Spaces:
Runtime error
Runtime error
Upload 73 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- model/slam_model_s2s.py +444 -0
- s2s.py +178 -0
- s2s_config.py +272 -0
- slam_llm/__init__.py +0 -0
- slam_llm/data/__init__.py +2 -0
- slam_llm/data/concatenator.py +34 -0
- slam_llm/data/sampler.py +57 -0
- slam_llm/models/BEATs/BEATs.py +181 -0
- slam_llm/models/BEATs/Tokenizers.py +173 -0
- slam_llm/models/BEATs/backbone.py +783 -0
- slam_llm/models/BEATs/modules.py +219 -0
- slam_llm/models/BEATs/quantizer.py +215 -0
- slam_llm/models/EAT/EAT.py +32 -0
- slam_llm/models/SpatialAST/SpatialAST.py +122 -0
- slam_llm/models/SpatialAST/vision_transformer.py +239 -0
- slam_llm/models/avhubert/__init__.py +10 -0
- slam_llm/models/avhubert/decoder.py +243 -0
- slam_llm/models/avhubert/hubert.py +792 -0
- slam_llm/models/avhubert/hubert_asr.py +523 -0
- slam_llm/models/avhubert/hubert_criterion.py +169 -0
- slam_llm/models/avhubert/hubert_dataset.py +529 -0
- slam_llm/models/avhubert/hubert_pretraining.py +401 -0
- slam_llm/models/avhubert/infer_s2s.py +318 -0
- slam_llm/models/avhubert/resnet.py +169 -0
- slam_llm/models/avhubert/sequence_generator.py +985 -0
- slam_llm/models/avhubert/utils.py +298 -0
- slam_llm/models/encoder.py +158 -0
- slam_llm/models/musicfm/model/__init__.py +2 -0
- slam_llm/models/musicfm/model/musicfm_25hz.py +253 -0
- slam_llm/models/musicfm/modules/__init__.py +2 -0
- slam_llm/models/musicfm/modules/conv.py +82 -0
- slam_llm/models/musicfm/modules/features.py +45 -0
- slam_llm/models/musicfm/modules/flash_conformer.py +2114 -0
- slam_llm/models/musicfm/modules/random_quantizer.py +83 -0
- slam_llm/models/projector.py +81 -0
- slam_llm/models/slam_model.py +443 -0
- slam_llm/models/vallex/__init__.py +0 -0
- slam_llm/models/vallex/activation.py +179 -0
- slam_llm/models/vallex/scaling.py +1404 -0
- slam_llm/models/vallex/transformers.py +613 -0
- slam_llm/models/vallex/vallex_config.py +56 -0
- slam_llm/models/vallex/vallex_model.py +772 -0
- slam_llm/models/wavlm/WavLM.py +743 -0
- slam_llm/models/wavlm/modules.py +827 -0
- slam_llm/policies/__init__.py +7 -0
- slam_llm/policies/activation_checkpointing_functions.py +29 -0
- slam_llm/policies/anyprecision_optimizer.py +179 -0
- slam_llm/policies/mixed_precision.py +38 -0
- slam_llm/policies/wrapping.py +33 -0
- slam_llm/utils/__init__.py +7 -0
model/slam_model_s2s.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from slam_llm.models.slam_model import (
|
6 |
+
slam_model,
|
7 |
+
setup_tokenizer,
|
8 |
+
setup_encoder,
|
9 |
+
setup_encoder_projector,
|
10 |
+
setup_llm,
|
11 |
+
)
|
12 |
+
from slam_llm.utils.train_utils import print_model_size
|
13 |
+
from typing import List, Optional
|
14 |
+
from slam_llm.utils.metric import compute_accuracy
|
15 |
+
from transformers import T5ForConditionalGeneration
|
16 |
+
from tqdm import tqdm
|
17 |
+
from utils.tts_adapter_utils import setup_tts_adapter
|
18 |
+
from utils.codec_utils import setup_codec
|
19 |
+
from utils.trick_utils import partial_freeze_weights, train_embedding_layer_only
|
20 |
+
from utils.snac_utils import layershift
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def model_factory(train_config, model_config, ckpt_path, **kwargs):
|
26 |
+
# return necessary components for training
|
27 |
+
tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
|
28 |
+
|
29 |
+
if train_config.task_type == "s2s" or train_config.task_type == "asr":
|
30 |
+
encoder = setup_encoder(train_config, model_config, **kwargs)
|
31 |
+
elif train_config.task_type == "tts":
|
32 |
+
encoder = None
|
33 |
+
else:
|
34 |
+
raise NotImplementedError
|
35 |
+
|
36 |
+
# llm
|
37 |
+
llm = setup_llm(train_config, model_config, **kwargs)
|
38 |
+
|
39 |
+
# projector
|
40 |
+
if encoder is not None:
|
41 |
+
encoder_projector = setup_encoder_projector(
|
42 |
+
train_config, model_config, **kwargs
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
encoder_projector = None
|
46 |
+
|
47 |
+
codec_decoder = None
|
48 |
+
if model_config.codec_decode:
|
49 |
+
codec_decoder = setup_codec(train_config, model_config, **kwargs)
|
50 |
+
|
51 |
+
tts_adapter = None
|
52 |
+
if model_config.tts_adapter:
|
53 |
+
adapter_config = model_config.tts_adapter_config
|
54 |
+
tts_adapter = setup_tts_adapter(adapter_config, model_config, **kwargs)
|
55 |
+
|
56 |
+
model = slam_model_s2s(
|
57 |
+
encoder,
|
58 |
+
llm,
|
59 |
+
encoder_projector,
|
60 |
+
tokenizer,
|
61 |
+
tts_adapter,
|
62 |
+
codec_decoder,
|
63 |
+
train_config,
|
64 |
+
model_config,
|
65 |
+
**kwargs,
|
66 |
+
)
|
67 |
+
|
68 |
+
if ckpt_path is not None:
|
69 |
+
logger.info("loading other parts from: {}".format(ckpt_path))
|
70 |
+
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
|
71 |
+
model.load_state_dict(ckpt_dict, strict=False)
|
72 |
+
|
73 |
+
if train_config.train_audio_embed_only:
|
74 |
+
partial_freeze_weights(model, model_config.vocab_config.padded_text_vocabsize, model_config.vocab_config.total_vocabsize)
|
75 |
+
|
76 |
+
if train_config.train_embed_only:
|
77 |
+
train_embedding_layer_only(model)
|
78 |
+
|
79 |
+
print_model_size(
|
80 |
+
model,
|
81 |
+
train_config,
|
82 |
+
(
|
83 |
+
int(os.environ["RANK"])
|
84 |
+
if train_config.enable_fsdp or train_config.enable_ddp
|
85 |
+
else 0
|
86 |
+
),
|
87 |
+
)
|
88 |
+
return model, tokenizer
|
89 |
+
|
90 |
+
|
91 |
+
class slam_model_s2s(slam_model):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
encoder,
|
95 |
+
llm,
|
96 |
+
encoder_projector,
|
97 |
+
tokenizer,
|
98 |
+
tts_adapter,
|
99 |
+
codec_decoder,
|
100 |
+
train_config,
|
101 |
+
model_config,
|
102 |
+
**kwargs,
|
103 |
+
):
|
104 |
+
super().__init__(
|
105 |
+
encoder,
|
106 |
+
llm,
|
107 |
+
encoder_projector,
|
108 |
+
tokenizer,
|
109 |
+
train_config,
|
110 |
+
model_config,
|
111 |
+
**kwargs,
|
112 |
+
)
|
113 |
+
|
114 |
+
# resize llm embedding layer
|
115 |
+
self.original_vocabsize = self.llm.lm_head.weight.size(0)
|
116 |
+
if self.model_config.vocab_config.total_vocabsize != self.original_vocabsize:
|
117 |
+
self.llm.resize_token_embeddings(self.model_config.vocab_config.total_vocabsize)
|
118 |
+
|
119 |
+
if int(os.environ.get("RANK", "0")) == 0:
|
120 |
+
logger.info("Resize llm embedding layer's vocab size to {}".format(self.model_config.vocab_config.total_vocabsize))
|
121 |
+
|
122 |
+
self.codec_decoder = codec_decoder
|
123 |
+
self.tts_adapter = tts_adapter
|
124 |
+
self.code_layer = self.model_config.vocab_config.code_layer
|
125 |
+
|
126 |
+
|
127 |
+
def forward(self,
|
128 |
+
input_ids: torch.LongTensor = None,
|
129 |
+
attention_mask: Optional[torch.Tensor] = None,
|
130 |
+
position_ids: Optional[torch.LongTensor] = None,
|
131 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
132 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
133 |
+
labels: Optional[torch.LongTensor] = None,
|
134 |
+
use_cache: Optional[bool] = None,
|
135 |
+
output_attentions: Optional[bool] = None,
|
136 |
+
output_hidden_states: Optional[bool] = None,
|
137 |
+
return_dict: Optional[bool] = None,
|
138 |
+
**kwargs,
|
139 |
+
):
|
140 |
+
audio_mel = kwargs.get("audio_mel", None)
|
141 |
+
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
|
142 |
+
|
143 |
+
audio = kwargs.get("audio", None)
|
144 |
+
audio_mask = kwargs.get("audio_mask", None)
|
145 |
+
|
146 |
+
modality_mask = kwargs.get("modality_mask", None)
|
147 |
+
|
148 |
+
encoder_outs = None
|
149 |
+
if audio_mel is not None or audio is not None:
|
150 |
+
if self.train_config.freeze_encoder: # freeze encoder
|
151 |
+
self.encoder.eval()
|
152 |
+
|
153 |
+
if self.model_config.encoder_name == "whisper":
|
154 |
+
encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim
|
155 |
+
if self.model_config.encoder_name == "wavlm":
|
156 |
+
encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask
|
157 |
+
if self.model_config.encoder_name == "hubert":
|
158 |
+
results = self.encoder(source = audio, padding_mask = 1-audio_mask)
|
159 |
+
if self.model_config.encoder_type == "pretrain":
|
160 |
+
encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"]
|
161 |
+
if self.model_config.encoder_type == "finetune":
|
162 |
+
encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
|
163 |
+
encoder_outs = encoder_outs.transpose(0, 1)
|
164 |
+
if self.encoder is None:
|
165 |
+
encoder_outs = audio_mel if audio_mel is not None else audio
|
166 |
+
|
167 |
+
if self.model_config.encoder_projector == "q-former":
|
168 |
+
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
|
169 |
+
if self.model_config.encoder_projector == "linear":
|
170 |
+
encoder_outs = self.encoder_projector(encoder_outs)
|
171 |
+
if self.model_config.encoder_projector == "cov1d-linear":
|
172 |
+
encoder_outs = self.encoder_projector(encoder_outs)
|
173 |
+
|
174 |
+
if input_ids is not None:
|
175 |
+
input_ids[input_ids == -1] = 0 # [btz, 8, seq_length]
|
176 |
+
|
177 |
+
if isinstance(self.llm, T5ForConditionalGeneration):
|
178 |
+
inputs_embeds = self.llm.shared(input_ids)
|
179 |
+
else:
|
180 |
+
if hasattr(self.llm.model, "embed_tokens"):
|
181 |
+
inputs_embeds = self.llm.model.embed_tokens(input_ids) # [btz, 8, seq_length, emb_dim]
|
182 |
+
elif hasattr(self.llm.model.model, "embed_tokens"):
|
183 |
+
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
|
184 |
+
else:
|
185 |
+
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
|
186 |
+
|
187 |
+
if modality_mask is not None and encoder_outs is not None:
|
188 |
+
modality_mask = modality_mask.unsqueeze(1).repeat(1, self.code_layer, 1) # [btz, 8, seq_length]
|
189 |
+
modality_mask_start_indices = (modality_mask == True).float().argmax(dim=2)
|
190 |
+
modality_lengths = torch.clamp(modality_mask.sum(dim=2), max=encoder_outs.shape[1]).tolist()
|
191 |
+
|
192 |
+
encoder_outs_pad = torch.zeros_like(inputs_embeds)
|
193 |
+
for i in range(encoder_outs.shape[0]):
|
194 |
+
for j in range(self.code_layer):
|
195 |
+
start_idx = modality_mask_start_indices[i, j].item()
|
196 |
+
length = modality_lengths[i][j]
|
197 |
+
encoder_outs_pad[i, j, start_idx:start_idx+length] = encoder_outs[i, :length]
|
198 |
+
|
199 |
+
inputs_embeds[:, :self.code_layer, :, :] = encoder_outs_pad[:, :self.code_layer, :, :] + inputs_embeds[:, :self.code_layer, :, :] * (~modality_mask[:, :, :, None])
|
200 |
+
|
201 |
+
inputs_embeds = torch.mean(inputs_embeds, dim=1) # [btz, seq_length, emb_dim], average over the 8 layers
|
202 |
+
|
203 |
+
if kwargs.get("inference_mode", False):
|
204 |
+
return inputs_embeds, attention_mask
|
205 |
+
|
206 |
+
text_labels = labels[:,self.code_layer] if labels is not None else None
|
207 |
+
audio_labels = labels[:, :self.code_layer] if labels is not None else None
|
208 |
+
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=text_labels) # here we use the text token layer as the target label
|
209 |
+
|
210 |
+
# parrallel generation
|
211 |
+
# TODO: add tts adapter forward
|
212 |
+
x_ori = model_outputs.logits
|
213 |
+
text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize
|
214 |
+
audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize
|
215 |
+
xt = x_ori[..., :text_vocab_size]
|
216 |
+
xa = []
|
217 |
+
for i in range(self.code_layer):
|
218 |
+
xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
|
219 |
+
|
220 |
+
loss_recorder = []
|
221 |
+
total_loss, loss_recorder = self.compute_parallel_loss(xt, text_labels, xa, audio_labels)
|
222 |
+
model_outputs.loss = total_loss
|
223 |
+
|
224 |
+
text_acc = -1
|
225 |
+
audio_acc = [-1 for _ in range(self.code_layer)]
|
226 |
+
if self.metric:
|
227 |
+
with torch.no_grad():
|
228 |
+
preds = torch.argmax(xt, -1)
|
229 |
+
text_acc = compute_accuracy(preds.detach()[:, :-1], text_labels.detach()[:, 1:], ignore_label=-100)
|
230 |
+
|
231 |
+
preds_audio = [torch.argmax(xa[i], -1) for i in range(self.code_layer)]
|
232 |
+
audio_acc = [compute_accuracy(preds_audio[i].detach()[:, :-1], audio_labels[:, i, 1:], ignore_label=-100) for i in range(self.code_layer)]
|
233 |
+
|
234 |
+
# metrics = {"text_acc": text_acc, "audio_acc": audio_acc, "layer_loss": loss_recorder}
|
235 |
+
return model_outputs, text_acc, audio_acc, loss_recorder
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
def compute_parallel_loss(self, xt, text_labels, xa, audio_labels):
|
240 |
+
"""
|
241 |
+
Compute the parallel loss for text and audio layers.
|
242 |
+
"""
|
243 |
+
text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize
|
244 |
+
audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize
|
245 |
+
layer_loss = [0 for _ in range(self.code_layer+1) ]
|
246 |
+
|
247 |
+
if text_labels is not None:
|
248 |
+
# text_loss = F.cross_entropy(xt.reshape(-1, text_vocab_size), text_labels.reshape(-1), ignore_index=-100)
|
249 |
+
text_loss = F.cross_entropy(xt[:, :-1, :].reshape(-1, text_vocab_size), text_labels[:, 1:].reshape(-1), ignore_index=-100)
|
250 |
+
layer_loss[self.code_layer] = text_loss
|
251 |
+
else:
|
252 |
+
text_loss = 0
|
253 |
+
|
254 |
+
total_audio_loss = 0
|
255 |
+
single_audio_loss = 0
|
256 |
+
for i in range(self.code_layer):
|
257 |
+
if audio_labels[:,i] is not None:
|
258 |
+
# audio_loss += F.cross_entropy(xa[i].reshape(-1, audio_vocab_size), audio_labels[:,i].reshape(-1), ignore_index=-100)
|
259 |
+
single_audio_loss = F.cross_entropy(xa[i][:, :-1, :].reshape(-1, audio_vocab_size), audio_labels[:, i, 1:].reshape(-1), ignore_index=-100)
|
260 |
+
layer_loss[i] = single_audio_loss
|
261 |
+
total_audio_loss += single_audio_loss
|
262 |
+
|
263 |
+
total_loss = (text_loss + total_audio_loss) / (self.code_layer+1)
|
264 |
+
return total_loss, layer_loss
|
265 |
+
|
266 |
+
|
267 |
+
@torch.no_grad()
|
268 |
+
def generate(self,
|
269 |
+
input_ids: torch.LongTensor = None,
|
270 |
+
attention_mask: Optional[torch.Tensor] = None,
|
271 |
+
position_ids: Optional[torch.LongTensor] = None,
|
272 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
273 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
274 |
+
labels: Optional[torch.LongTensor] = None,
|
275 |
+
use_cache: Optional[bool] = None,
|
276 |
+
output_attentions: Optional[bool] = None,
|
277 |
+
output_hidden_states: Optional[bool] = None,
|
278 |
+
return_dict: Optional[bool] = None,
|
279 |
+
**kwargs,
|
280 |
+
):
|
281 |
+
kwargs["inference_mode"] = True
|
282 |
+
|
283 |
+
inputs_embeds, attention_mask = self.forward(
|
284 |
+
input_ids=input_ids,
|
285 |
+
attention_mask=attention_mask,
|
286 |
+
position_ids=position_ids,
|
287 |
+
past_key_values=past_key_values,
|
288 |
+
inputs_embeds=inputs_embeds,
|
289 |
+
labels=labels,
|
290 |
+
use_cache=use_cache,
|
291 |
+
output_attentions=output_attentions,
|
292 |
+
output_hidden_states=output_hidden_states,
|
293 |
+
return_dict=return_dict,
|
294 |
+
**kwargs,
|
295 |
+
)
|
296 |
+
|
297 |
+
generated_ids = [[] for _ in range((self.code_layer+1))]
|
298 |
+
current_input_text = None
|
299 |
+
current_audio_tokens = [None for _ in range(self.code_layer)]
|
300 |
+
# input_pos = torch.arange(input_ids.size(-1), device=input_ids.device).unsqueeze(0)
|
301 |
+
past_key_values = None
|
302 |
+
|
303 |
+
text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize
|
304 |
+
audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize
|
305 |
+
|
306 |
+
max_new_tokens = kwargs.get("max_new_tokens", 360)
|
307 |
+
repetition_penalty = kwargs.get("repetition_penalty", 1.0)
|
308 |
+
decode_text_only = kwargs.get("decode_text_only", False)
|
309 |
+
|
310 |
+
pad_t = self.model_config.vocab_config.pad_t
|
311 |
+
pad_a = self.model_config.vocab_config.pad_a
|
312 |
+
eot = self.model_config.vocab_config.eot
|
313 |
+
eoa = self.model_config.vocab_config.eoa
|
314 |
+
|
315 |
+
text_end = False # Track whether text generation has ended
|
316 |
+
audio_end = False # Track whether audio generation has ended
|
317 |
+
|
318 |
+
# NOTE: currently, we only support greedy decoding and sampling for parallel generation, no beam search
|
319 |
+
for step in tqdm(range(max_new_tokens), desc="Generating"):
|
320 |
+
if current_input_text is not None:
|
321 |
+
audio_tokens = torch.cat([layershift(current_audio_tokens[i], i).unsqueeze(1) for i in range(self.code_layer)], dim=1)
|
322 |
+
combined_input_ids = torch.cat([audio_tokens, current_input_text.unsqueeze(1)], dim=1)
|
323 |
+
inputs_embeds = self.llm.model.embed_tokens(combined_input_ids)
|
324 |
+
inputs_embeds = torch.mean(inputs_embeds, dim=1).unsqueeze(1)
|
325 |
+
|
326 |
+
outputs = self.llm(
|
327 |
+
inputs_embeds=inputs_embeds, # [btz, seq_len / 1, emb_dim]
|
328 |
+
attention_mask=attention_mask, # single sample, no need for attention mask
|
329 |
+
past_key_values=past_key_values,
|
330 |
+
# position_ids=input_pos,
|
331 |
+
use_cache=True,
|
332 |
+
)
|
333 |
+
|
334 |
+
logits = outputs.logits
|
335 |
+
past_key_values = outputs.past_key_values # Update past_key_values for the next step
|
336 |
+
|
337 |
+
# Split logits into text and audio layers based on vocab size
|
338 |
+
xt_logits = logits[..., :text_vocab_size]
|
339 |
+
xa_logits = [logits[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)] for i in range(self.code_layer)]
|
340 |
+
|
341 |
+
# Apply repetition penalty to the logits
|
342 |
+
if repetition_penalty != 1.0:
|
343 |
+
xt_logits = self.repetition_penalty(xt_logits, generated_ids[self.code_layer], repetition_penalty)
|
344 |
+
for i in range(self.code_layer):
|
345 |
+
xa_logits[i] = self.repetition_penalty(xa_logits[i], generated_ids[i], repetition_penalty)
|
346 |
+
|
347 |
+
if not text_end:
|
348 |
+
next_token_text = self.sample_next_token(xt_logits[:, -1, :], **kwargs)
|
349 |
+
else:
|
350 |
+
next_token_text = torch.tensor([pad_t], device=input_ids.device)
|
351 |
+
|
352 |
+
next_tokens_audio = []
|
353 |
+
for i in range(self.code_layer):
|
354 |
+
if not audio_end and not decode_text_only:
|
355 |
+
next_token_audio = self.sample_next_token(xa_logits[i][:, -1, :], **kwargs)
|
356 |
+
else:
|
357 |
+
next_token_audio = torch.full((input_ids.size(0),), pad_a, device=input_ids.device)
|
358 |
+
next_tokens_audio.append(next_token_audio)
|
359 |
+
|
360 |
+
if next_tokens_audio[-1] == eoa or decode_text_only:
|
361 |
+
audio_end = True
|
362 |
+
if next_token_text == eot:
|
363 |
+
text_end = True
|
364 |
+
|
365 |
+
# Update input_ids for the next step
|
366 |
+
current_input_text = next_token_text
|
367 |
+
for i in range(self.code_layer):
|
368 |
+
current_audio_tokens[i] = next_tokens_audio[i]
|
369 |
+
|
370 |
+
# if input_pos.size(-1) > 1:
|
371 |
+
# input_pos = torch.tensor(input_pos.size(-1), device=input_ids.device).unsqueeze(0)
|
372 |
+
# else:
|
373 |
+
# input_pos = input_pos.add_(1)
|
374 |
+
attention_mask = torch.cat([attention_mask, torch.ones((input_ids.size(0), 1), device=input_ids.device)], dim=1)
|
375 |
+
|
376 |
+
if audio_end and text_end:
|
377 |
+
break
|
378 |
+
|
379 |
+
# Append generated tokens to the list
|
380 |
+
for i in range(self.code_layer):
|
381 |
+
generated_ids[i].append(next_tokens_audio[i].clone().tolist()[0]) # Audio layers
|
382 |
+
generated_ids[self.code_layer].append(next_token_text.clone().tolist()[0]) # Text layer
|
383 |
+
|
384 |
+
# Concatenate the generated tokens to form the complete sequence
|
385 |
+
text_tokens = generated_ids[-1]
|
386 |
+
generated_ids[-1] = text_tokens[: text_tokens.index(eot)] if eot in text_tokens else text_tokens
|
387 |
+
generated_ids = [torch.tensor(layer) for layer in generated_ids]
|
388 |
+
return generated_ids
|
389 |
+
|
390 |
+
|
391 |
+
@torch.no_grad()
|
392 |
+
def sample_next_token(self, logits, **kwargs):
|
393 |
+
"""
|
394 |
+
Generate the next token based on the model output logits.
|
395 |
+
Supports both greedy decoding, top-k sampling, and top-p (nucleus) sampling.
|
396 |
+
"""
|
397 |
+
do_sample = kwargs.get("do_sample", False)
|
398 |
+
temperature = kwargs.get("temperature", 1.0)
|
399 |
+
top_k = kwargs.get("top_k", 50)
|
400 |
+
top_p = kwargs.get("top_p", 1.0)
|
401 |
+
num_samples = kwargs.get("num_samples", 1)
|
402 |
+
|
403 |
+
# Adjust logits with temperature
|
404 |
+
logits = logits.squeeze(0)
|
405 |
+
logits = logits / temperature
|
406 |
+
|
407 |
+
# Top-k filtering
|
408 |
+
if top_k > 0:
|
409 |
+
top_k = min(top_k, logits.size(-1)) # Make sure top_k is within the vocab size
|
410 |
+
values, indices = torch.topk(logits, top_k)
|
411 |
+
logits[logits < values[..., [-1]]] = -float('Inf') # Filter tokens not in top_k
|
412 |
+
|
413 |
+
# Top-p filtering (nucleus sampling)
|
414 |
+
if top_p < 1.0:
|
415 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
416 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
417 |
+
|
418 |
+
# Remove tokens with cumulative probability above the threshold
|
419 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
420 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
421 |
+
sorted_indices_to_remove[..., 0] = 0
|
422 |
+
|
423 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
424 |
+
logits[indices_to_remove] = -float('Inf')
|
425 |
+
|
426 |
+
if do_sample:
|
427 |
+
# Perform sampling
|
428 |
+
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=num_samples)
|
429 |
+
else:
|
430 |
+
# Greedy decoding (argmax)
|
431 |
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
432 |
+
|
433 |
+
|
434 |
+
def repetition_penalty(self, logits, generated_ids, repetition_penalty):
|
435 |
+
"""
|
436 |
+
Apply repetition penalty to the logits.
|
437 |
+
"""
|
438 |
+
for token_id in set(generated_ids):
|
439 |
+
if logits[0, -1, token_id] < 0:
|
440 |
+
logits[0, -1, token_id] *= repetition_penalty
|
441 |
+
else:
|
442 |
+
logits[0, -1, token_id] /= repetition_penalty
|
443 |
+
|
444 |
+
return logits
|
s2s.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from slam_llm.utils.model_utils import get_custom_model_factory
|
4 |
+
from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift
|
5 |
+
import whisper
|
6 |
+
import numpy as np
|
7 |
+
from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME
|
8 |
+
import os
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from typing import Callable
|
12 |
+
|
13 |
+
|
14 |
+
def update_progress(progress_callback: Callable[[str], None] | None, message: str):
|
15 |
+
if progress_callback:
|
16 |
+
progress_callback(message)
|
17 |
+
|
18 |
+
|
19 |
+
def pull_model_ckpt():
|
20 |
+
if not os.path.exists(CKPT_LOCAL_DIR):
|
21 |
+
os.makedirs(CKPT_LOCAL_DIR)
|
22 |
+
if os.path.exists(CKPT_PATH):
|
23 |
+
return
|
24 |
+
hf_hub_download(
|
25 |
+
repo_id=CKPT_REPO,
|
26 |
+
filename=CKPT_NAME,
|
27 |
+
local_dir=CKPT_LOCAL_DIR,
|
28 |
+
token=os.getenv("HF_TOKEN"),
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
pull_model_ckpt()
|
33 |
+
|
34 |
+
|
35 |
+
def extract_audio_feature(audio_path, mel_size):
|
36 |
+
print("Extracting audio features from", audio_path)
|
37 |
+
audio_raw = whisper.load_audio(audio_path)
|
38 |
+
audio_raw = whisper.pad_or_trim(audio_raw)
|
39 |
+
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0)
|
40 |
+
audio_length = (audio_mel.shape[0] + 1) // 2
|
41 |
+
audio_length = audio_length // 5
|
42 |
+
audio_res = audio_mel
|
43 |
+
|
44 |
+
return audio_res, audio_length
|
45 |
+
|
46 |
+
|
47 |
+
def get_input_ids(length, special_token_a, special_token_t, vocab_config):
|
48 |
+
input_ids = []
|
49 |
+
for i in range(vocab_config.code_layer):
|
50 |
+
input_ids_item = []
|
51 |
+
input_ids_item.append(layershift(vocab_config.input_a, i))
|
52 |
+
input_ids_item += [layershift(vocab_config.pad_a, i)] * length
|
53 |
+
input_ids_item += [
|
54 |
+
(layershift(vocab_config.eoa, i)),
|
55 |
+
layershift(special_token_a, i),
|
56 |
+
]
|
57 |
+
input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
|
58 |
+
input_id_T = torch.tensor(
|
59 |
+
[vocab_config.input_t]
|
60 |
+
+ [vocab_config.pad_t] * length
|
61 |
+
+ [vocab_config.eot, special_token_t]
|
62 |
+
)
|
63 |
+
input_ids.append(input_id_T.unsqueeze(0))
|
64 |
+
return input_ids
|
65 |
+
|
66 |
+
|
67 |
+
def generate_from_wav(
|
68 |
+
wav_path, model, codec_decoder, dataset_config, decode_config, device
|
69 |
+
):
|
70 |
+
mel_size = dataset_config.mel_size
|
71 |
+
prompt = dataset_config.prompt
|
72 |
+
prompt_template = "USER: {}\n ASSISTANT: "
|
73 |
+
vocab_config = dataset_config.vocab_config
|
74 |
+
special_token_a = vocab_config.answer_a
|
75 |
+
special_token_t = vocab_config.answer_t
|
76 |
+
code_layer = vocab_config.code_layer
|
77 |
+
task_type = dataset_config.task_type
|
78 |
+
|
79 |
+
audio_mel, audio_length = extract_audio_feature(wav_path, mel_size)
|
80 |
+
|
81 |
+
prompt = prompt_template.format(prompt)
|
82 |
+
prompt_ids = model.tokenizer.encode(prompt)
|
83 |
+
prompt_length = len(prompt_ids)
|
84 |
+
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
|
85 |
+
|
86 |
+
example_ids = get_input_ids(
|
87 |
+
audio_length + prompt_length, special_token_a, special_token_t, vocab_config
|
88 |
+
)
|
89 |
+
text_layer = example_ids[code_layer]
|
90 |
+
text_layer = torch.cat(
|
91 |
+
(
|
92 |
+
text_layer[:, : audio_length + 1],
|
93 |
+
prompt_ids.unsqueeze(0),
|
94 |
+
text_layer[:, -2:],
|
95 |
+
),
|
96 |
+
dim=1,
|
97 |
+
) # <bos> <audio> <prompt> <eos> <task>
|
98 |
+
example_ids[code_layer] = text_layer
|
99 |
+
|
100 |
+
input_length = audio_length
|
101 |
+
example_mask = example_ids[0][0].ge(-1)
|
102 |
+
example_ids = torch.stack(example_ids).squeeze()
|
103 |
+
|
104 |
+
input_ids = example_ids.unsqueeze(0).to(device)
|
105 |
+
attention_mask = example_mask.unsqueeze(0).to(device)
|
106 |
+
audio_mel = audio_mel.unsqueeze(0).to(device)
|
107 |
+
input_length = torch.tensor([input_length]).to(device)
|
108 |
+
audio_length = torch.tensor([audio_length]).to(device)
|
109 |
+
task_type = [task_type]
|
110 |
+
|
111 |
+
modality_mask = torch.zeros_like(attention_mask)
|
112 |
+
padding_left = 1 # +1 for <bos>
|
113 |
+
modality_mask[0, padding_left : padding_left + audio_length] = True
|
114 |
+
|
115 |
+
batch = {
|
116 |
+
"input_ids": input_ids,
|
117 |
+
"attention_mask": attention_mask,
|
118 |
+
"audio_mel": audio_mel,
|
119 |
+
"input_length": input_length,
|
120 |
+
"audio_length": audio_length,
|
121 |
+
"modality_mask": modality_mask,
|
122 |
+
"task_types": task_type,
|
123 |
+
}
|
124 |
+
|
125 |
+
model_outputs = model.generate(**batch, **decode_config)
|
126 |
+
text_outputs = model_outputs[7]
|
127 |
+
audio_outputs = model_outputs[:7]
|
128 |
+
output_text = model.tokenizer.decode(
|
129 |
+
text_outputs, add_special_tokens=False, skip_special_tokens=True
|
130 |
+
)
|
131 |
+
|
132 |
+
if decode_config.decode_text_only:
|
133 |
+
return None, output_text
|
134 |
+
|
135 |
+
audio_tokens = [audio_outputs[layer] for layer in range(7)]
|
136 |
+
audiolist = reconscruct_snac(audio_tokens)
|
137 |
+
audio = reconstruct_tensors(audiolist)
|
138 |
+
with torch.inference_mode():
|
139 |
+
audio_hat = codec_decoder.decode(audio)
|
140 |
+
|
141 |
+
return audio_hat, output_text
|
142 |
+
|
143 |
+
|
144 |
+
def generate(
|
145 |
+
wav_path: str, progress_callback: Callable[[str], None] | None = None
|
146 |
+
) -> tuple[np.ndarray, int | float]:
|
147 |
+
config = OmegaConf.structured(InferenceConfig())
|
148 |
+
train_config, model_config, dataset_config, decode_config = (
|
149 |
+
config.train_config,
|
150 |
+
config.model_config,
|
151 |
+
config.dataset_config,
|
152 |
+
config.decode_config,
|
153 |
+
)
|
154 |
+
|
155 |
+
torch.cuda.manual_seed(train_config.seed)
|
156 |
+
torch.manual_seed(train_config.seed)
|
157 |
+
random.seed(train_config.seed)
|
158 |
+
|
159 |
+
update_progress(progress_callback, "Loading model")
|
160 |
+
|
161 |
+
model_factory = get_custom_model_factory(model_config)
|
162 |
+
model, _ = model_factory(train_config, model_config, CKPT_PATH)
|
163 |
+
codec_decoder = model.codec_decoder
|
164 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
165 |
+
model.to(device)
|
166 |
+
model.eval()
|
167 |
+
|
168 |
+
update_progress(progress_callback, "Generating")
|
169 |
+
output_wav, output_text = generate_from_wav(
|
170 |
+
wav_path, model, codec_decoder, dataset_config, decode_config, device
|
171 |
+
)
|
172 |
+
|
173 |
+
return output_wav.squeeze().cpu().numpy(), 24000
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
wav_path = "sample.wav"
|
178 |
+
generate(wav_path)
|
s2s_config.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional, List
|
3 |
+
import os
|
4 |
+
|
5 |
+
CKPT_NAME = "model.pt"
|
6 |
+
CKPT_LOCAL_DIR = "model_ckpts"
|
7 |
+
CKPT_PATH = os.path.join(CKPT_LOCAL_DIR, CKPT_NAME)
|
8 |
+
CKPT_REPO = "xcczach/mini-omni"
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class VocabConfig:
|
13 |
+
text_vocabsize: int = 151936
|
14 |
+
text_specialtokens: int = 64
|
15 |
+
audio_vocabsize: int = 4096
|
16 |
+
audio_specialtokens: int = 64
|
17 |
+
total_vocabsize: int = 181120
|
18 |
+
code_layer: int = 7
|
19 |
+
|
20 |
+
padded_text_vocabsize: int = field(init=False)
|
21 |
+
padded_audio_vocabsize: int = field(init=False)
|
22 |
+
total_audio_vocabsize: int = field(init=False)
|
23 |
+
|
24 |
+
eot: int = field(init=False) # end of text token
|
25 |
+
pad_t: int = field(init=False) # padding text token
|
26 |
+
input_t: int = field(init=False) # input text token
|
27 |
+
answer_t: int = field(init=False) # answer text token
|
28 |
+
asr: int = field(init=False) # ASR token
|
29 |
+
|
30 |
+
eoa: int = field(init=False) # end of audio token
|
31 |
+
pad_a: int = field(init=False) # padding audio token
|
32 |
+
input_a: int = field(init=False) # input audio token
|
33 |
+
answer_a: int = field(init=False) # answer audio token
|
34 |
+
split: int = field(init=False) # split token
|
35 |
+
|
36 |
+
def __post_init__(self):
|
37 |
+
self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens
|
38 |
+
self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens
|
39 |
+
self.total_audio_vocabsize = self.padded_audio_vocabsize * self.code_layer
|
40 |
+
|
41 |
+
self.eot = self.text_vocabsize
|
42 |
+
self.pad_t = self.text_vocabsize + 1
|
43 |
+
self.input_t = self.text_vocabsize + 2
|
44 |
+
self.answer_t = self.text_vocabsize + 3
|
45 |
+
self.asr = self.text_vocabsize + 4
|
46 |
+
|
47 |
+
self.eoa = self.audio_vocabsize
|
48 |
+
self.pad_a = self.audio_vocabsize + 1
|
49 |
+
self.input_a = self.audio_vocabsize + 2
|
50 |
+
self.answer_a = self.audio_vocabsize + 3
|
51 |
+
self.split = self.audio_vocabsize + 4
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class TTSAdapterConfig:
|
56 |
+
add_qkv_bias: Optional[bool] = True
|
57 |
+
bias: bool = False
|
58 |
+
gelu_approximate: Optional[str] = None
|
59 |
+
head_size: Optional[int] = 64
|
60 |
+
intermediate_size: Optional[int] = 4864
|
61 |
+
lm_head_bias: bool = False
|
62 |
+
mlp_class_name: str = "GptNeoxMLP"
|
63 |
+
n_layer: int = 6
|
64 |
+
n_head: int = 14
|
65 |
+
n_embd: int = 896
|
66 |
+
n_query_groups: Optional[int] = 2
|
67 |
+
norm_class_name: str = "RMSNorm"
|
68 |
+
norm_eps: float = 1e-6
|
69 |
+
parallel_residual: bool = False
|
70 |
+
rotary_percentage: float = 1
|
71 |
+
shared_attention_norm: bool = False
|
72 |
+
|
73 |
+
def __post_init__(self):
|
74 |
+
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class ModelConfig:
|
79 |
+
file: str = "model/slam_model_s2s.py:model_factory"
|
80 |
+
llm_name: str = "qwen2-0.5b"
|
81 |
+
llm_path: str = "Qwen/Qwen2-0.5B"
|
82 |
+
llm_type: str = "decoder_only"
|
83 |
+
llm_dim: int = 896
|
84 |
+
encoder_name: Optional[str] = "whisper"
|
85 |
+
encoder_ds_rate: int = 2
|
86 |
+
encoder_path: Optional[str] = "small"
|
87 |
+
encoder_dim: int = 768
|
88 |
+
encoder_projector: str = "linear"
|
89 |
+
encoder_projector_ds_rate: int = 5
|
90 |
+
modal: str = "audio"
|
91 |
+
normalize: Optional[bool] = field(
|
92 |
+
default=False,
|
93 |
+
metadata={"help": "whether input is normalized, used for models such as wavlm"},
|
94 |
+
)
|
95 |
+
encoder_type: str = field(
|
96 |
+
default="finetune",
|
97 |
+
metadata={
|
98 |
+
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
|
99 |
+
},
|
100 |
+
)
|
101 |
+
vocab_config: VocabConfig = field(default_factory=VocabConfig)
|
102 |
+
codec_decode: bool = True
|
103 |
+
codec_decoder_type: str = "SNAC"
|
104 |
+
codec_decoder_path: Optional[str] = "hubertsiuzdak/snac_24khz"
|
105 |
+
tts_adapter: bool = False
|
106 |
+
tts_adapter_config: TTSAdapterConfig = field(default_factory=TTSAdapterConfig)
|
107 |
+
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class PeftConfig:
|
111 |
+
peft_method: str = "lora" # None , llama_adapter, prefix
|
112 |
+
r: int = 8
|
113 |
+
lora_alpha: int = 32
|
114 |
+
target_modules: List = field(default_factory=lambda: ["q_proj", "v_proj"])
|
115 |
+
bias: str = "none"
|
116 |
+
task_type: str = "CAUSAL_LM"
|
117 |
+
lora_dropout: float = 0.05
|
118 |
+
inference_mode: bool = False
|
119 |
+
|
120 |
+
|
121 |
+
@dataclass
|
122 |
+
class TrainConfig:
|
123 |
+
model_name: str = "s2s"
|
124 |
+
enable_ddp: bool = False
|
125 |
+
enable_deepspeed: bool = False
|
126 |
+
enable_fsdp: bool = False
|
127 |
+
low_cpu_fsdp: bool = False
|
128 |
+
run_validation: bool = True
|
129 |
+
batch_size_training: int = 4
|
130 |
+
batching_strategy: str = field(
|
131 |
+
default="custom", metadata={"help": "alternative: padding"}
|
132 |
+
) #
|
133 |
+
context_length: int = 4096
|
134 |
+
gradient_accumulation_steps: int = 1
|
135 |
+
num_epochs: int = 1
|
136 |
+
num_workers_dataloader: int = 2
|
137 |
+
warmup_steps: int = 1000
|
138 |
+
total_steps: int = 100000
|
139 |
+
validation_interval: int = 1000
|
140 |
+
lr: float = 1e-4
|
141 |
+
weight_decay: float = 0.0
|
142 |
+
gamma: float = 0.85
|
143 |
+
seed: int = 42
|
144 |
+
use_fp16: bool = False
|
145 |
+
mixed_precision: bool = True
|
146 |
+
val_batch_size: int = 1
|
147 |
+
|
148 |
+
use_peft: bool = False
|
149 |
+
peft_config: PeftConfig = field(default_factory=PeftConfig)
|
150 |
+
output_dir: str = "PATH/to/save/PEFT/model"
|
151 |
+
freeze_layers: bool = False
|
152 |
+
num_freeze_layers: int = 1
|
153 |
+
quantization: bool = False
|
154 |
+
one_gpu: bool = False
|
155 |
+
save_model: bool = True
|
156 |
+
dist_checkpoint_root_folder: str = (
|
157 |
+
"PATH/to/save/FSDP/model" # will be used if using FSDP
|
158 |
+
)
|
159 |
+
dist_checkpoint_folder: str = "fine-tuned" # will be used if using FSDP
|
160 |
+
save_optimizer: bool = False # will be used if using FSDP
|
161 |
+
use_fast_kernels: bool = (
|
162 |
+
False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
|
163 |
+
)
|
164 |
+
run_test_during_validation: bool = False
|
165 |
+
run_test_during_validation_file: str = "test.wav"
|
166 |
+
run_test_during_validation_prompt: str = "<|S2S|>"
|
167 |
+
freeze_llm: bool = field(
|
168 |
+
default=True,
|
169 |
+
metadata={
|
170 |
+
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
|
171 |
+
},
|
172 |
+
)
|
173 |
+
freeze_encoder: bool = True
|
174 |
+
train_embed_only: bool = False
|
175 |
+
train_audio_embed_only: bool = False
|
176 |
+
task_type: str = "s2s"
|
177 |
+
|
178 |
+
|
179 |
+
@dataclass
|
180 |
+
class DataConfig:
|
181 |
+
dataset: str = "speech_dataset_s2s"
|
182 |
+
file: str = "examples/s2s/speech_dataset_s2s.py:get_speech_dataset"
|
183 |
+
train_data_path: Optional[str] = None
|
184 |
+
val_data_path: Optional[str] = None
|
185 |
+
train_split: str = "train"
|
186 |
+
test_split: str = "validation"
|
187 |
+
prompt: Optional[str] = None
|
188 |
+
data_path: Optional[str] = None
|
189 |
+
max_words: Optional[int] = None
|
190 |
+
max_mel: Optional[float] = None
|
191 |
+
fix_length_audio: int = -1
|
192 |
+
inference_mode: bool = True
|
193 |
+
input_type: str = field(
|
194 |
+
default="mel",
|
195 |
+
metadata={"help": "Use raw when input is wav, mel when for whisper"},
|
196 |
+
)
|
197 |
+
mel_size: int = field(
|
198 |
+
default=80, metadata={"help": "80 for whisper large v1 and v2, 128 for v3"}
|
199 |
+
)
|
200 |
+
normalize: Optional[bool] = field(
|
201 |
+
default=False,
|
202 |
+
metadata={"help": "whether input is normalized, used for models such as wavlm"},
|
203 |
+
)
|
204 |
+
seed: int = 42
|
205 |
+
manifest_format: str = field(
|
206 |
+
default="datasets", metadata={"help": "alternative: jsonl"}
|
207 |
+
)
|
208 |
+
split_size: float = 0.1
|
209 |
+
|
210 |
+
vocab_config: VocabConfig = field(default_factory=VocabConfig)
|
211 |
+
load_from_cache_file: bool = False
|
212 |
+
task_type: str = "s2s"
|
213 |
+
|
214 |
+
|
215 |
+
@dataclass
|
216 |
+
class DecodeConfig:
|
217 |
+
do_sample: bool = False
|
218 |
+
max_new_tokens: int = 300
|
219 |
+
min_length: int = 10
|
220 |
+
temperature: float = 1.0
|
221 |
+
top_k: int = 50
|
222 |
+
top_p: float = 0.9
|
223 |
+
num_beams: int = 1
|
224 |
+
num_return_sequences: int = 1
|
225 |
+
num_samples: int = 1
|
226 |
+
max_time: float = 0.0
|
227 |
+
repetition_penalty: float = 1.0
|
228 |
+
length_penalty: float = 1.0
|
229 |
+
early_stopping: bool = False
|
230 |
+
no_repeat_ngram_size: int = 0
|
231 |
+
bad_words_ids: List = field(default_factory=list)
|
232 |
+
num_beam_groups: int = 1
|
233 |
+
diversity_penalty: float = 0.0
|
234 |
+
task_type: str = "s2s"
|
235 |
+
decode_text_only: bool = False
|
236 |
+
|
237 |
+
|
238 |
+
@dataclass
|
239 |
+
class FSDPConfig:
|
240 |
+
mixed_precision: bool = True
|
241 |
+
use_fp16: bool = False
|
242 |
+
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
|
243 |
+
sharding_strategy: str = (
|
244 |
+
"NO_SHARD" # ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
|
245 |
+
)
|
246 |
+
checkpoint_type: str = (
|
247 |
+
"SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
|
248 |
+
)
|
249 |
+
fsdp_activation_checkpointing: bool = True
|
250 |
+
fsdp_cpu_offload: bool = False
|
251 |
+
pure_bf16: bool = False
|
252 |
+
optimizer: str = "AdamW"
|
253 |
+
|
254 |
+
|
255 |
+
@dataclass
|
256 |
+
class LogConfig:
|
257 |
+
use_wandb: bool = False
|
258 |
+
wandb_dir: str = "/valleblob/v-wenxichen/exp/wandb_log"
|
259 |
+
wandb_entity_name: str = "project_name"
|
260 |
+
wandb_project_name: str = "project_name"
|
261 |
+
wandb_exp_name: str = "exp_name"
|
262 |
+
log_file: str = "/valleblob/v-wenxichen/exp/log/test.log"
|
263 |
+
log_interval: int = 10
|
264 |
+
online_output_dir: Optional[str] = None
|
265 |
+
|
266 |
+
|
267 |
+
@dataclass
|
268 |
+
class InferenceConfig:
|
269 |
+
dataset_config: DataConfig = field(default_factory=DataConfig)
|
270 |
+
model_config: ModelConfig = field(default_factory=ModelConfig)
|
271 |
+
train_config: TrainConfig = field(default_factory=TrainConfig)
|
272 |
+
decode_config: DecodeConfig = field(default_factory=DecodeConfig)
|
slam_llm/__init__.py
ADDED
File without changes
|
slam_llm/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
slam_llm/data/concatenator.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
from itertools import chain
|
6 |
+
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
|
9 |
+
|
10 |
+
class ConcatDataset(Dataset):
|
11 |
+
def __init__(self, dataset, chunk_size=4096):
|
12 |
+
self.dataset = dataset
|
13 |
+
self.chunk_size = chunk_size
|
14 |
+
|
15 |
+
self.samples = []
|
16 |
+
|
17 |
+
buffer = {
|
18 |
+
"input_ids": [],
|
19 |
+
"attention_mask": [],
|
20 |
+
"labels": [],
|
21 |
+
}
|
22 |
+
|
23 |
+
for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
|
24 |
+
buffer = {k: v + sample[k] for k,v in buffer.items()}
|
25 |
+
|
26 |
+
while len(next(iter(buffer.values()))) > self.chunk_size:
|
27 |
+
self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
|
28 |
+
buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
return self.samples[idx]
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.samples)
|
slam_llm/data/sampler.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import random
|
5 |
+
from itertools import islice
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
|
12 |
+
def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
|
13 |
+
if isinstance(next(iter(data_source)), dict):
|
14 |
+
first_key = next(iter(next(iter(data_source)).keys()))
|
15 |
+
self.lengths = [len(d[first_key]) for d in data_source]
|
16 |
+
else:
|
17 |
+
self.lengths = [len(d) for d in data_source]
|
18 |
+
self.batch_size = batch_size
|
19 |
+
self.drop_last = drop_last
|
20 |
+
self.shuffle = shuffle
|
21 |
+
|
22 |
+
def __iter__(self):
|
23 |
+
ids = np.argsort(self.lengths)
|
24 |
+
if self.drop_last:
|
25 |
+
ids = ids[:len(ids) // self.batch_size * self.batch_size]
|
26 |
+
|
27 |
+
batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
|
28 |
+
|
29 |
+
if self.shuffle:
|
30 |
+
random.shuffle(batches)
|
31 |
+
|
32 |
+
for b in batches:
|
33 |
+
yield b
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
if self.drop_last:
|
37 |
+
return len(self.lengths) // self.batch_size
|
38 |
+
else:
|
39 |
+
return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
|
40 |
+
|
41 |
+
|
42 |
+
class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
|
43 |
+
def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
|
44 |
+
random.seed(seed)
|
45 |
+
self.batch_sampler = LengthBasedBatchSampler(
|
46 |
+
data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
|
47 |
+
)
|
48 |
+
self.num_replicas = num_replicas
|
49 |
+
self.rank = rank
|
50 |
+
|
51 |
+
def __iter__(self):
|
52 |
+
max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
|
53 |
+
return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.batch_sampler) // self.num_replicas
|
57 |
+
|
slam_llm/models/BEATs/BEATs.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import LayerNorm
|
14 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
15 |
+
|
16 |
+
from .backbone import (
|
17 |
+
TransformerEncoder,
|
18 |
+
)
|
19 |
+
|
20 |
+
import logging
|
21 |
+
from typing import Optional
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class BEATsConfig:
|
27 |
+
def __init__(self, cfg=None):
|
28 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
29 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
30 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
31 |
+
|
32 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
33 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
34 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
35 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
36 |
+
self.activation_fn: str = "gelu" # activation function to use
|
37 |
+
|
38 |
+
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
39 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
40 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
41 |
+
|
42 |
+
# dropouts
|
43 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
44 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
45 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
46 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
47 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
48 |
+
|
49 |
+
# positional embeddings
|
50 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
51 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
52 |
+
|
53 |
+
# relative position embedding
|
54 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
55 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
56 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
57 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
58 |
+
|
59 |
+
# label predictor
|
60 |
+
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
61 |
+
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
62 |
+
self.predictor_class: int = 527 # target class number for the predictor
|
63 |
+
|
64 |
+
if cfg is not None:
|
65 |
+
self.update(cfg)
|
66 |
+
|
67 |
+
def update(self, cfg: dict):
|
68 |
+
self.__dict__.update(cfg)
|
69 |
+
|
70 |
+
|
71 |
+
class BEATs(nn.Module):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
cfg: BEATsConfig,
|
75 |
+
) -> None:
|
76 |
+
super().__init__()
|
77 |
+
logger.info(f"BEATs Config: {cfg.__dict__}")
|
78 |
+
|
79 |
+
self.cfg = cfg
|
80 |
+
|
81 |
+
self.embed = cfg.embed_dim
|
82 |
+
self.post_extract_proj = (
|
83 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
84 |
+
if self.embed != cfg.encoder_embed_dim
|
85 |
+
else None
|
86 |
+
)
|
87 |
+
|
88 |
+
self.input_patch_size = cfg.input_patch_size
|
89 |
+
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
90 |
+
bias=cfg.conv_bias)
|
91 |
+
|
92 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
93 |
+
|
94 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
95 |
+
self.encoder = TransformerEncoder(cfg)
|
96 |
+
self.layer_norm = LayerNorm(self.embed)
|
97 |
+
|
98 |
+
if cfg.finetuned_model:
|
99 |
+
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
100 |
+
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
101 |
+
else:
|
102 |
+
self.predictor = None
|
103 |
+
|
104 |
+
def forward_padding_mask(
|
105 |
+
self,
|
106 |
+
features: torch.Tensor,
|
107 |
+
padding_mask: torch.Tensor,
|
108 |
+
) -> torch.Tensor:
|
109 |
+
extra = padding_mask.size(1) % features.size(1)
|
110 |
+
if extra > 0:
|
111 |
+
padding_mask = padding_mask[:, :-extra]
|
112 |
+
padding_mask = padding_mask.view(
|
113 |
+
padding_mask.size(0), features.size(1), -1
|
114 |
+
)
|
115 |
+
padding_mask = padding_mask.all(-1)
|
116 |
+
return padding_mask
|
117 |
+
|
118 |
+
@classmethod
|
119 |
+
def preprocess(
|
120 |
+
cls,
|
121 |
+
source: torch.Tensor,
|
122 |
+
fbank_mean: float = 15.41663,
|
123 |
+
fbank_std: float = 6.55582,
|
124 |
+
) -> torch.Tensor:
|
125 |
+
if len(source.shape) > 1: # batch
|
126 |
+
fbanks = []
|
127 |
+
for waveform in source:
|
128 |
+
waveform = waveform.unsqueeze(0) * 2 ** 15
|
129 |
+
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
130 |
+
fbanks.append(fbank)
|
131 |
+
fbank = torch.stack(fbanks, dim=0)
|
132 |
+
else: # single
|
133 |
+
waveform = source.unsqueeze(0) * 2 ** 15
|
134 |
+
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
135 |
+
|
136 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
137 |
+
return fbank
|
138 |
+
|
139 |
+
def extract_features(
|
140 |
+
self,
|
141 |
+
fbank: torch.Tensor,
|
142 |
+
padding_mask: Optional[torch.Tensor] = None,
|
143 |
+
):
|
144 |
+
if padding_mask is not None:
|
145 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
146 |
+
|
147 |
+
fbank = fbank.unsqueeze(1)
|
148 |
+
features = self.patch_embedding(fbank)
|
149 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
150 |
+
features = features.transpose(1, 2)
|
151 |
+
features = self.layer_norm(features)
|
152 |
+
|
153 |
+
if padding_mask is not None:
|
154 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
155 |
+
|
156 |
+
if self.post_extract_proj is not None:
|
157 |
+
features = self.post_extract_proj(features)
|
158 |
+
|
159 |
+
x = self.dropout_input(features)
|
160 |
+
|
161 |
+
x, layer_results = self.encoder(
|
162 |
+
x,
|
163 |
+
padding_mask=padding_mask,
|
164 |
+
)
|
165 |
+
|
166 |
+
if self.predictor is not None:
|
167 |
+
x = self.predictor_dropout(x)
|
168 |
+
logits = self.predictor(x)
|
169 |
+
|
170 |
+
if padding_mask is not None and padding_mask.any():
|
171 |
+
logits[padding_mask] = 0
|
172 |
+
logits = logits.sum(dim=1)
|
173 |
+
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
|
174 |
+
else:
|
175 |
+
logits = logits.mean(dim=1)
|
176 |
+
|
177 |
+
lprobs = torch.sigmoid(logits)
|
178 |
+
|
179 |
+
return lprobs, padding_mask
|
180 |
+
else:
|
181 |
+
return x, padding_mask
|
slam_llm/models/BEATs/Tokenizers.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import LayerNorm
|
14 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
15 |
+
|
16 |
+
from .backbone import (
|
17 |
+
TransformerEncoder,
|
18 |
+
)
|
19 |
+
from .quantizer import (
|
20 |
+
NormEMAVectorQuantizer,
|
21 |
+
)
|
22 |
+
|
23 |
+
import logging
|
24 |
+
from typing import Optional
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class TokenizersConfig:
|
30 |
+
def __init__(self, cfg=None):
|
31 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
32 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
33 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
34 |
+
|
35 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
36 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
37 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
38 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
39 |
+
self.activation_fn: str = "gelu" # activation function to use
|
40 |
+
|
41 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
42 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
43 |
+
|
44 |
+
# dropouts
|
45 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
46 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
47 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
48 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
49 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
50 |
+
|
51 |
+
# positional embeddings
|
52 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
53 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
54 |
+
|
55 |
+
# relative position embedding
|
56 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
57 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
58 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
59 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
60 |
+
|
61 |
+
# quantizer
|
62 |
+
self.quant_n: int = 1024 # codebook number in quantizer
|
63 |
+
self.quant_dim: int = 256 # codebook dimension in quantizer
|
64 |
+
|
65 |
+
if cfg is not None:
|
66 |
+
self.update(cfg)
|
67 |
+
|
68 |
+
def update(self, cfg: dict):
|
69 |
+
self.__dict__.update(cfg)
|
70 |
+
|
71 |
+
|
72 |
+
class Tokenizers(nn.Module):
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
cfg: TokenizersConfig,
|
76 |
+
) -> None:
|
77 |
+
super().__init__()
|
78 |
+
logger.info(f"Tokenizers Config: {cfg.__dict__}")
|
79 |
+
|
80 |
+
self.cfg = cfg
|
81 |
+
|
82 |
+
self.embed = cfg.embed_dim
|
83 |
+
self.post_extract_proj = (
|
84 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
85 |
+
if self.embed != cfg.encoder_embed_dim
|
86 |
+
else None
|
87 |
+
)
|
88 |
+
|
89 |
+
self.input_patch_size = cfg.input_patch_size
|
90 |
+
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
91 |
+
bias=cfg.conv_bias)
|
92 |
+
|
93 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
94 |
+
|
95 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
96 |
+
self.encoder = TransformerEncoder(cfg)
|
97 |
+
self.layer_norm = LayerNorm(self.embed)
|
98 |
+
|
99 |
+
self.quantize = NormEMAVectorQuantizer(
|
100 |
+
n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
|
101 |
+
)
|
102 |
+
self.quant_n = cfg.quant_n
|
103 |
+
self.quantize_layer = nn.Sequential(
|
104 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
|
105 |
+
nn.Tanh(),
|
106 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
|
107 |
+
)
|
108 |
+
|
109 |
+
def forward_padding_mask(
|
110 |
+
self,
|
111 |
+
features: torch.Tensor,
|
112 |
+
padding_mask: torch.Tensor,
|
113 |
+
) -> torch.Tensor:
|
114 |
+
extra = padding_mask.size(1) % features.size(1)
|
115 |
+
if extra > 0:
|
116 |
+
padding_mask = padding_mask[:, :-extra]
|
117 |
+
padding_mask = padding_mask.view(
|
118 |
+
padding_mask.size(0), features.size(1), -1
|
119 |
+
)
|
120 |
+
padding_mask = padding_mask.all(-1)
|
121 |
+
return padding_mask
|
122 |
+
|
123 |
+
def preprocess(
|
124 |
+
self,
|
125 |
+
source: torch.Tensor,
|
126 |
+
fbank_mean: float = 15.41663,
|
127 |
+
fbank_std: float = 6.55582,
|
128 |
+
) -> torch.Tensor:
|
129 |
+
fbanks = []
|
130 |
+
for waveform in source:
|
131 |
+
waveform = waveform.unsqueeze(0) * 2 ** 15
|
132 |
+
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
133 |
+
fbanks.append(fbank)
|
134 |
+
fbank = torch.stack(fbanks, dim=0)
|
135 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
136 |
+
return fbank
|
137 |
+
|
138 |
+
def extract_labels(
|
139 |
+
self,
|
140 |
+
source: torch.Tensor,
|
141 |
+
padding_mask: Optional[torch.Tensor] = None,
|
142 |
+
fbank_mean: float = 15.41663,
|
143 |
+
fbank_std: float = 6.55582,
|
144 |
+
):
|
145 |
+
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
146 |
+
|
147 |
+
if padding_mask is not None:
|
148 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
149 |
+
|
150 |
+
fbank = fbank.unsqueeze(1)
|
151 |
+
features = self.patch_embedding(fbank)
|
152 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
153 |
+
features = features.transpose(1, 2)
|
154 |
+
features = self.layer_norm(features)
|
155 |
+
|
156 |
+
if padding_mask is not None:
|
157 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
158 |
+
|
159 |
+
if self.post_extract_proj is not None:
|
160 |
+
features = self.post_extract_proj(features)
|
161 |
+
|
162 |
+
x = self.dropout_input(features)
|
163 |
+
|
164 |
+
x, layer_results = self.encoder(
|
165 |
+
x,
|
166 |
+
padding_mask=padding_mask,
|
167 |
+
)
|
168 |
+
|
169 |
+
quantize_input = self.quantize_layer(x)
|
170 |
+
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
|
171 |
+
|
172 |
+
return embed_ind
|
173 |
+
|
slam_llm/models/BEATs/backbone.py
ADDED
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import numpy as np
|
12 |
+
from typing import Dict, Optional, Tuple
|
13 |
+
import torch
|
14 |
+
from torch import Tensor, nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.nn import LayerNorm, Parameter
|
17 |
+
from .modules import (
|
18 |
+
GradMultiply,
|
19 |
+
SamePad,
|
20 |
+
get_activation_fn,
|
21 |
+
GLU_Linear,
|
22 |
+
quant_noise,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class TransformerEncoder(nn.Module):
|
27 |
+
def __init__(self, args):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.dropout = args.dropout
|
31 |
+
self.embedding_dim = args.encoder_embed_dim
|
32 |
+
|
33 |
+
self.pos_conv = nn.Conv1d(
|
34 |
+
self.embedding_dim,
|
35 |
+
self.embedding_dim,
|
36 |
+
kernel_size=args.conv_pos,
|
37 |
+
padding=args.conv_pos // 2,
|
38 |
+
groups=args.conv_pos_groups,
|
39 |
+
)
|
40 |
+
dropout = 0
|
41 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
42 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
43 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
44 |
+
|
45 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
46 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
47 |
+
|
48 |
+
if hasattr(args, "relative_position_embedding"):
|
49 |
+
self.relative_position_embedding = args.relative_position_embedding
|
50 |
+
self.num_buckets = args.num_buckets
|
51 |
+
self.max_distance = args.max_distance
|
52 |
+
else:
|
53 |
+
self.relative_position_embedding = False
|
54 |
+
self.num_buckets = 0
|
55 |
+
self.max_distance = 0
|
56 |
+
|
57 |
+
self.layers = nn.ModuleList(
|
58 |
+
[
|
59 |
+
TransformerSentenceEncoderLayer(
|
60 |
+
embedding_dim=self.embedding_dim,
|
61 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
62 |
+
num_attention_heads=args.encoder_attention_heads,
|
63 |
+
dropout=self.dropout,
|
64 |
+
attention_dropout=args.attention_dropout,
|
65 |
+
activation_dropout=args.activation_dropout,
|
66 |
+
activation_fn=args.activation_fn,
|
67 |
+
layer_norm_first=args.layer_norm_first,
|
68 |
+
deep_norm=args.deep_norm,
|
69 |
+
has_relative_attention_bias=self.relative_position_embedding,
|
70 |
+
num_buckets=self.num_buckets,
|
71 |
+
max_distance=self.max_distance,
|
72 |
+
gru_rel_pos=args.gru_rel_pos,
|
73 |
+
encoder_layers=args.encoder_layers,
|
74 |
+
)
|
75 |
+
for i in range(args.encoder_layers)
|
76 |
+
]
|
77 |
+
)
|
78 |
+
if self.relative_position_embedding:
|
79 |
+
for i in range(1, args.encoder_layers):
|
80 |
+
del self.layers[i].self_attn.relative_attention_bias
|
81 |
+
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
|
82 |
+
|
83 |
+
self.layer_norm_first = args.layer_norm_first
|
84 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
85 |
+
self.layerdrop = args.encoder_layerdrop
|
86 |
+
|
87 |
+
self.apply(init_bert_params)
|
88 |
+
|
89 |
+
if args.deep_norm:
|
90 |
+
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
91 |
+
for i in range(args.encoder_layers):
|
92 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
93 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
94 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
95 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
|
96 |
+
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
97 |
+
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
98 |
+
|
99 |
+
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
|
100 |
+
|
101 |
+
def forward(self, x, padding_mask=None, layer=None):
|
102 |
+
x, layer_results = self.extract_features(x, padding_mask, layer)
|
103 |
+
|
104 |
+
if self.layer_norm_first and layer is None:
|
105 |
+
x = self.layer_norm(x)
|
106 |
+
|
107 |
+
return x, layer_results
|
108 |
+
|
109 |
+
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
110 |
+
|
111 |
+
if padding_mask is not None:
|
112 |
+
x[padding_mask] = 0
|
113 |
+
|
114 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
115 |
+
x_conv = x_conv.transpose(1, 2)
|
116 |
+
x = x + x_conv
|
117 |
+
|
118 |
+
if not self.layer_norm_first:
|
119 |
+
x = self.layer_norm(x)
|
120 |
+
|
121 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
122 |
+
|
123 |
+
# B x T x C -> T x B x C
|
124 |
+
x = x.transpose(0, 1)
|
125 |
+
|
126 |
+
layer_results = []
|
127 |
+
z = None
|
128 |
+
if tgt_layer is not None:
|
129 |
+
layer_results.append((x, z))
|
130 |
+
r = None
|
131 |
+
pos_bias = None
|
132 |
+
for i, layer in enumerate(self.layers):
|
133 |
+
if self.layer_wise_gradient_decay_ratio != 1.0:
|
134 |
+
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
135 |
+
dropout_probability = np.random.random()
|
136 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
137 |
+
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
|
138 |
+
if tgt_layer is not None:
|
139 |
+
layer_results.append((x, z))
|
140 |
+
if i == tgt_layer:
|
141 |
+
r = x
|
142 |
+
break
|
143 |
+
|
144 |
+
if r is not None:
|
145 |
+
x = r
|
146 |
+
|
147 |
+
# T x B x C -> B x T x C
|
148 |
+
x = x.transpose(0, 1)
|
149 |
+
|
150 |
+
return x, layer_results
|
151 |
+
|
152 |
+
|
153 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
embedding_dim: float = 768,
|
157 |
+
ffn_embedding_dim: float = 3072,
|
158 |
+
num_attention_heads: float = 8,
|
159 |
+
dropout: float = 0.1,
|
160 |
+
attention_dropout: float = 0.1,
|
161 |
+
activation_dropout: float = 0.1,
|
162 |
+
activation_fn: str = "relu",
|
163 |
+
layer_norm_first: bool = False,
|
164 |
+
deep_norm: bool = False,
|
165 |
+
has_relative_attention_bias: bool = False,
|
166 |
+
num_buckets: int = 0,
|
167 |
+
max_distance: int = 0,
|
168 |
+
rescale_init: bool = False,
|
169 |
+
gru_rel_pos: bool = False,
|
170 |
+
encoder_layers: int = 0,
|
171 |
+
) -> None:
|
172 |
+
|
173 |
+
super().__init__()
|
174 |
+
self.embedding_dim = embedding_dim
|
175 |
+
self.dropout = dropout
|
176 |
+
self.activation_dropout = activation_dropout
|
177 |
+
|
178 |
+
self.activation_name = activation_fn
|
179 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
180 |
+
self.self_attn = MultiheadAttention(
|
181 |
+
self.embedding_dim,
|
182 |
+
num_attention_heads,
|
183 |
+
dropout=attention_dropout,
|
184 |
+
self_attention=True,
|
185 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
186 |
+
num_buckets=num_buckets,
|
187 |
+
max_distance=max_distance,
|
188 |
+
rescale_init=rescale_init,
|
189 |
+
gru_rel_pos=gru_rel_pos,
|
190 |
+
)
|
191 |
+
|
192 |
+
self.dropout1 = nn.Dropout(dropout)
|
193 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
194 |
+
self.dropout3 = nn.Dropout(dropout)
|
195 |
+
|
196 |
+
self.layer_norm_first = layer_norm_first
|
197 |
+
|
198 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
199 |
+
|
200 |
+
if self.activation_name == "glu":
|
201 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
202 |
+
else:
|
203 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
204 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
205 |
+
|
206 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
207 |
+
|
208 |
+
self.deep_norm = deep_norm
|
209 |
+
if self.deep_norm:
|
210 |
+
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
211 |
+
else:
|
212 |
+
self.deep_norm_alpha = 1
|
213 |
+
|
214 |
+
def forward(
|
215 |
+
self,
|
216 |
+
x: torch.Tensor,
|
217 |
+
self_attn_mask: torch.Tensor = None,
|
218 |
+
self_attn_padding_mask: torch.Tensor = None,
|
219 |
+
need_weights: bool = False,
|
220 |
+
pos_bias=None
|
221 |
+
):
|
222 |
+
residual = x
|
223 |
+
|
224 |
+
if self.layer_norm_first:
|
225 |
+
x = self.self_attn_layer_norm(x)
|
226 |
+
x, attn, pos_bias = self.self_attn(
|
227 |
+
query=x,
|
228 |
+
key=x,
|
229 |
+
value=x,
|
230 |
+
key_padding_mask=self_attn_padding_mask,
|
231 |
+
need_weights=False,
|
232 |
+
attn_mask=self_attn_mask,
|
233 |
+
position_bias=pos_bias
|
234 |
+
)
|
235 |
+
x = self.dropout1(x)
|
236 |
+
x = residual + x
|
237 |
+
|
238 |
+
residual = x
|
239 |
+
x = self.final_layer_norm(x)
|
240 |
+
if self.activation_name == "glu":
|
241 |
+
x = self.fc1(x)
|
242 |
+
else:
|
243 |
+
x = self.activation_fn(self.fc1(x))
|
244 |
+
x = self.dropout2(x)
|
245 |
+
x = self.fc2(x)
|
246 |
+
x = self.dropout3(x)
|
247 |
+
x = residual + x
|
248 |
+
else:
|
249 |
+
x, attn, pos_bias = self.self_attn(
|
250 |
+
query=x,
|
251 |
+
key=x,
|
252 |
+
value=x,
|
253 |
+
key_padding_mask=self_attn_padding_mask,
|
254 |
+
need_weights=need_weights,
|
255 |
+
attn_mask=self_attn_mask,
|
256 |
+
position_bias=pos_bias
|
257 |
+
)
|
258 |
+
|
259 |
+
x = self.dropout1(x)
|
260 |
+
x = residual * self.deep_norm_alpha + x
|
261 |
+
|
262 |
+
x = self.self_attn_layer_norm(x)
|
263 |
+
|
264 |
+
residual = x
|
265 |
+
if self.activation_name == "glu":
|
266 |
+
x = self.fc1(x)
|
267 |
+
else:
|
268 |
+
x = self.activation_fn(self.fc1(x))
|
269 |
+
x = self.dropout2(x)
|
270 |
+
x = self.fc2(x)
|
271 |
+
x = self.dropout3(x)
|
272 |
+
x = residual * self.deep_norm_alpha + x
|
273 |
+
x = self.final_layer_norm(x)
|
274 |
+
|
275 |
+
return x, attn, pos_bias
|
276 |
+
|
277 |
+
|
278 |
+
class MultiheadAttention(nn.Module):
|
279 |
+
"""Multi-headed attention.
|
280 |
+
|
281 |
+
See "Attention Is All You Need" for more details.
|
282 |
+
"""
|
283 |
+
|
284 |
+
def __init__(
|
285 |
+
self,
|
286 |
+
embed_dim,
|
287 |
+
num_heads,
|
288 |
+
kdim=None,
|
289 |
+
vdim=None,
|
290 |
+
dropout=0.0,
|
291 |
+
bias=True,
|
292 |
+
add_bias_kv=False,
|
293 |
+
add_zero_attn=False,
|
294 |
+
self_attention=False,
|
295 |
+
encoder_decoder_attention=False,
|
296 |
+
q_noise=0.0,
|
297 |
+
qn_block_size=8,
|
298 |
+
has_relative_attention_bias=False,
|
299 |
+
num_buckets=32,
|
300 |
+
max_distance=128,
|
301 |
+
gru_rel_pos=False,
|
302 |
+
rescale_init=False,
|
303 |
+
):
|
304 |
+
super().__init__()
|
305 |
+
self.embed_dim = embed_dim
|
306 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
307 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
308 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
309 |
+
|
310 |
+
self.num_heads = num_heads
|
311 |
+
self.dropout_module = nn.Dropout(dropout)
|
312 |
+
|
313 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
314 |
+
self.num_buckets = num_buckets
|
315 |
+
self.max_distance = max_distance
|
316 |
+
if self.has_relative_attention_bias:
|
317 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
318 |
+
|
319 |
+
self.head_dim = embed_dim // num_heads
|
320 |
+
self.q_head_dim = self.head_dim
|
321 |
+
self.k_head_dim = self.head_dim
|
322 |
+
assert (
|
323 |
+
self.head_dim * num_heads == self.embed_dim
|
324 |
+
), "embed_dim must be divisible by num_heads"
|
325 |
+
self.scaling = self.head_dim ** -0.5
|
326 |
+
|
327 |
+
self.self_attention = self_attention
|
328 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
329 |
+
|
330 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
331 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
332 |
+
)
|
333 |
+
|
334 |
+
k_bias = True
|
335 |
+
if rescale_init:
|
336 |
+
k_bias = False
|
337 |
+
|
338 |
+
k_embed_dim = embed_dim
|
339 |
+
q_embed_dim = embed_dim
|
340 |
+
|
341 |
+
self.k_proj = quant_noise(
|
342 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
343 |
+
)
|
344 |
+
self.v_proj = quant_noise(
|
345 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
346 |
+
)
|
347 |
+
self.q_proj = quant_noise(
|
348 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
349 |
+
)
|
350 |
+
|
351 |
+
self.out_proj = quant_noise(
|
352 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
353 |
+
)
|
354 |
+
|
355 |
+
if add_bias_kv:
|
356 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
357 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
358 |
+
else:
|
359 |
+
self.bias_k = self.bias_v = None
|
360 |
+
|
361 |
+
self.add_zero_attn = add_zero_attn
|
362 |
+
|
363 |
+
self.gru_rel_pos = gru_rel_pos
|
364 |
+
if self.gru_rel_pos:
|
365 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
366 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
367 |
+
|
368 |
+
self.reset_parameters()
|
369 |
+
|
370 |
+
def reset_parameters(self):
|
371 |
+
if self.qkv_same_dim:
|
372 |
+
# Empirically observed the convergence to be much better with
|
373 |
+
# the scaled initialization
|
374 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
375 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
376 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
377 |
+
else:
|
378 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
379 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
380 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
381 |
+
|
382 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
383 |
+
if self.out_proj.bias is not None:
|
384 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
385 |
+
if self.bias_k is not None:
|
386 |
+
nn.init.xavier_normal_(self.bias_k)
|
387 |
+
if self.bias_v is not None:
|
388 |
+
nn.init.xavier_normal_(self.bias_v)
|
389 |
+
if self.has_relative_attention_bias:
|
390 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
391 |
+
|
392 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
393 |
+
num_buckets = self.num_buckets
|
394 |
+
max_distance = self.max_distance
|
395 |
+
relative_buckets = 0
|
396 |
+
|
397 |
+
if bidirectional:
|
398 |
+
num_buckets = num_buckets // 2
|
399 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
400 |
+
relative_positions = torch.abs(relative_positions)
|
401 |
+
else:
|
402 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
403 |
+
|
404 |
+
max_exact = num_buckets // 2
|
405 |
+
is_small = relative_positions < max_exact
|
406 |
+
|
407 |
+
relative_postion_if_large = max_exact + (
|
408 |
+
torch.log(relative_positions.float() / max_exact)
|
409 |
+
/ math.log(max_distance / max_exact)
|
410 |
+
* (num_buckets - max_exact)
|
411 |
+
).to(torch.long)
|
412 |
+
relative_postion_if_large = torch.min(
|
413 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
414 |
+
)
|
415 |
+
|
416 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
417 |
+
return relative_buckets
|
418 |
+
|
419 |
+
def compute_bias(self, query_length, key_length):
|
420 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
421 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
422 |
+
relative_position = memory_position - context_position
|
423 |
+
relative_position_bucket = self._relative_positions_bucket(
|
424 |
+
relative_position,
|
425 |
+
bidirectional=True
|
426 |
+
)
|
427 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
428 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
429 |
+
values = values.permute([2, 0, 1])
|
430 |
+
return values
|
431 |
+
|
432 |
+
def forward(
|
433 |
+
self,
|
434 |
+
query,
|
435 |
+
key: Optional[Tensor],
|
436 |
+
value: Optional[Tensor],
|
437 |
+
key_padding_mask: Optional[Tensor] = None,
|
438 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
439 |
+
need_weights: bool = True,
|
440 |
+
static_kv: bool = False,
|
441 |
+
attn_mask: Optional[Tensor] = None,
|
442 |
+
before_softmax: bool = False,
|
443 |
+
need_head_weights: bool = False,
|
444 |
+
position_bias: Optional[Tensor] = None
|
445 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
446 |
+
"""Input shape: Time x Batch x Channel
|
447 |
+
|
448 |
+
Args:
|
449 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
450 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
451 |
+
padding elements are indicated by 1s.
|
452 |
+
need_weights (bool, optional): return the attention weights,
|
453 |
+
averaged over heads (default: False).
|
454 |
+
attn_mask (ByteTensor, optional): typically used to
|
455 |
+
implement causal attention, where the mask prevents the
|
456 |
+
attention from looking forward in time (default: None).
|
457 |
+
before_softmax (bool, optional): return the raw attention
|
458 |
+
weights and values before the attention softmax.
|
459 |
+
need_head_weights (bool, optional): return the attention
|
460 |
+
weights for each head. Implies *need_weights*. Default:
|
461 |
+
return the average attention weights over all heads.
|
462 |
+
"""
|
463 |
+
if need_head_weights:
|
464 |
+
need_weights = True
|
465 |
+
|
466 |
+
is_tpu = query.device.type == "xla"
|
467 |
+
|
468 |
+
tgt_len, bsz, embed_dim = query.size()
|
469 |
+
src_len = tgt_len
|
470 |
+
assert embed_dim == self.embed_dim
|
471 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
472 |
+
if key is not None:
|
473 |
+
src_len, key_bsz, _ = key.size()
|
474 |
+
if not torch.jit.is_scripting():
|
475 |
+
assert key_bsz == bsz
|
476 |
+
assert value is not None
|
477 |
+
assert src_len, bsz == value.shape[:2]
|
478 |
+
|
479 |
+
if self.has_relative_attention_bias and position_bias is None:
|
480 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
481 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
482 |
+
|
483 |
+
if incremental_state is not None:
|
484 |
+
saved_state = self._get_input_buffer(incremental_state)
|
485 |
+
if saved_state is not None and "prev_key" in saved_state:
|
486 |
+
# previous time steps are cached - no need to recompute
|
487 |
+
# key and value if they are static
|
488 |
+
if static_kv:
|
489 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
490 |
+
key = value = None
|
491 |
+
else:
|
492 |
+
saved_state = None
|
493 |
+
|
494 |
+
if self.self_attention:
|
495 |
+
q = self.q_proj(query)
|
496 |
+
k = self.k_proj(query)
|
497 |
+
v = self.v_proj(query)
|
498 |
+
elif self.encoder_decoder_attention:
|
499 |
+
# encoder-decoder attention
|
500 |
+
q = self.q_proj(query)
|
501 |
+
if key is None:
|
502 |
+
assert value is None
|
503 |
+
k = v = None
|
504 |
+
else:
|
505 |
+
k = self.k_proj(key)
|
506 |
+
v = self.v_proj(key)
|
507 |
+
|
508 |
+
else:
|
509 |
+
assert key is not None and value is not None
|
510 |
+
q = self.q_proj(query)
|
511 |
+
k = self.k_proj(key)
|
512 |
+
v = self.v_proj(value)
|
513 |
+
q *= self.scaling
|
514 |
+
alpha = 32
|
515 |
+
q *= 1 / alpha
|
516 |
+
|
517 |
+
if self.bias_k is not None:
|
518 |
+
assert self.bias_v is not None
|
519 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
520 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
521 |
+
if attn_mask is not None:
|
522 |
+
attn_mask = torch.cat(
|
523 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
524 |
+
)
|
525 |
+
if key_padding_mask is not None:
|
526 |
+
key_padding_mask = torch.cat(
|
527 |
+
[
|
528 |
+
key_padding_mask,
|
529 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
530 |
+
],
|
531 |
+
dim=1,
|
532 |
+
)
|
533 |
+
|
534 |
+
q = (
|
535 |
+
q.contiguous()
|
536 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
537 |
+
.transpose(0, 1)
|
538 |
+
)
|
539 |
+
if k is not None:
|
540 |
+
k = (
|
541 |
+
k.contiguous()
|
542 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
543 |
+
.transpose(0, 1)
|
544 |
+
)
|
545 |
+
if v is not None:
|
546 |
+
v = (
|
547 |
+
v.contiguous()
|
548 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
549 |
+
.transpose(0, 1)
|
550 |
+
)
|
551 |
+
|
552 |
+
if saved_state is not None:
|
553 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
554 |
+
if "prev_key" in saved_state:
|
555 |
+
_prev_key = saved_state["prev_key"]
|
556 |
+
assert _prev_key is not None
|
557 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
558 |
+
if static_kv:
|
559 |
+
k = prev_key
|
560 |
+
else:
|
561 |
+
assert k is not None
|
562 |
+
k = torch.cat([prev_key, k], dim=1)
|
563 |
+
src_len = k.size(1)
|
564 |
+
if "prev_value" in saved_state:
|
565 |
+
_prev_value = saved_state["prev_value"]
|
566 |
+
assert _prev_value is not None
|
567 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
568 |
+
if static_kv:
|
569 |
+
v = prev_value
|
570 |
+
else:
|
571 |
+
assert v is not None
|
572 |
+
v = torch.cat([prev_value, v], dim=1)
|
573 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
574 |
+
if "prev_key_padding_mask" in saved_state:
|
575 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
576 |
+
assert k is not None and v is not None
|
577 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
578 |
+
key_padding_mask=key_padding_mask,
|
579 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
580 |
+
batch_size=bsz,
|
581 |
+
src_len=k.size(1),
|
582 |
+
static_kv=static_kv,
|
583 |
+
)
|
584 |
+
|
585 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
586 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
587 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
588 |
+
# In this branch incremental_state is never None
|
589 |
+
assert incremental_state is not None
|
590 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
591 |
+
assert k is not None
|
592 |
+
assert k.size(1) == src_len
|
593 |
+
|
594 |
+
# This is part of a workaround to get around fork/join parallelism
|
595 |
+
# not supporting Optional types.
|
596 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
597 |
+
key_padding_mask = None
|
598 |
+
|
599 |
+
if key_padding_mask is not None:
|
600 |
+
assert key_padding_mask.size(0) == bsz
|
601 |
+
assert key_padding_mask.size(1) == src_len
|
602 |
+
|
603 |
+
if self.add_zero_attn:
|
604 |
+
assert v is not None
|
605 |
+
src_len += 1
|
606 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
607 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
608 |
+
if attn_mask is not None:
|
609 |
+
attn_mask = torch.cat(
|
610 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
611 |
+
)
|
612 |
+
if key_padding_mask is not None:
|
613 |
+
key_padding_mask = torch.cat(
|
614 |
+
[
|
615 |
+
key_padding_mask,
|
616 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
617 |
+
key_padding_mask
|
618 |
+
),
|
619 |
+
],
|
620 |
+
dim=1,
|
621 |
+
)
|
622 |
+
|
623 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
624 |
+
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
625 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
626 |
+
|
627 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
628 |
+
|
629 |
+
if attn_mask is not None:
|
630 |
+
attn_mask = attn_mask.unsqueeze(0)
|
631 |
+
attn_weights += attn_mask
|
632 |
+
|
633 |
+
if key_padding_mask is not None:
|
634 |
+
# don't attend to padding symbols
|
635 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
636 |
+
if not is_tpu:
|
637 |
+
attn_weights = attn_weights.masked_fill(
|
638 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
639 |
+
float("-inf"),
|
640 |
+
)
|
641 |
+
else:
|
642 |
+
attn_weights = attn_weights.transpose(0, 2)
|
643 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
644 |
+
attn_weights = attn_weights.transpose(0, 2)
|
645 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
646 |
+
|
647 |
+
if before_softmax:
|
648 |
+
return attn_weights, v, position_bias
|
649 |
+
|
650 |
+
if position_bias is not None:
|
651 |
+
attn_mask_rel_pos = position_bias
|
652 |
+
if self.gru_rel_pos == 1:
|
653 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
|
654 |
+
_B, _H, _L, __ = query_layer.size()
|
655 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
656 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
657 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
658 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
659 |
+
|
660 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
661 |
+
|
662 |
+
attn_weights = attn_weights + attn_mask_rel_pos
|
663 |
+
|
664 |
+
attn_weights_float = F.softmax(
|
665 |
+
attn_weights, dim=-1
|
666 |
+
)
|
667 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
668 |
+
attn_probs = self.dropout_module(attn_weights)
|
669 |
+
|
670 |
+
assert v is not None
|
671 |
+
attn = torch.bmm(attn_probs, v)
|
672 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
673 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
674 |
+
attn = self.out_proj(attn)
|
675 |
+
attn_weights: Optional[Tensor] = None
|
676 |
+
if need_weights:
|
677 |
+
attn_weights = attn_weights_float.view(
|
678 |
+
bsz, self.num_heads, tgt_len, src_len
|
679 |
+
).transpose(1, 0)
|
680 |
+
if not need_head_weights:
|
681 |
+
# average attention weights over heads
|
682 |
+
attn_weights = attn_weights.mean(dim=0)
|
683 |
+
|
684 |
+
return attn, attn_weights, position_bias
|
685 |
+
|
686 |
+
@staticmethod
|
687 |
+
def _append_prev_key_padding_mask(
|
688 |
+
key_padding_mask: Optional[Tensor],
|
689 |
+
prev_key_padding_mask: Optional[Tensor],
|
690 |
+
batch_size: int,
|
691 |
+
src_len: int,
|
692 |
+
static_kv: bool,
|
693 |
+
) -> Optional[Tensor]:
|
694 |
+
# saved key padding masks have shape (bsz, seq_len)
|
695 |
+
if prev_key_padding_mask is not None and static_kv:
|
696 |
+
new_key_padding_mask = prev_key_padding_mask
|
697 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
698 |
+
new_key_padding_mask = torch.cat(
|
699 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
700 |
+
)
|
701 |
+
# During incremental decoding, as the padding token enters and
|
702 |
+
# leaves the frame, there will be a time when prev or current
|
703 |
+
# is None
|
704 |
+
elif prev_key_padding_mask is not None:
|
705 |
+
if src_len > prev_key_padding_mask.size(1):
|
706 |
+
filler = torch.zeros(
|
707 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
708 |
+
device=prev_key_padding_mask.device,
|
709 |
+
)
|
710 |
+
new_key_padding_mask = torch.cat(
|
711 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
712 |
+
)
|
713 |
+
else:
|
714 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
715 |
+
elif key_padding_mask is not None:
|
716 |
+
if src_len > key_padding_mask.size(1):
|
717 |
+
filler = torch.zeros(
|
718 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
719 |
+
device=key_padding_mask.device,
|
720 |
+
)
|
721 |
+
new_key_padding_mask = torch.cat(
|
722 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
723 |
+
)
|
724 |
+
else:
|
725 |
+
new_key_padding_mask = key_padding_mask.float()
|
726 |
+
else:
|
727 |
+
new_key_padding_mask = prev_key_padding_mask
|
728 |
+
return new_key_padding_mask
|
729 |
+
|
730 |
+
def _get_input_buffer(
|
731 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
732 |
+
) -> Dict[str, Optional[Tensor]]:
|
733 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
734 |
+
if result is not None:
|
735 |
+
return result
|
736 |
+
else:
|
737 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
738 |
+
return empty_result
|
739 |
+
|
740 |
+
def _set_input_buffer(
|
741 |
+
self,
|
742 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
743 |
+
buffer: Dict[str, Optional[Tensor]],
|
744 |
+
):
|
745 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
746 |
+
|
747 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
748 |
+
return attn_weights
|
749 |
+
|
750 |
+
|
751 |
+
def init_bert_params(module):
|
752 |
+
"""
|
753 |
+
Initialize the weights specific to the BERT Model.
|
754 |
+
This overrides the default initializations depending on the specified arguments.
|
755 |
+
1. If normal_init_linear_weights is set then weights of linear
|
756 |
+
layer will be initialized using the normal distribution and
|
757 |
+
bais will be set to the specified value.
|
758 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
759 |
+
layer will be initialized using the normal distribution.
|
760 |
+
3. If normal_init_proj_weights is set then weights of
|
761 |
+
in_project_weight for MultiHeadAttention initialized using
|
762 |
+
the normal distribution (to be validated).
|
763 |
+
"""
|
764 |
+
|
765 |
+
def normal_(data):
|
766 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
767 |
+
# so that the RNG is consistent with and without FSDP
|
768 |
+
data.copy_(
|
769 |
+
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
770 |
+
)
|
771 |
+
|
772 |
+
if isinstance(module, nn.Linear):
|
773 |
+
normal_(module.weight.data)
|
774 |
+
if module.bias is not None:
|
775 |
+
module.bias.data.zero_()
|
776 |
+
if isinstance(module, nn.Embedding):
|
777 |
+
normal_(module.weight.data)
|
778 |
+
if module.padding_idx is not None:
|
779 |
+
module.weight.data[module.padding_idx].zero_()
|
780 |
+
if isinstance(module, MultiheadAttention):
|
781 |
+
normal_(module.q_proj.weight.data)
|
782 |
+
normal_(module.k_proj.weight.data)
|
783 |
+
normal_(module.v_proj.weight.data)
|
slam_llm/models/BEATs/modules.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import warnings
|
12 |
+
import torch
|
13 |
+
from torch import Tensor, nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
|
17 |
+
class GradMultiply(torch.autograd.Function):
|
18 |
+
@staticmethod
|
19 |
+
def forward(ctx, x, scale):
|
20 |
+
ctx.scale = scale
|
21 |
+
res = x.new(x)
|
22 |
+
return res
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def backward(ctx, grad):
|
26 |
+
return grad * ctx.scale, None
|
27 |
+
|
28 |
+
|
29 |
+
class SamePad(nn.Module):
|
30 |
+
def __init__(self, kernel_size, causal=False):
|
31 |
+
super().__init__()
|
32 |
+
if causal:
|
33 |
+
self.remove = kernel_size - 1
|
34 |
+
else:
|
35 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
if self.remove > 0:
|
39 |
+
x = x[:, :, : -self.remove]
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class Swish(nn.Module):
|
44 |
+
def __init__(self):
|
45 |
+
super(Swish, self).__init__()
|
46 |
+
self.act = torch.nn.Sigmoid()
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
return x * self.act(x)
|
50 |
+
|
51 |
+
|
52 |
+
class GLU_Linear(nn.Module):
|
53 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
54 |
+
super(GLU_Linear, self).__init__()
|
55 |
+
|
56 |
+
self.glu_type = glu_type
|
57 |
+
self.output_dim = output_dim
|
58 |
+
|
59 |
+
if glu_type == "sigmoid":
|
60 |
+
self.glu_act = torch.nn.Sigmoid()
|
61 |
+
elif glu_type == "swish":
|
62 |
+
self.glu_act = Swish()
|
63 |
+
elif glu_type == "relu":
|
64 |
+
self.glu_act = torch.nn.ReLU()
|
65 |
+
elif glu_type == "gelu":
|
66 |
+
self.glu_act = torch.nn.GELU()
|
67 |
+
|
68 |
+
if bias_in_glu:
|
69 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
70 |
+
else:
|
71 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
75 |
+
x = self.linear(x)
|
76 |
+
|
77 |
+
if self.glu_type == "bilinear":
|
78 |
+
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
79 |
+
else:
|
80 |
+
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
81 |
+
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
def gelu_accurate(x):
|
86 |
+
if not hasattr(gelu_accurate, "_a"):
|
87 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
88 |
+
return (
|
89 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
94 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
95 |
+
|
96 |
+
|
97 |
+
def get_activation_fn(activation: str):
|
98 |
+
"""Returns the activation function corresponding to `activation`"""
|
99 |
+
|
100 |
+
if activation == "relu":
|
101 |
+
return F.relu
|
102 |
+
elif activation == "gelu":
|
103 |
+
return gelu
|
104 |
+
elif activation == "gelu_fast":
|
105 |
+
warnings.warn(
|
106 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
107 |
+
)
|
108 |
+
return gelu_accurate
|
109 |
+
elif activation == "gelu_accurate":
|
110 |
+
return gelu_accurate
|
111 |
+
elif activation == "tanh":
|
112 |
+
return torch.tanh
|
113 |
+
elif activation == "linear":
|
114 |
+
return lambda x: x
|
115 |
+
elif activation == "glu":
|
116 |
+
return lambda x: x
|
117 |
+
else:
|
118 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
119 |
+
|
120 |
+
|
121 |
+
def quant_noise(module, p, block_size):
|
122 |
+
"""
|
123 |
+
Wraps modules and applies quantization noise to the weights for
|
124 |
+
subsequent quantization with Iterative Product Quantization as
|
125 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
126 |
+
|
127 |
+
Args:
|
128 |
+
- module: nn.Module
|
129 |
+
- p: amount of Quantization Noise
|
130 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
131 |
+
|
132 |
+
Remarks:
|
133 |
+
- Module weights must have the right sizes wrt the block size
|
134 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
135 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
136 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
137 |
+
- We implement the simplest form of noise here as stated in the paper
|
138 |
+
which consists in randomly dropping blocks
|
139 |
+
"""
|
140 |
+
|
141 |
+
# if no quantization noise, don't register hook
|
142 |
+
if p <= 0:
|
143 |
+
return module
|
144 |
+
|
145 |
+
# supported modules
|
146 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
147 |
+
|
148 |
+
# test whether module.weight has the right sizes wrt block_size
|
149 |
+
is_conv = module.weight.ndim == 4
|
150 |
+
|
151 |
+
# 2D matrix
|
152 |
+
if not is_conv:
|
153 |
+
assert (
|
154 |
+
module.weight.size(1) % block_size == 0
|
155 |
+
), "Input features must be a multiple of block sizes"
|
156 |
+
|
157 |
+
# 4D matrix
|
158 |
+
else:
|
159 |
+
# 1x1 convolutions
|
160 |
+
if module.kernel_size == (1, 1):
|
161 |
+
assert (
|
162 |
+
module.in_channels % block_size == 0
|
163 |
+
), "Input channels must be a multiple of block sizes"
|
164 |
+
# regular convolutions
|
165 |
+
else:
|
166 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
167 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
168 |
+
|
169 |
+
def _forward_pre_hook(mod, input):
|
170 |
+
# no noise for evaluation
|
171 |
+
if mod.training:
|
172 |
+
if not is_conv:
|
173 |
+
# gather weight and sizes
|
174 |
+
weight = mod.weight
|
175 |
+
in_features = weight.size(1)
|
176 |
+
out_features = weight.size(0)
|
177 |
+
|
178 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
179 |
+
mask = torch.zeros(
|
180 |
+
in_features // block_size * out_features, device=weight.device
|
181 |
+
)
|
182 |
+
mask.bernoulli_(p)
|
183 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
184 |
+
|
185 |
+
else:
|
186 |
+
# gather weight and sizes
|
187 |
+
weight = mod.weight
|
188 |
+
in_channels = mod.in_channels
|
189 |
+
out_channels = mod.out_channels
|
190 |
+
|
191 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
192 |
+
if mod.kernel_size == (1, 1):
|
193 |
+
mask = torch.zeros(
|
194 |
+
int(in_channels // block_size * out_channels),
|
195 |
+
device=weight.device,
|
196 |
+
)
|
197 |
+
mask.bernoulli_(p)
|
198 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
199 |
+
else:
|
200 |
+
mask = torch.zeros(
|
201 |
+
weight.size(0), weight.size(1), device=weight.device
|
202 |
+
)
|
203 |
+
mask.bernoulli_(p)
|
204 |
+
mask = (
|
205 |
+
mask.unsqueeze(2)
|
206 |
+
.unsqueeze(3)
|
207 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
208 |
+
)
|
209 |
+
|
210 |
+
# scale weights and apply mask
|
211 |
+
mask = mask.to(
|
212 |
+
torch.bool
|
213 |
+
) # x.bool() is not currently supported in TorchScript
|
214 |
+
s = 1 / (1 - p)
|
215 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
216 |
+
|
217 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
218 |
+
return module
|
219 |
+
|
slam_llm/models/BEATs/quantizer.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on VQGAN code bases
|
7 |
+
# https://github.com/CompVis/taming-transformers
|
8 |
+
# --------------------------------------------------------'
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.distributed as distributed
|
14 |
+
|
15 |
+
try:
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
except ImportError:
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
def l2norm(t):
|
22 |
+
return F.normalize(t, p=2, dim=-1)
|
23 |
+
|
24 |
+
|
25 |
+
def ema_inplace(moving_avg, new, decay):
|
26 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
27 |
+
|
28 |
+
|
29 |
+
def sample_vectors(samples, num):
|
30 |
+
num_samples, device = samples.shape[0], samples.device
|
31 |
+
|
32 |
+
if num_samples >= num:
|
33 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
34 |
+
else:
|
35 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
36 |
+
|
37 |
+
return samples[indices]
|
38 |
+
|
39 |
+
|
40 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
41 |
+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
42 |
+
|
43 |
+
means = sample_vectors(samples, num_clusters)
|
44 |
+
|
45 |
+
for _ in range(num_iters):
|
46 |
+
if use_cosine_sim:
|
47 |
+
dists = samples @ means.t()
|
48 |
+
else:
|
49 |
+
diffs = rearrange(samples, 'n d -> n () d') \
|
50 |
+
- rearrange(means, 'c d -> () c d')
|
51 |
+
dists = -(diffs ** 2).sum(dim=-1)
|
52 |
+
|
53 |
+
buckets = dists.max(dim=-1).indices
|
54 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
55 |
+
zero_mask = bins == 0
|
56 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
57 |
+
|
58 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
59 |
+
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
|
60 |
+
new_means = new_means / bins_min_clamped[..., None]
|
61 |
+
|
62 |
+
if use_cosine_sim:
|
63 |
+
new_means = l2norm(new_means)
|
64 |
+
|
65 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
66 |
+
|
67 |
+
return means, bins
|
68 |
+
|
69 |
+
|
70 |
+
class EmbeddingEMA(nn.Module):
|
71 |
+
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
|
72 |
+
super().__init__()
|
73 |
+
self.num_tokens = num_tokens
|
74 |
+
self.codebook_dim = codebook_dim
|
75 |
+
self.decay = decay
|
76 |
+
self.eps = eps
|
77 |
+
if codebook_init_path == '':
|
78 |
+
if not kmeans_init:
|
79 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
80 |
+
weight = l2norm(weight)
|
81 |
+
else:
|
82 |
+
weight = torch.zeros(num_tokens, codebook_dim)
|
83 |
+
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
84 |
+
else:
|
85 |
+
print(f"load init codebook weight from {codebook_init_path}")
|
86 |
+
codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
|
87 |
+
weight = codebook_ckpt_weight.clone()
|
88 |
+
self.register_buffer('initted', torch.Tensor([True]))
|
89 |
+
|
90 |
+
self.weight = nn.Parameter(weight, requires_grad=False)
|
91 |
+
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
92 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
93 |
+
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
94 |
+
self.update = True
|
95 |
+
|
96 |
+
@torch.jit.ignore
|
97 |
+
def init_embed_(self, data):
|
98 |
+
if self.initted:
|
99 |
+
return
|
100 |
+
print("Performing Kemans init for codebook")
|
101 |
+
embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
|
102 |
+
self.weight.data.copy_(embed)
|
103 |
+
self.cluster_size.data.copy_(cluster_size)
|
104 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
105 |
+
|
106 |
+
def forward(self, embed_id):
|
107 |
+
return F.embedding(embed_id, self.weight)
|
108 |
+
|
109 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
110 |
+
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
111 |
+
|
112 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
113 |
+
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
114 |
+
|
115 |
+
def weight_update(self, num_tokens):
|
116 |
+
n = self.cluster_size.sum()
|
117 |
+
smoothed_cluster_size = (
|
118 |
+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
119 |
+
)
|
120 |
+
# normalize embedding average with smoothed cluster size
|
121 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
122 |
+
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
|
123 |
+
self.weight.data.copy_(embed_normalized)
|
124 |
+
|
125 |
+
|
126 |
+
def norm_ema_inplace(moving_avg, new, decay):
|
127 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
128 |
+
moving_avg.data.copy_(l2norm(moving_avg.data))
|
129 |
+
|
130 |
+
|
131 |
+
class NormEMAVectorQuantizer(nn.Module):
|
132 |
+
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
133 |
+
statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
|
134 |
+
super().__init__()
|
135 |
+
self.codebook_dim = embedding_dim
|
136 |
+
self.num_tokens = n_embed
|
137 |
+
self.beta = beta
|
138 |
+
self.decay = decay
|
139 |
+
|
140 |
+
# learnable = True if orthogonal_reg_weight > 0 else False
|
141 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
|
142 |
+
|
143 |
+
self.statistic_code_usage = statistic_code_usage
|
144 |
+
if statistic_code_usage:
|
145 |
+
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
146 |
+
if distributed.is_available() and distributed.is_initialized():
|
147 |
+
print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
|
148 |
+
self.all_reduce_fn = distributed.all_reduce
|
149 |
+
else:
|
150 |
+
self.all_reduce_fn = nn.Identity()
|
151 |
+
|
152 |
+
def reset_cluster_size(self, device):
|
153 |
+
if self.statistic_code_usage:
|
154 |
+
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
|
155 |
+
self.cluster_size = self.cluster_size.to(device)
|
156 |
+
|
157 |
+
def forward(self, z):
|
158 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
159 |
+
# z, 'b c h w -> b h w c'
|
160 |
+
# z = rearrange(z, 'b c h w -> b h w c')
|
161 |
+
# z = z.transpose(1, 2)
|
162 |
+
z = l2norm(z)
|
163 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
164 |
+
|
165 |
+
self.embedding.init_embed_(z_flattened)
|
166 |
+
|
167 |
+
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
168 |
+
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
169 |
+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
170 |
+
|
171 |
+
encoding_indices = torch.argmin(d, dim=1)
|
172 |
+
|
173 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
174 |
+
|
175 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
176 |
+
|
177 |
+
if not self.training:
|
178 |
+
with torch.no_grad():
|
179 |
+
cluster_size = encodings.sum(0)
|
180 |
+
self.all_reduce_fn(cluster_size)
|
181 |
+
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
182 |
+
|
183 |
+
if self.training and self.embedding.update:
|
184 |
+
# EMA cluster size
|
185 |
+
|
186 |
+
bins = encodings.sum(0)
|
187 |
+
self.all_reduce_fn(bins)
|
188 |
+
|
189 |
+
# self.embedding.cluster_size_ema_update(bins)
|
190 |
+
ema_inplace(self.cluster_size, bins, self.decay)
|
191 |
+
|
192 |
+
zero_mask = (bins == 0)
|
193 |
+
bins = bins.masked_fill(zero_mask, 1.)
|
194 |
+
|
195 |
+
embed_sum = z_flattened.t() @ encodings
|
196 |
+
self.all_reduce_fn(embed_sum)
|
197 |
+
|
198 |
+
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
199 |
+
embed_normalized = l2norm(embed_normalized)
|
200 |
+
|
201 |
+
embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
|
202 |
+
embed_normalized)
|
203 |
+
norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
|
204 |
+
|
205 |
+
# compute loss for embedding
|
206 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
207 |
+
|
208 |
+
# preserve gradients
|
209 |
+
z_q = z + (z_q - z).detach()
|
210 |
+
|
211 |
+
# reshape back to match original input shape
|
212 |
+
# z_q, 'b h w c -> b c h w'
|
213 |
+
# z_q = rearrange(z_q, 'b h w c -> b c h w')
|
214 |
+
# z_q = z_q.transpose(1, 2)
|
215 |
+
return z_q, loss, encoding_indices
|
slam_llm/models/EAT/EAT.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import random
|
4 |
+
|
5 |
+
def EAT_preprocess(source, norm_mean = -4.268, norm_std = 4.569, target_length = 1024, fixed_length = False, random_crop = False):
|
6 |
+
source = source - source.mean()
|
7 |
+
source = source.unsqueeze(dim=0)
|
8 |
+
|
9 |
+
source = torchaudio.compliance.kaldi.fbank(source, htk_compat=True, sample_frequency=16000, use_energy=False,
|
10 |
+
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10).unsqueeze(dim=0)
|
11 |
+
|
12 |
+
n_frames = source.shape[1]
|
13 |
+
if not fixed_length:
|
14 |
+
target_length = n_frames
|
15 |
+
if target_length % 16 != 0:
|
16 |
+
target_length = n_frames + (16 - n_frames % 16)
|
17 |
+
diff = target_length - n_frames
|
18 |
+
if diff > 0:
|
19 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, diff))
|
20 |
+
source = m(source)
|
21 |
+
elif diff < 0:
|
22 |
+
if random_crop:
|
23 |
+
start_index = random.randint(0, n_frames - target_length)
|
24 |
+
source = source[:,start_index: start_index+target_length, :]
|
25 |
+
else:
|
26 |
+
source = source[:,0:target_length, :]
|
27 |
+
|
28 |
+
# Normalize the mel spectrogram
|
29 |
+
source = (source - norm_mean) / (norm_std * 2)
|
30 |
+
source = source.squeeze()
|
31 |
+
|
32 |
+
return source
|
slam_llm/models/SpatialAST/SpatialAST.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from torchlibrosa.stft import STFT, LogmelFilterBank
|
5 |
+
from timm.models.layers import to_2tuple
|
6 |
+
|
7 |
+
from .vision_transformer import VisionTransformer as _VisionTransformer
|
8 |
+
|
9 |
+
def conv3x3(in_channels, out_channels, stride=1):
|
10 |
+
"3x3 convolution with padding"
|
11 |
+
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
12 |
+
|
13 |
+
class PatchEmbed_new(nn.Module):
|
14 |
+
""" Flexible Image to Patch Embedding
|
15 |
+
"""
|
16 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10):
|
17 |
+
super().__init__()
|
18 |
+
img_size = to_2tuple(img_size)
|
19 |
+
patch_size = to_2tuple(patch_size)
|
20 |
+
stride = to_2tuple(stride)
|
21 |
+
|
22 |
+
self.img_size = img_size
|
23 |
+
self.patch_size = patch_size
|
24 |
+
self.in_chans = in_chans
|
25 |
+
|
26 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
|
27 |
+
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
|
28 |
+
self.patch_hw = (h, w)
|
29 |
+
self.num_patches = h*w
|
30 |
+
|
31 |
+
def get_output_shape(self, img_size):
|
32 |
+
return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
B, C, H, W = x.shape
|
36 |
+
|
37 |
+
x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
|
38 |
+
x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
|
39 |
+
x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
|
40 |
+
return x
|
41 |
+
|
42 |
+
class BinauralEncoder(_VisionTransformer):
|
43 |
+
""" Spatial Audio Spectrogram Transformer designed for Sound Event Localization and Detection
|
44 |
+
--------------------------------------------------------
|
45 |
+
References:
|
46 |
+
Spatial-AST from BAT: https://github.com/zszheng147/Spatial-AST and https://arxiv.org/abs/2402.01591
|
47 |
+
--------------------------------------------------------
|
48 |
+
"""
|
49 |
+
def __init__(self, num_cls_tokens=3, **kwargs):
|
50 |
+
super(BinauralEncoder, self).__init__(**kwargs)
|
51 |
+
img_size = (1024, 128) # 1024, 128
|
52 |
+
in_chans = 1
|
53 |
+
emb_dim = 768
|
54 |
+
|
55 |
+
del self.cls_token
|
56 |
+
self.num_cls_tokens = num_cls_tokens
|
57 |
+
self.cls_tokens = nn.Parameter(torch.zeros(1, num_cls_tokens, emb_dim))
|
58 |
+
|
59 |
+
self.patch_embed = PatchEmbed_new(
|
60 |
+
img_size=img_size, patch_size=(16, 16),
|
61 |
+
in_chans=in_chans, embed_dim=emb_dim, stride=16
|
62 |
+
) # no overlap. stride=img_size=16
|
63 |
+
|
64 |
+
num_patches = self.patch_embed.num_patches
|
65 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False) # fixed sin-cos embedding
|
66 |
+
|
67 |
+
self.spectrogram_extractor = STFT(
|
68 |
+
n_fft=1024, hop_length=320, win_length=1024, window='hann',
|
69 |
+
center=True, pad_mode='reflect', freeze_parameters=True
|
70 |
+
)
|
71 |
+
self.logmel_extractor = LogmelFilterBank(
|
72 |
+
sr=32000, n_fft=1024, n_mels=128, fmin=50,
|
73 |
+
fmax=14000, ref=1.0, amin=1e-10, top_db=None, freeze_parameters=True
|
74 |
+
)
|
75 |
+
|
76 |
+
self.conv_downsample = nn.Sequential(
|
77 |
+
conv3x3(4, 1),
|
78 |
+
nn.BatchNorm2d(1),
|
79 |
+
nn.GELU(),
|
80 |
+
)
|
81 |
+
|
82 |
+
self.bn = nn.BatchNorm2d(2, affine=False)
|
83 |
+
del self.norm # remove the original norm
|
84 |
+
|
85 |
+
self.target_frame = 1024
|
86 |
+
|
87 |
+
def forward_features_mask(self, x):
|
88 |
+
B = x.shape[0] #bsz, 512, 768 (unmasked)
|
89 |
+
|
90 |
+
x = x + self.pos_embed[:, 1:, :]
|
91 |
+
|
92 |
+
cls_tokens = self.cls_tokens
|
93 |
+
cls_tokens = cls_tokens.expand(B, -1, -1)
|
94 |
+
x = torch.cat([cls_tokens, x], dim=1) # bsz, 512 + 2 + 10, 768
|
95 |
+
x = self.pos_drop(x)
|
96 |
+
|
97 |
+
for blk in self.blocks:
|
98 |
+
x = blk(x)
|
99 |
+
|
100 |
+
return x
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
def forward(self, waveforms):
|
104 |
+
B, C, T = waveforms.shape
|
105 |
+
|
106 |
+
waveforms = waveforms.reshape(B * C, T)
|
107 |
+
real, imag = self.spectrogram_extractor(waveforms)
|
108 |
+
|
109 |
+
log_mel = self.logmel_extractor(torch.sqrt(real**2 + imag**2)).reshape(B, C, -1, 128)
|
110 |
+
log_mel = self.bn(log_mel)
|
111 |
+
|
112 |
+
IPD = torch.atan2(imag[1::2], real[1::2]) - torch.atan2(imag[::2], real[::2])
|
113 |
+
x = torch.cat([log_mel, torch.matmul(torch.cat([torch.cos(IPD), torch.sin(IPD)], dim=1), self.logmel_extractor.melW)], dim=1)
|
114 |
+
|
115 |
+
if x.shape[2] < self.target_frame:
|
116 |
+
x = nn.functional.interpolate(x, (self.target_frame, x.shape[3]), mode="bicubic", align_corners=True)
|
117 |
+
|
118 |
+
x = self.conv_downsample(x)
|
119 |
+
x = self.patch_embed(x)
|
120 |
+
x = self.forward_features_mask(x)
|
121 |
+
|
122 |
+
return x
|
slam_llm/models/SpatialAST/vision_transformer.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from timm.models.layers import to_2tuple, DropPath, trunc_normal_
|
5 |
+
|
6 |
+
|
7 |
+
class HybridEmbed(nn.Module):
|
8 |
+
""" CNN Feature Map Embedding
|
9 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
10 |
+
"""
|
11 |
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
12 |
+
super().__init__()
|
13 |
+
assert isinstance(backbone, nn.Module)
|
14 |
+
img_size = to_2tuple(img_size)
|
15 |
+
self.img_size = img_size
|
16 |
+
self.backbone = backbone
|
17 |
+
if feature_size is None:
|
18 |
+
with torch.no_grad():
|
19 |
+
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
20 |
+
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
21 |
+
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
22 |
+
training = backbone.training
|
23 |
+
if training:
|
24 |
+
backbone.eval()
|
25 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
26 |
+
feature_size = o.shape[-2:]
|
27 |
+
feature_dim = o.shape[1]
|
28 |
+
backbone.train(training)
|
29 |
+
else:
|
30 |
+
feature_size = to_2tuple(feature_size)
|
31 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
32 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
33 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.backbone(x)[-1]
|
37 |
+
x = x.flatten(2).transpose(1, 2)
|
38 |
+
x = self.proj(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
class Mlp(nn.Module):
|
42 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
43 |
+
super().__init__()
|
44 |
+
out_features = out_features or in_features
|
45 |
+
hidden_features = hidden_features or in_features
|
46 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
47 |
+
self.act = act_layer()
|
48 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
49 |
+
self.drop = nn.Dropout(drop)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = self.fc1(x)
|
53 |
+
x = self.act(x)
|
54 |
+
x = self.drop(x)
|
55 |
+
x = self.fc2(x)
|
56 |
+
x = self.drop(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
class Attention(nn.Module):
|
60 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
61 |
+
super().__init__()
|
62 |
+
self.num_heads = num_heads
|
63 |
+
head_dim = dim // num_heads
|
64 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
65 |
+
self.scale = qk_scale or head_dim ** -0.5
|
66 |
+
|
67 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
68 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
69 |
+
self.proj = nn.Linear(dim, dim)
|
70 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
B, N, C = x.shape
|
75 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
76 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
77 |
+
|
78 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
79 |
+
attn = attn.softmax(dim=-1)
|
80 |
+
attn = self.attn_drop(attn)
|
81 |
+
|
82 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
83 |
+
x = self.proj(x)
|
84 |
+
x = self.proj_drop(x)
|
85 |
+
return x
|
86 |
+
|
87 |
+
class Block(nn.Module):
|
88 |
+
|
89 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
90 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
91 |
+
super().__init__()
|
92 |
+
self.norm1 = norm_layer(dim)
|
93 |
+
self.attn = Attention(
|
94 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
95 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
96 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
97 |
+
self.norm2 = norm_layer(dim)
|
98 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
99 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
103 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
104 |
+
return x
|
105 |
+
|
106 |
+
class PatchEmbed(nn.Module):
|
107 |
+
""" Image to Patch Embedding
|
108 |
+
"""
|
109 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
110 |
+
super().__init__()
|
111 |
+
img_size = to_2tuple(img_size)
|
112 |
+
patch_size = to_2tuple(patch_size)
|
113 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
114 |
+
self.img_size = img_size
|
115 |
+
self.patch_size = patch_size
|
116 |
+
self.num_patches = num_patches
|
117 |
+
|
118 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
B, C, H, W = x.shape
|
122 |
+
# FIXME look at relaxing size constraints
|
123 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
124 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
125 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
class PatchEmbed_new(nn.Module):
|
130 |
+
""" Flexible Image to Patch Embedding
|
131 |
+
"""
|
132 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10):
|
133 |
+
super().__init__()
|
134 |
+
img_size = to_2tuple(img_size)
|
135 |
+
patch_size = to_2tuple(patch_size)
|
136 |
+
stride = to_2tuple(stride)
|
137 |
+
|
138 |
+
self.img_size = img_size
|
139 |
+
self.patch_size = patch_size
|
140 |
+
self.in_chans = in_chans
|
141 |
+
|
142 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
|
143 |
+
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
|
144 |
+
self.patch_hw = (h, w)
|
145 |
+
self.num_patches = h*w
|
146 |
+
|
147 |
+
def get_output_shape(self, img_size):
|
148 |
+
return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
B, C, H, W = x.shape
|
152 |
+
|
153 |
+
x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
|
154 |
+
x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
|
155 |
+
x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class VisionTransformer(nn.Module):
|
160 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
161 |
+
"""
|
162 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
163 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
164 |
+
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
165 |
+
super().__init__()
|
166 |
+
self.num_classes = num_classes
|
167 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
168 |
+
|
169 |
+
if hybrid_backbone is not None:
|
170 |
+
self.patch_embed = HybridEmbed(
|
171 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
172 |
+
else:
|
173 |
+
self.patch_embed = PatchEmbed(
|
174 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
175 |
+
num_patches = self.patch_embed.num_patches
|
176 |
+
|
177 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
178 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
179 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
180 |
+
|
181 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
182 |
+
self.blocks = nn.ModuleList([
|
183 |
+
Block(
|
184 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
185 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
186 |
+
for i in range(depth)])
|
187 |
+
|
188 |
+
self.norm = norm_layer(embed_dim)
|
189 |
+
|
190 |
+
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
|
191 |
+
#self.repr = nn.Linear(embed_dim, representation_size)
|
192 |
+
#self.repr_act = nn.Tanh()
|
193 |
+
|
194 |
+
# Classifier head
|
195 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
196 |
+
|
197 |
+
trunc_normal_(self.pos_embed, std=.02)
|
198 |
+
trunc_normal_(self.cls_token, std=.02)
|
199 |
+
self.apply(self._init_weights)
|
200 |
+
|
201 |
+
def _init_weights(self, m):
|
202 |
+
if isinstance(m, nn.Linear):
|
203 |
+
trunc_normal_(m.weight, std=.02)
|
204 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
205 |
+
nn.init.constant_(m.bias, 0)
|
206 |
+
elif isinstance(m, nn.LayerNorm):
|
207 |
+
nn.init.constant_(m.bias, 0)
|
208 |
+
nn.init.constant_(m.weight, 1.0)
|
209 |
+
|
210 |
+
@torch.jit.ignore
|
211 |
+
def no_weight_decay(self):
|
212 |
+
return {'pos_embed', 'cls_token'}
|
213 |
+
|
214 |
+
def get_classifier(self):
|
215 |
+
return self.head
|
216 |
+
|
217 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
218 |
+
self.num_classes = num_classes
|
219 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
220 |
+
|
221 |
+
def forward_features(self, x):
|
222 |
+
B = x.shape[0]
|
223 |
+
x = self.patch_embed(x)
|
224 |
+
|
225 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
226 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
227 |
+
x = x + self.pos_embed
|
228 |
+
x = self.pos_drop(x)
|
229 |
+
|
230 |
+
for blk in self.blocks:
|
231 |
+
x = blk(x)
|
232 |
+
|
233 |
+
x = self.norm(x)
|
234 |
+
return x[:, 0]
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
x = self.forward_features(x)
|
238 |
+
x = self.head(x)
|
239 |
+
return x
|
slam_llm/models/avhubert/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .hubert import * # noqa
|
7 |
+
from .hubert_asr import * # noqa
|
8 |
+
from .hubert_dataset import *
|
9 |
+
from .hubert_pretraining import *
|
10 |
+
from .hubert_criterion import *
|
slam_llm/models/avhubert/decoder.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from argparse import Namespace
|
8 |
+
import contextlib
|
9 |
+
import copy
|
10 |
+
import math
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
from omegaconf import MISSING, II, open_dict
|
17 |
+
from typing import Any, Optional
|
18 |
+
|
19 |
+
from fairseq import checkpoint_utils, tasks, utils
|
20 |
+
from fairseq.dataclass import FairseqDataclass
|
21 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
22 |
+
from fairseq.tasks import FairseqTask
|
23 |
+
from fairseq.models import (
|
24 |
+
BaseFairseqModel,
|
25 |
+
FairseqEncoder,
|
26 |
+
FairseqEncoderDecoderModel,
|
27 |
+
FairseqIncrementalDecoder,
|
28 |
+
register_model,
|
29 |
+
)
|
30 |
+
# from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
|
31 |
+
from fairseq.modules import (
|
32 |
+
LayerNorm,
|
33 |
+
PositionalEmbedding,
|
34 |
+
TransformerDecoderLayer,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
class TransformerDecoder(FairseqIncrementalDecoder):
|
39 |
+
"""
|
40 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
41 |
+
is a :class:`TransformerDecoderLayer`.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
args (argparse.Namespace): parsed command-line arguments
|
45 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
46 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
47 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
48 |
+
(default: False).
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
cfg,
|
54 |
+
dictionary,
|
55 |
+
embed_tokens,
|
56 |
+
no_encoder_attn=False,
|
57 |
+
):
|
58 |
+
super().__init__(dictionary)
|
59 |
+
|
60 |
+
self.dropout = cfg.decoder_dropout
|
61 |
+
self.share_input_output_embed = cfg.share_decoder_input_output_embed
|
62 |
+
|
63 |
+
input_embed_dim = embed_tokens.embedding_dim
|
64 |
+
embed_dim = cfg.decoder_embed_dim
|
65 |
+
self.output_embed_dim = cfg.decoder_embed_dim
|
66 |
+
|
67 |
+
self.layerdrop = cfg.decoder_layerdrop
|
68 |
+
|
69 |
+
padding_idx = embed_tokens.padding_idx
|
70 |
+
self.max_target_positions = cfg.max_target_positions
|
71 |
+
|
72 |
+
self.embed_tokens = embed_tokens
|
73 |
+
# self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
|
74 |
+
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
|
75 |
+
|
76 |
+
self.project_in_dim = (
|
77 |
+
Linear(input_embed_dim, embed_dim, bias=False)
|
78 |
+
if embed_dim != input_embed_dim
|
79 |
+
else None
|
80 |
+
)
|
81 |
+
|
82 |
+
self.embed_positions = (
|
83 |
+
PositionalEmbedding(
|
84 |
+
cfg.max_target_positions,
|
85 |
+
embed_dim,
|
86 |
+
padding_idx,
|
87 |
+
learned=cfg.decoder_learned_pos,
|
88 |
+
)
|
89 |
+
if not cfg.no_token_positional_embeddings
|
90 |
+
else None
|
91 |
+
)
|
92 |
+
|
93 |
+
# TODO: update this when transformer gets converted to dataclass configs
|
94 |
+
transformer_cfg = copy.deepcopy(cfg)
|
95 |
+
# with open_dict(transformer_cfg):
|
96 |
+
transformer_cfg.dropout = transformer_cfg.decoder_dropout
|
97 |
+
transformer_cfg.attention_dropout = (
|
98 |
+
transformer_cfg.decoder_attention_dropout
|
99 |
+
)
|
100 |
+
transformer_cfg.activation_dropout = (
|
101 |
+
transformer_cfg.decoder_activation_dropout
|
102 |
+
)
|
103 |
+
|
104 |
+
self.layers = nn.ModuleList([])
|
105 |
+
self.layers.extend(
|
106 |
+
[
|
107 |
+
TransformerDecoderLayer(transformer_cfg, no_encoder_attn)
|
108 |
+
for _ in range(transformer_cfg.decoder_layers)
|
109 |
+
]
|
110 |
+
)
|
111 |
+
|
112 |
+
if not self.share_input_output_embed:
|
113 |
+
self.embed_out = nn.Parameter(
|
114 |
+
torch.Tensor(len(dictionary), self.output_embed_dim)
|
115 |
+
)
|
116 |
+
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
|
117 |
+
|
118 |
+
if transformer_cfg.decoder_normalize_before:
|
119 |
+
self.layer_norm = LayerNorm(embed_dim)
|
120 |
+
else:
|
121 |
+
self.layer_norm = None
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
Args:
|
128 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
129 |
+
`(batch, tgt_len)`, for teacher forcing
|
130 |
+
encoder_out (Tensor, optional): output from the encoder, used for
|
131 |
+
encoder-side attention
|
132 |
+
incremental_state (dict): dictionary used for storing state during
|
133 |
+
:ref:`Incremental decoding`
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
tuple:
|
137 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
138 |
+
- a dictionary with any model-specific outputs
|
139 |
+
"""
|
140 |
+
prev_output_tokens = prev_output_tokens.long()
|
141 |
+
x, extra = self.extract_features(
|
142 |
+
prev_output_tokens, encoder_out, incremental_state
|
143 |
+
)
|
144 |
+
x = self.output_layer(x)
|
145 |
+
return x, extra
|
146 |
+
|
147 |
+
def extract_features(
|
148 |
+
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
|
149 |
+
):
|
150 |
+
"""
|
151 |
+
Similar to *forward* but only return features.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
tuple:
|
155 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
156 |
+
- a dictionary with any model-specific outputs
|
157 |
+
"""
|
158 |
+
|
159 |
+
# embed positions
|
160 |
+
positions = (
|
161 |
+
self.embed_positions(
|
162 |
+
prev_output_tokens, incremental_state=incremental_state
|
163 |
+
)
|
164 |
+
if self.embed_positions is not None
|
165 |
+
else None
|
166 |
+
)
|
167 |
+
|
168 |
+
if incremental_state is not None:
|
169 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
170 |
+
if positions is not None:
|
171 |
+
positions = positions[:, -1:]
|
172 |
+
|
173 |
+
# embed tokens and positions
|
174 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
175 |
+
|
176 |
+
if self.project_in_dim is not None:
|
177 |
+
x = self.project_in_dim(x)
|
178 |
+
|
179 |
+
if positions is not None:
|
180 |
+
x += positions
|
181 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
182 |
+
|
183 |
+
# B x T x C -> T x B x C
|
184 |
+
x = x.transpose(0, 1)
|
185 |
+
attn = None
|
186 |
+
|
187 |
+
inner_states = [x]
|
188 |
+
|
189 |
+
# decoder layers
|
190 |
+
for layer in self.layers:
|
191 |
+
dropout_probability = np.random.random()
|
192 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
193 |
+
x, attn, _ = layer(
|
194 |
+
x,
|
195 |
+
encoder_out["encoder_out"] if encoder_out is not None else None,
|
196 |
+
encoder_out["padding_mask"] if encoder_out is not None else None,
|
197 |
+
incremental_state,
|
198 |
+
self_attn_mask=self.buffered_future_mask(x)
|
199 |
+
if incremental_state is None
|
200 |
+
else None,
|
201 |
+
)
|
202 |
+
inner_states.append(x)
|
203 |
+
|
204 |
+
if self.layer_norm:
|
205 |
+
x = self.layer_norm(x)
|
206 |
+
|
207 |
+
# T x B x C -> B x T x C
|
208 |
+
x = x.transpose(0, 1)
|
209 |
+
|
210 |
+
return x, {"attn": attn, "inner_states": inner_states}
|
211 |
+
|
212 |
+
def output_layer(self, features, **kwargs):
|
213 |
+
"""Project features to the vocabulary size."""
|
214 |
+
# project back to size of vocabulary
|
215 |
+
emb_mat = self.embed_tokens.weight if self.share_input_output_embed else self.embed_out
|
216 |
+
return torch.matmul(features, emb_mat.transpose(0, 1))
|
217 |
+
# if self.share_input_output_embed:
|
218 |
+
# return F.linear(features, self.embed_tokens.weight)
|
219 |
+
# else:
|
220 |
+
# return F.linear(features, self.embed_out)
|
221 |
+
|
222 |
+
def max_positions(self):
|
223 |
+
"""Maximum output length supported by the decoder."""
|
224 |
+
if self.embed_positions is None:
|
225 |
+
return self.max_target_positions
|
226 |
+
return min(self.max_target_positions, self.embed_positions.max_positions)
|
227 |
+
|
228 |
+
def buffered_future_mask(self, tensor):
|
229 |
+
dim = tensor.size(0)
|
230 |
+
if (
|
231 |
+
not hasattr(self, "_future_mask")
|
232 |
+
or self._future_mask is None
|
233 |
+
or self._future_mask.device != tensor.device
|
234 |
+
or self._future_mask.size(0) < dim
|
235 |
+
):
|
236 |
+
self._future_mask = torch.triu(
|
237 |
+
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
238 |
+
)
|
239 |
+
return self._future_mask[:dim, :dim]
|
240 |
+
|
241 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
242 |
+
return state_dict
|
243 |
+
|
slam_llm/models/avhubert/hubert.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os,sys
|
8 |
+
import logging
|
9 |
+
from typing import Dict, List, Optional, Tuple
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
from fairseq import utils
|
17 |
+
from fairseq.data.data_utils import compute_mask_indices
|
18 |
+
from fairseq.data.dictionary import Dictionary
|
19 |
+
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
20 |
+
from fairseq.models import BaseFairseqModel, register_model
|
21 |
+
from fairseq.models.wav2vec.wav2vec2 import (
|
22 |
+
ConvFeatureExtractionModel,
|
23 |
+
TransformerEncoder,
|
24 |
+
)
|
25 |
+
from fairseq.modules import GradMultiply, LayerNorm
|
26 |
+
from copy import deepcopy
|
27 |
+
|
28 |
+
DBG=True if len(sys.argv) == 1 else False
|
29 |
+
|
30 |
+
if DBG:
|
31 |
+
from hubert_pretraining import (
|
32 |
+
AVHubertPretrainingConfig,
|
33 |
+
AVHubertPretrainingTask,
|
34 |
+
)
|
35 |
+
from resnet import ResEncoder
|
36 |
+
logging.basicConfig(
|
37 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
38 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
39 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
40 |
+
stream=sys.stdout,
|
41 |
+
)
|
42 |
+
from utils import compute_mask_indices
|
43 |
+
from decoder import TransformerDecoder
|
44 |
+
|
45 |
+
else:
|
46 |
+
from .hubert_pretraining import (
|
47 |
+
AVHubertPretrainingConfig,
|
48 |
+
AVHubertPretrainingTask,
|
49 |
+
)
|
50 |
+
from .resnet import ResEncoder
|
51 |
+
from .utils import compute_mask_indices
|
52 |
+
from .decoder import TransformerDecoder
|
53 |
+
|
54 |
+
from omegaconf import II
|
55 |
+
|
56 |
+
logger = logging.getLogger(__name__)
|
57 |
+
|
58 |
+
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
|
59 |
+
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(
|
60 |
+
["static", "uniform", "normal", "poisson"]
|
61 |
+
)
|
62 |
+
# LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"])
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class AVHubertConfig(FairseqDataclass):
|
67 |
+
label_rate: int = II("task.label_rate")
|
68 |
+
input_modality: str = II("task.input_modality")
|
69 |
+
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
|
70 |
+
default="default",
|
71 |
+
metadata={
|
72 |
+
"help": "mode for feature extractor. default has a single group "
|
73 |
+
"norm with d groups in the first conv block, whereas layer_norm "
|
74 |
+
"has layer norms in every block (meant to use with normalize=True)"
|
75 |
+
},
|
76 |
+
)
|
77 |
+
encoder_layers: int = field(
|
78 |
+
default=12, metadata={"help": "num encoder layers in the transformer"}
|
79 |
+
)
|
80 |
+
encoder_embed_dim: int = field(
|
81 |
+
default=768, metadata={"help": "encoder embedding dimension"}
|
82 |
+
)
|
83 |
+
encoder_ffn_embed_dim: int = field(
|
84 |
+
default=3072, metadata={"help": "encoder embedding dimension for FFN"}
|
85 |
+
)
|
86 |
+
encoder_attention_heads: int = field(
|
87 |
+
default=12, metadata={"help": "num encoder attention heads"}
|
88 |
+
)
|
89 |
+
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
90 |
+
default="gelu", metadata={"help": "activation function to use"}
|
91 |
+
)
|
92 |
+
|
93 |
+
# dropouts
|
94 |
+
dropout: float = field(
|
95 |
+
default=0.1,
|
96 |
+
metadata={"help": "dropout probability for the transformer"},
|
97 |
+
)
|
98 |
+
attention_dropout: float = field(
|
99 |
+
default=0.1,
|
100 |
+
metadata={"help": "dropout probability for attention weights"},
|
101 |
+
)
|
102 |
+
activation_dropout: float = field(
|
103 |
+
default=0.0,
|
104 |
+
metadata={"help": "dropout probability after activation in FFN"},
|
105 |
+
)
|
106 |
+
encoder_layerdrop: float = field(
|
107 |
+
default=0.0,
|
108 |
+
metadata={"help": "probability of dropping a tarnsformer layer"},
|
109 |
+
)
|
110 |
+
dropout_input: float = field(
|
111 |
+
default=0.0,
|
112 |
+
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
113 |
+
)
|
114 |
+
dropout_features: float = field(
|
115 |
+
default=0.0,
|
116 |
+
metadata={
|
117 |
+
"help": "dropout to apply to the features (after feat extr)"
|
118 |
+
},
|
119 |
+
)
|
120 |
+
|
121 |
+
final_dim: int = field(
|
122 |
+
default=0,
|
123 |
+
metadata={
|
124 |
+
"help": "project final representations and targets to this many "
|
125 |
+
"dimensions. set to encoder_embed_dim is <= 0"
|
126 |
+
},
|
127 |
+
)
|
128 |
+
untie_final_proj: bool = field(
|
129 |
+
default=False,
|
130 |
+
metadata={"help": "use separate projection for each target"},
|
131 |
+
)
|
132 |
+
layer_norm_first: bool = field(
|
133 |
+
default=False,
|
134 |
+
metadata={"help": "apply layernorm first in the transformer"},
|
135 |
+
)
|
136 |
+
conv_feature_layers: str = field(
|
137 |
+
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
138 |
+
metadata={
|
139 |
+
"help": "string describing convolutional feature extraction "
|
140 |
+
"layers in form of a python list that contains "
|
141 |
+
"[(dim, kernel_size, stride), ...]"
|
142 |
+
},
|
143 |
+
)
|
144 |
+
conv_bias: bool = field(
|
145 |
+
default=False, metadata={"help": "include bias in conv encoder"}
|
146 |
+
)
|
147 |
+
logit_temp: float = field(
|
148 |
+
default=0.1, metadata={"help": "temperature to divide logits by"}
|
149 |
+
)
|
150 |
+
target_glu: bool = field(
|
151 |
+
default=False, metadata={"help": "adds projection + glu to targets"}
|
152 |
+
)
|
153 |
+
feature_grad_mult: float = field(
|
154 |
+
default=1.0,
|
155 |
+
metadata={"help": "multiply feature extractor var grads by this"},
|
156 |
+
)
|
157 |
+
|
158 |
+
# masking
|
159 |
+
mask_length_audio: int = field(default=10, metadata={"help": "mask length"})
|
160 |
+
mask_prob_audio: float = field(
|
161 |
+
default=0.65,
|
162 |
+
metadata={"help": "probability of replacing a token with mask"},
|
163 |
+
)
|
164 |
+
mask_length_image: int = field(default=10, metadata={"help": "mask length"})
|
165 |
+
mask_prob_image: float = field(
|
166 |
+
default=0.65,
|
167 |
+
metadata={"help": "probability of replacing a token with mask"},
|
168 |
+
)
|
169 |
+
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
170 |
+
default="static", metadata={"help": "how to choose mask length"}
|
171 |
+
)
|
172 |
+
mask_other: float = field(
|
173 |
+
default=0,
|
174 |
+
metadata={
|
175 |
+
"help": "secondary mask argument "
|
176 |
+
"(used for more complex distributions), "
|
177 |
+
"see help in compute_mask_indicesh"
|
178 |
+
},
|
179 |
+
)
|
180 |
+
no_mask_overlap: bool = field(
|
181 |
+
default=False, metadata={"help": "whether to allow masks to overlap"}
|
182 |
+
)
|
183 |
+
mask_min_space: int = field(
|
184 |
+
default=1,
|
185 |
+
metadata={
|
186 |
+
"help": "min space between spans (if no overlap is enabled)"
|
187 |
+
},
|
188 |
+
)
|
189 |
+
|
190 |
+
# channel masking
|
191 |
+
mask_channel_length: int = field(
|
192 |
+
default=10,
|
193 |
+
metadata={"help": "length of the mask for features (channels)"},
|
194 |
+
)
|
195 |
+
mask_channel_prob: float = field(
|
196 |
+
default=0.0,
|
197 |
+
metadata={"help": "probability of replacing a feature with 0"},
|
198 |
+
)
|
199 |
+
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
200 |
+
default="static",
|
201 |
+
metadata={"help": "how to choose mask length for channel masking"},
|
202 |
+
)
|
203 |
+
mask_channel_other: float = field(
|
204 |
+
default=0,
|
205 |
+
metadata={
|
206 |
+
"help": "secondary mask argument "
|
207 |
+
"(used for more complex distributions), "
|
208 |
+
"see help in compute_mask_indicesh"
|
209 |
+
},
|
210 |
+
)
|
211 |
+
no_mask_channel_overlap: bool = field(
|
212 |
+
default=False,
|
213 |
+
metadata={"help": "whether to allow channel masks to overlap"},
|
214 |
+
)
|
215 |
+
mask_channel_min_space: int = field(
|
216 |
+
default=1,
|
217 |
+
metadata={
|
218 |
+
"help": "min space between spans (if no overlap is enabled)"
|
219 |
+
},
|
220 |
+
)
|
221 |
+
|
222 |
+
# positional embeddings
|
223 |
+
conv_pos: int = field(
|
224 |
+
default=128,
|
225 |
+
metadata={
|
226 |
+
"help": "number of filters for convolutional positional embeddings"
|
227 |
+
},
|
228 |
+
)
|
229 |
+
conv_pos_groups: int = field(
|
230 |
+
default=16,
|
231 |
+
metadata={
|
232 |
+
"help": "number of groups for convolutional positional embedding"
|
233 |
+
},
|
234 |
+
)
|
235 |
+
|
236 |
+
latent_temp: Tuple[float, float, float] = field(
|
237 |
+
default=(2, 0.5, 0.999995),
|
238 |
+
metadata={"help": "legacy (to be removed)"},
|
239 |
+
)
|
240 |
+
|
241 |
+
# loss computation
|
242 |
+
skip_masked: bool = field(
|
243 |
+
default=False,
|
244 |
+
metadata={"help": "skip computing losses over masked frames"},
|
245 |
+
)
|
246 |
+
skip_nomask: bool = field(
|
247 |
+
default=False,
|
248 |
+
metadata={"help": "skip computing losses over unmasked frames"},
|
249 |
+
)
|
250 |
+
resnet_relu_type: str = field(default='prelu', metadata={"help": 'relu type for resnet'})
|
251 |
+
resnet_weights: Optional[str] = field(default=None, metadata={"help": 'resnet weights'})
|
252 |
+
sim_type: str = field(default='cosine', metadata={"help": 'similarity type'})
|
253 |
+
|
254 |
+
sub_encoder_layers: int = field(default=0, metadata={'help': 'number of transformer layers for single modality'})
|
255 |
+
audio_feat_dim: int = field(default=-1, metadata={'help': 'audio feature dimension'})
|
256 |
+
modality_dropout: float = field(default=0, metadata={'help': 'drop one modality'})
|
257 |
+
audio_dropout: float = field(default=0, metadata={'help': 'drop audio feature'})
|
258 |
+
modality_fuse: str = field(default='concat', metadata={'help': 'fusing two modalities: add,concat'})
|
259 |
+
selection_type : str = field(default='same_other_seq', metadata={'help': 'type of selectig images, same_other_seq: replace masked span with span from another sequence, same_seq: repace masked span with span of the same sequence'})
|
260 |
+
masking_type : str = field(default='input', metadata={'help': 'input or feature masking'})
|
261 |
+
|
262 |
+
decoder_embed_dim: int = field(
|
263 |
+
default=768, metadata={"help": "decoder embedding dimension"}
|
264 |
+
)
|
265 |
+
decoder_ffn_embed_dim: int = field(
|
266 |
+
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
267 |
+
)
|
268 |
+
decoder_layers: int = field(
|
269 |
+
default=6, metadata={"help": "num of decoder layers"}
|
270 |
+
)
|
271 |
+
decoder_layerdrop: float = field(
|
272 |
+
default=0.0, metadata={"help": "decoder layerdrop chance"}
|
273 |
+
)
|
274 |
+
decoder_attention_heads: int = field(
|
275 |
+
default=4, metadata={"help": "num decoder attention heads"}
|
276 |
+
)
|
277 |
+
decoder_learned_pos: bool = field(
|
278 |
+
default=False,
|
279 |
+
metadata={"help": "use learned positional embeddings in the decoder"},
|
280 |
+
)
|
281 |
+
decoder_normalize_before: bool = field(
|
282 |
+
default=False,
|
283 |
+
metadata={"help": "apply layernorm before each decoder block"},
|
284 |
+
)
|
285 |
+
no_token_positional_embeddings: bool = field(
|
286 |
+
default=False,
|
287 |
+
metadata={
|
288 |
+
"help": "if set, disables positional embeddings "
|
289 |
+
"(outside self attention)"
|
290 |
+
},
|
291 |
+
)
|
292 |
+
decoder_dropout: float = field(
|
293 |
+
default=0.1, metadata={"help": "dropout probability in the decoder"}
|
294 |
+
)
|
295 |
+
decoder_attention_dropout: float = field(
|
296 |
+
default=0.1,
|
297 |
+
metadata={
|
298 |
+
"help": "dropout probability for attention weights "
|
299 |
+
"inside the decoder"
|
300 |
+
},
|
301 |
+
)
|
302 |
+
decoder_activation_dropout: float = field(
|
303 |
+
default=0.0,
|
304 |
+
metadata={
|
305 |
+
"help": "dropout probability after activation in FFN "
|
306 |
+
"inside the decoder"
|
307 |
+
},
|
308 |
+
)
|
309 |
+
max_target_positions: int = field(
|
310 |
+
default=2048, metadata={"help": "max target positions"}
|
311 |
+
)
|
312 |
+
share_decoder_input_output_embed: bool = field(
|
313 |
+
default=False,
|
314 |
+
metadata={"help": "share decoder input and output embeddings"},
|
315 |
+
)
|
316 |
+
no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
|
317 |
+
|
318 |
+
# # new fairseq
|
319 |
+
# required_seq_len_multiple: int = field(
|
320 |
+
# default=1,
|
321 |
+
# metadata={
|
322 |
+
# "help": "pad the input to encoder such that the sequence length is divisible by multiple"
|
323 |
+
# },
|
324 |
+
# )
|
325 |
+
|
326 |
+
# layer_type: LAYER_TYPE_CHOICES = field(
|
327 |
+
# default="transformer", metadata={"help": "layer type in encoder"}
|
328 |
+
# )
|
329 |
+
|
330 |
+
class SubModel(nn.Module):
|
331 |
+
def __init__(self, resnet=None, input_dim=None, cfg=None):
|
332 |
+
super().__init__()
|
333 |
+
self.resnet = resnet
|
334 |
+
self.proj = nn.Linear(input_dim, cfg.encoder_embed_dim)
|
335 |
+
self.encoder = TransformerEncoder(cfg) if cfg.encoder_layers > 0 else None
|
336 |
+
|
337 |
+
def forward(self, x): #torch.Size([1, 1, 106, 112, 112])
|
338 |
+
if self.resnet is not None:
|
339 |
+
x = self.resnet(x) #torch.Size([1, 512, 106]) #torch.Size([12, 26, 314])
|
340 |
+
x = self.proj(x.transpose(1, 2)) #audio是 Linear(in_features=104, out_features=1024, bias=True) 太他妈扯了吧
|
341 |
+
if self.encoder is not None:
|
342 |
+
x = self.encoder(x)[0].transpose(1, 2)
|
343 |
+
else: #
|
344 |
+
x = x.transpose(1, 2)
|
345 |
+
return x #torch.Size([1, 1024, 106])
|
346 |
+
|
347 |
+
@register_model("av_hubert", dataclass=AVHubertConfig)
|
348 |
+
class AVHubertModel(BaseFairseqModel):
|
349 |
+
def __init__(
|
350 |
+
self,
|
351 |
+
cfg: AVHubertConfig,
|
352 |
+
task_cfg: AVHubertPretrainingConfig,
|
353 |
+
dictionaries: List[Dictionary],
|
354 |
+
**kwargs
|
355 |
+
) -> None:
|
356 |
+
super().__init__()
|
357 |
+
logger.info(f"HubertModel Config: {cfg}")
|
358 |
+
|
359 |
+
feature_ds_rate = 1
|
360 |
+
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
361 |
+
sub_cfg = deepcopy(cfg)
|
362 |
+
sub_cfg.encoder_layers = sub_cfg.sub_encoder_layers
|
363 |
+
resnet = ResEncoder(relu_type=cfg.resnet_relu_type, weights=cfg.resnet_weights)
|
364 |
+
self.feature_extractor_audio = SubModel(resnet=None, input_dim=cfg.audio_feat_dim, cfg=sub_cfg)
|
365 |
+
self.feature_extractor_video = SubModel(resnet=resnet, input_dim=resnet.backend_out, cfg=sub_cfg)
|
366 |
+
self.modality_dropout, self.audio_dropout = cfg.modality_dropout, cfg.audio_dropout
|
367 |
+
self.modality_fuse = cfg.modality_fuse
|
368 |
+
self.encoder_embed_dim = cfg.encoder_embed_dim
|
369 |
+
if self.modality_fuse == 'concat':
|
370 |
+
self.embed = cfg.encoder_embed_dim * 2
|
371 |
+
elif self.modality_fuse == 'add':
|
372 |
+
self.embed = cfg.encoder_embed_dim
|
373 |
+
self.post_extract_proj = (
|
374 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
375 |
+
if self.embed != cfg.encoder_embed_dim
|
376 |
+
else None
|
377 |
+
)
|
378 |
+
|
379 |
+
self.mask_prob_image, self.mask_prob_audio = cfg.mask_prob_image, cfg.mask_prob_audio
|
380 |
+
self.mask_selection = cfg.mask_selection
|
381 |
+
self.mask_other = cfg.mask_other
|
382 |
+
self.mask_length_image, self.mask_length_audio = cfg.mask_length_image, cfg.mask_length_audio
|
383 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
384 |
+
self.mask_min_space = cfg.mask_min_space
|
385 |
+
|
386 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
387 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
388 |
+
self.mask_channel_other = cfg.mask_channel_other
|
389 |
+
self.mask_channel_length = cfg.mask_channel_length
|
390 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
391 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
392 |
+
|
393 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
394 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
395 |
+
|
396 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
397 |
+
self.logit_temp = cfg.logit_temp
|
398 |
+
self.skip_masked = cfg.skip_masked
|
399 |
+
self.skip_nomask = cfg.skip_nomask
|
400 |
+
self.sim_type = cfg.sim_type
|
401 |
+
self.selection_type = cfg.selection_type
|
402 |
+
self.masking_type = cfg.masking_type
|
403 |
+
|
404 |
+
final_dim = (
|
405 |
+
cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
406 |
+
)
|
407 |
+
|
408 |
+
self.mask_emb = nn.Parameter(
|
409 |
+
torch.FloatTensor(cfg.audio_feat_dim).uniform_() if self.masking_type == 'input' else torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
410 |
+
)
|
411 |
+
|
412 |
+
self.encoder = TransformerEncoder(cfg)
|
413 |
+
self.layer_norm = LayerNorm(self.embed)
|
414 |
+
|
415 |
+
self.target_glu = None
|
416 |
+
if cfg.target_glu:
|
417 |
+
self.target_glu = nn.Sequential(
|
418 |
+
nn.Linear(final_dim, final_dim * 2), nn.GLU()
|
419 |
+
)
|
420 |
+
|
421 |
+
self.untie_final_proj = cfg.untie_final_proj
|
422 |
+
if self.untie_final_proj:
|
423 |
+
self.final_proj = nn.Linear(
|
424 |
+
cfg.encoder_embed_dim, final_dim * len(dictionaries)
|
425 |
+
)
|
426 |
+
else:
|
427 |
+
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
428 |
+
|
429 |
+
# modules below are not needed during fine-tuning
|
430 |
+
if any([d is None for d in dictionaries]):
|
431 |
+
logger.info(
|
432 |
+
"cannot find dictionary. assume will be used for fine-tuning"
|
433 |
+
)
|
434 |
+
else:
|
435 |
+
self.num_classes = [len(d) for d in dictionaries]
|
436 |
+
self.label_embs_concat = nn.Parameter(
|
437 |
+
torch.FloatTensor(sum(self.num_classes), final_dim)
|
438 |
+
)
|
439 |
+
nn.init.uniform_(self.label_embs_concat)
|
440 |
+
|
441 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
442 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
443 |
+
|
444 |
+
super().upgrade_state_dict_named(state_dict, name)
|
445 |
+
return state_dict
|
446 |
+
|
447 |
+
@classmethod
|
448 |
+
def build_model(cls, cfg: AVHubertConfig, task: AVHubertPretrainingTask):
|
449 |
+
"""Build a new model instance."""
|
450 |
+
|
451 |
+
kwargs = {}
|
452 |
+
model = AVHubertModel(cfg, task.cfg, task.dictionaries, **kwargs)
|
453 |
+
return model
|
454 |
+
|
455 |
+
def apply_input_mask(self, x, padding_mask, target_list):
|
456 |
+
B, C, T = x.shape[:3]
|
457 |
+
is_audio = True if len(x.shape) == 3 else False
|
458 |
+
if is_audio:
|
459 |
+
mask_prob, mask_length = self.mask_prob_audio, self.mask_length_audio
|
460 |
+
else:
|
461 |
+
mask_prob, mask_length = self.mask_prob_image, self.mask_length_image
|
462 |
+
if mask_prob > 0:
|
463 |
+
|
464 |
+
mask_indices, starts, ends, batch_indexes = compute_mask_indices(
|
465 |
+
(B, T),
|
466 |
+
padding_mask,
|
467 |
+
mask_prob,
|
468 |
+
mask_length,
|
469 |
+
self.mask_selection,
|
470 |
+
self.mask_other,
|
471 |
+
min_masks=2,
|
472 |
+
no_overlap=self.no_mask_overlap,
|
473 |
+
min_space=self.mask_min_space,
|
474 |
+
)
|
475 |
+
mask_indices_np = mask_indices
|
476 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
477 |
+
x = x.transpose(1, 2).contiguous() # [B, T, C, H, W]
|
478 |
+
if B == 1:
|
479 |
+
x[mask_indices] = 0
|
480 |
+
elif is_audio:
|
481 |
+
x[mask_indices] = self.mask_emb
|
482 |
+
elif self.selection_type == 'same_other_seq':
|
483 |
+
perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
|
484 |
+
x_perm = x[perm]
|
485 |
+
x[mask_indices] = x_perm[mask_indices]
|
486 |
+
elif self.selection_type == 'same_seq':
|
487 |
+
batch_indexes_, other_indexes = [], []
|
488 |
+
for batch_index, start, end in zip(batch_indexes, starts, ends):
|
489 |
+
length = end-start
|
490 |
+
other_start = np.setdiff1d(np.arange(T), np.arange(max(0, start-length), end))
|
491 |
+
if len(other_start) > 0:
|
492 |
+
other_start = np.random.choice(other_start, size=1)
|
493 |
+
else:
|
494 |
+
other_start = 0
|
495 |
+
other_end = other_start + length
|
496 |
+
other_indexes.append(np.arange(other_start, other_end).clip(max=T-1))
|
497 |
+
batch_indexes_.append(np.zeros([length], dtype=np.int64)+batch_index)
|
498 |
+
batch_indexes, other_indexes = np.concatenate(batch_indexes_), np.concatenate(other_indexes)
|
499 |
+
x[mask_indices] = x[batch_indexes, other_indexes]
|
500 |
+
|
501 |
+
x = x.transpose(1, 2).contiguous()
|
502 |
+
else:
|
503 |
+
mask_indices = None
|
504 |
+
|
505 |
+
if self.mask_channel_prob > 0:
|
506 |
+
logger.info(f"No mask channel prob for input masking")
|
507 |
+
return x, mask_indices
|
508 |
+
|
509 |
+
def apply_feature_mask(self, x, padding_mask, target_list):
|
510 |
+
B, T, C = x.shape
|
511 |
+
assert self.mask_prob_audio == self.mask_prob_image and self.mask_length_audio == self.mask_length_image, f"masking prob/length for image/audio be same for feature masking"
|
512 |
+
mask_prob, mask_length = self.mask_prob_audio, self.mask_length_image
|
513 |
+
if mask_prob > 0:
|
514 |
+
mask_indices, _, _, _ = compute_mask_indices(
|
515 |
+
(B, T),
|
516 |
+
padding_mask,
|
517 |
+
mask_prob,
|
518 |
+
mask_length,
|
519 |
+
self.mask_selection,
|
520 |
+
self.mask_other,
|
521 |
+
min_masks=2,
|
522 |
+
no_overlap=self.no_mask_overlap,
|
523 |
+
min_space=self.mask_min_space,
|
524 |
+
)
|
525 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
526 |
+
x[mask_indices] = self.mask_emb
|
527 |
+
else:
|
528 |
+
mask_indices = None
|
529 |
+
|
530 |
+
if self.mask_channel_prob > 0:
|
531 |
+
mask_channel_indices, _, _, _ = compute_mask_indices(
|
532 |
+
(B, C),
|
533 |
+
None,
|
534 |
+
self.mask_channel_prob,
|
535 |
+
self.mask_channel_length,
|
536 |
+
self.mask_channel_selection,
|
537 |
+
self.mask_channel_other,
|
538 |
+
no_overlap=self.no_mask_channel_overlap,
|
539 |
+
min_space=self.mask_channel_min_space,
|
540 |
+
)
|
541 |
+
mask_channel_indices = (
|
542 |
+
torch.from_numpy(mask_channel_indices)
|
543 |
+
.to(x.device)
|
544 |
+
.unsqueeze(1)
|
545 |
+
.expand(-1, T, -1)
|
546 |
+
)
|
547 |
+
x[mask_channel_indices] = 0
|
548 |
+
|
549 |
+
return x, mask_indices
|
550 |
+
|
551 |
+
def forward_features(self, source: torch.Tensor, modality: str) -> torch.Tensor:
|
552 |
+
extractor = eval(f"self.feature_extractor_{modality}")
|
553 |
+
if self.feature_grad_mult > 0:
|
554 |
+
features = extractor(source)
|
555 |
+
if self.feature_grad_mult != 1.0:
|
556 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
557 |
+
else:
|
558 |
+
with torch.no_grad():
|
559 |
+
features = extractor(source)
|
560 |
+
return features
|
561 |
+
|
562 |
+
def forward_targets(
|
563 |
+
self, features: torch.Tensor, mask_indices: torch.Tensor, target_list: List[torch.Tensor],
|
564 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
565 |
+
# Trim features to ensure labels exist and then get aligned labels
|
566 |
+
feat_tsz = features.size(2)
|
567 |
+
targ_tsz = min([t.size(1) for t in target_list])
|
568 |
+
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
569 |
+
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
570 |
+
features = features[..., :feat_tsz]
|
571 |
+
if mask_indices is not None:
|
572 |
+
mask_indices = mask_indices[..., :feat_tsz]
|
573 |
+
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
574 |
+
target_list = [t[:, target_inds.long()] for t in target_list]
|
575 |
+
return features, mask_indices, target_list
|
576 |
+
|
577 |
+
def forward_padding_mask(
|
578 |
+
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
579 |
+
) -> torch.Tensor:
|
580 |
+
extra = padding_mask.size(1) % features.size(1)
|
581 |
+
if extra > 0:
|
582 |
+
padding_mask = padding_mask[:, :-extra]
|
583 |
+
padding_mask = padding_mask.view(
|
584 |
+
padding_mask.size(0), features.size(1), -1
|
585 |
+
)
|
586 |
+
padding_mask = padding_mask.all(-1)
|
587 |
+
return padding_mask
|
588 |
+
|
589 |
+
def compute_logits(self, feats, emb_mat):
|
590 |
+
# feats: [B, T, F], emb_mat: [V, F]
|
591 |
+
if self.sim_type == 'dot':
|
592 |
+
logits = torch.matmul(feats, emb_mat.transpose(0, 1))
|
593 |
+
elif self.sim_type == 'cosine':
|
594 |
+
batch_size, timesteps, emb_dim = feats.size()
|
595 |
+
feats_ = feats.view(-1, emb_dim)
|
596 |
+
nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1) # [B*T, V]
|
597 |
+
denom = (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1) * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0) # [B*T, V]
|
598 |
+
logits = (nom/denom.clamp(min=1e-6)).view(batch_size, timesteps, -1)
|
599 |
+
else:
|
600 |
+
raise NotImplementedError
|
601 |
+
logits = logits / self.logit_temp
|
602 |
+
return logits
|
603 |
+
|
604 |
+
def forward(
|
605 |
+
self,
|
606 |
+
source: torch.Tensor,
|
607 |
+
target_list: Optional[List[torch.Tensor]] = None,
|
608 |
+
padding_mask: Optional[torch.Tensor] = None,
|
609 |
+
mask: bool = True,
|
610 |
+
features_only: bool = False,
|
611 |
+
output_layer: Optional[int] = None
|
612 |
+
) -> Dict[str, torch.Tensor]:
|
613 |
+
"""output layer is 1-based"""
|
614 |
+
src_audio, src_video = source['audio'], source['video']
|
615 |
+
if mask and self.masking_type == 'input':
|
616 |
+
src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list)
|
617 |
+
src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list)
|
618 |
+
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video)
|
619 |
+
else:
|
620 |
+
src_audio, src_video, mask_indices = src_audio, src_video, None
|
621 |
+
|
622 |
+
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
|
623 |
+
features_video = self.forward_features(src_video, modality='video')
|
624 |
+
modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random()
|
625 |
+
if self.training:
|
626 |
+
if modality_drop_prob < self.modality_dropout:
|
627 |
+
if audio_drop_prob < self.audio_dropout:
|
628 |
+
features_audio = 0 * features_audio
|
629 |
+
else:
|
630 |
+
features_video = 0 * features_video
|
631 |
+
if self.modality_fuse == 'concat':
|
632 |
+
features = torch.cat([features_audio, features_video], dim=1)
|
633 |
+
elif self.modality_fuse == 'add':
|
634 |
+
features = features_audio + features_video
|
635 |
+
if target_list is not None:
|
636 |
+
features, mask_indices, target_list = self.forward_targets(features, mask_indices, target_list)
|
637 |
+
|
638 |
+
features_pen = features.float().pow(2).mean()
|
639 |
+
|
640 |
+
features = features.transpose(1, 2)
|
641 |
+
features = self.layer_norm(features)
|
642 |
+
|
643 |
+
if padding_mask is not None:
|
644 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
645 |
+
|
646 |
+
if self.post_extract_proj is not None:
|
647 |
+
features = self.post_extract_proj(features)
|
648 |
+
|
649 |
+
features = self.dropout_input(features)
|
650 |
+
if self.masking_type == 'feature' and mask:
|
651 |
+
x, mask_indices = self.apply_feature_mask(features, padding_mask, target_list)
|
652 |
+
else:
|
653 |
+
x = features
|
654 |
+
|
655 |
+
# feature: (B, T, D), float
|
656 |
+
# target: (B, T), long
|
657 |
+
# x: (B, T, D), float
|
658 |
+
# padding_mask: (B, T), bool
|
659 |
+
# mask_indices: (B, T), bool
|
660 |
+
x, _ = self.encoder(
|
661 |
+
x,
|
662 |
+
padding_mask=padding_mask,
|
663 |
+
layer=None if output_layer is None else output_layer - 1
|
664 |
+
)
|
665 |
+
|
666 |
+
if features_only:
|
667 |
+
return {"x": x, "padding_mask": padding_mask, "features": features}
|
668 |
+
|
669 |
+
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
670 |
+
proj_x = self.final_proj(x)
|
671 |
+
if self.untie_final_proj:
|
672 |
+
proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1)
|
673 |
+
else:
|
674 |
+
proj_x_list = [proj_x for _ in self.num_classes]
|
675 |
+
logit_list = [self.compute_logits(proj, emb).view(-1, num_class) for proj, emb, num_class in zip(proj_x_list, label_embs_list, self.num_classes)] # [[B*T, V]]
|
676 |
+
mask, unmask = torch.logical_and(mask_indices, ~padding_mask).view(-1), torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T]
|
677 |
+
logit_m_list, logit_u_list = [logit[mask] for logit in logit_list], [logit[unmask] for logit in logit_list]
|
678 |
+
target_m_list, target_u_list = [target.view(-1)[mask].long() for target in target_list], [target.view(-1)[unmask].long() for target in target_list]
|
679 |
+
result = {
|
680 |
+
"logit_m_list": logit_m_list,
|
681 |
+
"logit_u_list": logit_u_list,
|
682 |
+
"target_m_list": target_m_list,
|
683 |
+
"target_u_list": target_u_list,
|
684 |
+
"padding_mask": padding_mask,
|
685 |
+
"features_pen": features_pen,
|
686 |
+
}
|
687 |
+
return result
|
688 |
+
|
689 |
+
def extract_features(
|
690 |
+
self,
|
691 |
+
source: torch.Tensor,
|
692 |
+
padding_mask: Optional[torch.Tensor] = None,
|
693 |
+
mask: bool = False,
|
694 |
+
ret_conv: bool = False,
|
695 |
+
output_layer: Optional[int] = None,
|
696 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
697 |
+
res = self.forward(
|
698 |
+
source,
|
699 |
+
padding_mask=padding_mask,
|
700 |
+
mask=mask,
|
701 |
+
features_only=True,
|
702 |
+
output_layer=output_layer,
|
703 |
+
)
|
704 |
+
feature = res["features"] if ret_conv else res["x"]
|
705 |
+
return feature, res["padding_mask"]
|
706 |
+
|
707 |
+
def extract_finetune(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None):
|
708 |
+
src_audio, src_video = source['audio'], source['video'] #torch.Size([1, 1, 106, 112, 112])
|
709 |
+
if mask and self.masking_type == 'input':
|
710 |
+
src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list=None)
|
711 |
+
src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list=None)
|
712 |
+
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) # mask_indices not used in fine-tuning
|
713 |
+
else: #
|
714 |
+
src_audio, src_video, mask_indices = src_audio, src_video, None
|
715 |
+
|
716 |
+
if src_audio is not None and src_video is None:
|
717 |
+
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
|
718 |
+
features_video = features_audio.new_zeros(features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1))
|
719 |
+
elif src_audio is None and src_video is not None:
|
720 |
+
features_video = self.forward_features(src_video, modality='video')
|
721 |
+
features_audio = features_video.new_zeros(features_video.size(0), self.encoder_embed_dim, features_video.size(-1)) #全0!
|
722 |
+
elif src_audio is not None and src_video is not None:
|
723 |
+
features_video = self.forward_features(src_video, modality='video') #torch.Size([1, 1024, 106]) #scr torch.Size([12, 1, 314, 88, 88])
|
724 |
+
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T] #torch.Size([12, 26, 314])
|
725 |
+
|
726 |
+
if self.modality_fuse == 'concat': #
|
727 |
+
features = torch.cat([features_audio, features_video], dim=1) #torch.Size([1, 2048, 106])
|
728 |
+
elif self.modality_fuse == 'add':
|
729 |
+
features = features_audio + features_video
|
730 |
+
features_pen = features.float().pow(2).mean()
|
731 |
+
|
732 |
+
features = features.transpose(1, 2)
|
733 |
+
features = self.layer_norm(features)
|
734 |
+
unmasked_features = features.clone()
|
735 |
+
|
736 |
+
if padding_mask is not None: #features:torch.Size([1, 106, 2048])
|
737 |
+
padding_mask = self.forward_padding_mask(features, padding_mask) #torch.Size([4, 154])
|
738 |
+
|
739 |
+
if self.post_extract_proj is not None:
|
740 |
+
features = self.post_extract_proj(features) #torch.Size([1, 106, 1024])
|
741 |
+
|
742 |
+
features = self.dropout_input(features)
|
743 |
+
unmasked_features = self.dropout_features(unmasked_features)
|
744 |
+
x = features
|
745 |
+
mask_indices = None
|
746 |
+
|
747 |
+
# feature: (B, T, D), float
|
748 |
+
# target: (B, T), long
|
749 |
+
# x: (B, T, D), float
|
750 |
+
# padding_mask: (B, T), bool
|
751 |
+
# mask_indices: (B, T), bool
|
752 |
+
x, _ = self.encoder(
|
753 |
+
x,
|
754 |
+
padding_mask=padding_mask,
|
755 |
+
layer=None if output_layer is None else output_layer - 1
|
756 |
+
)
|
757 |
+
|
758 |
+
return x, padding_mask #torch.Size([1, 106, 1024]), None
|
759 |
+
|
760 |
+
|
761 |
+
def get_extra_losses(self, net_output):
|
762 |
+
extra_losses = []
|
763 |
+
names = []
|
764 |
+
if "features_pen" in net_output:
|
765 |
+
extra_losses.append(net_output["features_pen"])
|
766 |
+
names.append("features_pen")
|
767 |
+
|
768 |
+
return extra_losses, names
|
769 |
+
|
770 |
+
def remove_pretraining_modules(self):
|
771 |
+
self.target_glu = None
|
772 |
+
self.final_proj = None
|
773 |
+
|
774 |
+
def get_logits(self, net_output, is_masked=True):
|
775 |
+
raise NotImplementedError
|
776 |
+
|
777 |
+
def get_targets(self, net_output, is_masked=True):
|
778 |
+
raise NotImplementedError
|
779 |
+
|
780 |
+
def compute_nce(self, x, pos, negs):
|
781 |
+
neg_is_pos = (pos == negs).all(-1)
|
782 |
+
pos = pos.unsqueeze(0)
|
783 |
+
targets = torch.cat([pos, negs], dim=0)
|
784 |
+
|
785 |
+
logits = torch.cosine_similarity(
|
786 |
+
x.float(), targets.float(), dim=-1
|
787 |
+
).type_as(x)
|
788 |
+
logits /= self.logit_temp
|
789 |
+
if neg_is_pos.any():
|
790 |
+
logits[1:][neg_is_pos] = float("-inf")
|
791 |
+
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
792 |
+
return logits
|
slam_llm/models/avhubert/hubert_asr.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import sys,logging
|
8 |
+
import contextlib
|
9 |
+
import tempfile
|
10 |
+
from argparse import Namespace
|
11 |
+
from typing import Any, Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
from fairseq import checkpoint_utils, tasks, utils
|
17 |
+
from fairseq.dataclass import FairseqDataclass
|
18 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
19 |
+
from fairseq.models import BaseFairseqModel, FairseqEncoder, FairseqEncoderDecoderModel, register_model
|
20 |
+
from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES
|
21 |
+
from fairseq.tasks import FairseqTask
|
22 |
+
from omegaconf import II, MISSING
|
23 |
+
|
24 |
+
DBG=True if len(sys.argv) == 1 else False
|
25 |
+
|
26 |
+
if DBG:
|
27 |
+
from hubert import AVHubertModel
|
28 |
+
from decoder import TransformerDecoder
|
29 |
+
else:
|
30 |
+
from .hubert import AVHubertModel
|
31 |
+
from .decoder import TransformerDecoder
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class AVHubertAsrConfig(FairseqDataclass):
|
38 |
+
w2v_path: str = field(
|
39 |
+
default=MISSING, metadata={"help": "path to hubert model"}
|
40 |
+
)
|
41 |
+
no_pretrained_weights: bool = field(
|
42 |
+
default=False,
|
43 |
+
metadata={"help": "if true, does not load pretrained weights"},
|
44 |
+
)
|
45 |
+
dropout_input: float = field(
|
46 |
+
default=0.0,
|
47 |
+
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
48 |
+
)
|
49 |
+
final_dropout: float = field(
|
50 |
+
default=0.0,
|
51 |
+
metadata={
|
52 |
+
"help": "dropout after transformer and before final projection"
|
53 |
+
},
|
54 |
+
)
|
55 |
+
dropout: float = field(
|
56 |
+
default=0.0,
|
57 |
+
metadata={"help": "dropout probability inside hubert model"},
|
58 |
+
)
|
59 |
+
attention_dropout: float = field(
|
60 |
+
default=0.0,
|
61 |
+
metadata={
|
62 |
+
"help": "dropout probability for attention weights "
|
63 |
+
"inside hubert model"
|
64 |
+
},
|
65 |
+
)
|
66 |
+
activation_dropout: float = field(
|
67 |
+
default=0.0,
|
68 |
+
metadata={
|
69 |
+
"help": "dropout probability after activation in FFN "
|
70 |
+
"inside hubert model"
|
71 |
+
},
|
72 |
+
)
|
73 |
+
|
74 |
+
# masking
|
75 |
+
apply_mask: bool = field(
|
76 |
+
default=False, metadata={"help": "apply masking during fine-tuning"}
|
77 |
+
)
|
78 |
+
mask_length: int = field(
|
79 |
+
default=10, metadata={"help": "repeat the mask indices multiple times"}
|
80 |
+
)
|
81 |
+
mask_prob: float = field(
|
82 |
+
default=0.5,
|
83 |
+
metadata={
|
84 |
+
"help": "probability of replacing a token with mask "
|
85 |
+
"(normalized by length)"
|
86 |
+
},
|
87 |
+
)
|
88 |
+
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
89 |
+
default="static", metadata={"help": "how to choose masks"}
|
90 |
+
)
|
91 |
+
mask_other: float = field(
|
92 |
+
default=0,
|
93 |
+
metadata={
|
94 |
+
"help": "secondary mask argument "
|
95 |
+
"(used for more complex distributions), "
|
96 |
+
"see help in compute_mask_indices"
|
97 |
+
},
|
98 |
+
)
|
99 |
+
no_mask_overlap: bool = field(
|
100 |
+
default=False, metadata={"help": "whether to allow masks to overlap"}
|
101 |
+
)
|
102 |
+
|
103 |
+
# channel masking
|
104 |
+
mask_channel_length: int = field(
|
105 |
+
default=10,
|
106 |
+
metadata={"help": "length of the mask for features (channels)"},
|
107 |
+
)
|
108 |
+
mask_channel_prob: float = field(
|
109 |
+
default=0.0,
|
110 |
+
metadata={"help": "probability of replacing a feature with 0"},
|
111 |
+
)
|
112 |
+
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
113 |
+
default="static",
|
114 |
+
metadata={"help": "how to choose mask length for channel masking"},
|
115 |
+
)
|
116 |
+
mask_channel_other: float = field(
|
117 |
+
default=0,
|
118 |
+
metadata={
|
119 |
+
"help": "secondary mask argument "
|
120 |
+
"(used for more complex distributions), "
|
121 |
+
"see help in compute_mask_indices"
|
122 |
+
},
|
123 |
+
)
|
124 |
+
no_mask_channel_overlap: bool = field(
|
125 |
+
default=False,
|
126 |
+
metadata={"help": "whether to allow channel masks to overlap"},
|
127 |
+
)
|
128 |
+
freeze_finetune_updates: int = field(
|
129 |
+
default=0,
|
130 |
+
metadata={"help": "dont finetune hubert for this many updates"},
|
131 |
+
)
|
132 |
+
feature_grad_mult: float = field(
|
133 |
+
default=0.0,
|
134 |
+
metadata={"help": "reset feature grad mult in hubert to this"},
|
135 |
+
)
|
136 |
+
layerdrop: float = field(
|
137 |
+
default=0.0,
|
138 |
+
metadata={"help": "probability of dropping a layer in hubert"},
|
139 |
+
)
|
140 |
+
normalize: bool = II("task.normalize")
|
141 |
+
data: str = II("task.data")
|
142 |
+
|
143 |
+
# this holds the loaded hubert args
|
144 |
+
w2v_args: Any = None
|
145 |
+
|
146 |
+
|
147 |
+
@dataclass
|
148 |
+
class AVHubertCtcConfig(AVHubertAsrConfig):
|
149 |
+
pass
|
150 |
+
|
151 |
+
|
152 |
+
@register_model("av_hubert_ctc", dataclass=AVHubertCtcConfig)
|
153 |
+
class AVHubertCtc(BaseFairseqModel):
|
154 |
+
def __init__(self, cfg: AVHubertCtcConfig, w2v_encoder: BaseFairseqModel):
|
155 |
+
super().__init__()
|
156 |
+
self.cfg = cfg
|
157 |
+
self.w2v_encoder = w2v_encoder
|
158 |
+
|
159 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
160 |
+
super().upgrade_state_dict_named(state_dict, name)
|
161 |
+
return state_dict
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def build_model(cls, cfg: AVHubertCtcConfig, task: FairseqTask):
|
165 |
+
"""Build a new model instance."""
|
166 |
+
w2v_encoder = HubertEncoder(cfg, task.target_dictionary)
|
167 |
+
return cls(cfg, w2v_encoder)
|
168 |
+
|
169 |
+
def get_normalized_probs(self, net_output, log_probs):
|
170 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
171 |
+
|
172 |
+
logits = net_output["encoder_out"]
|
173 |
+
if log_probs:
|
174 |
+
return utils.log_softmax(logits.float(), dim=-1)
|
175 |
+
else:
|
176 |
+
return utils.softmax(logits.float(), dim=-1)
|
177 |
+
|
178 |
+
def get_logits(self, net_output):
|
179 |
+
logits = net_output["encoder_out"]
|
180 |
+
padding = net_output["encoder_padding_mask"]
|
181 |
+
if padding is not None and padding.any():
|
182 |
+
padding = padding.T
|
183 |
+
logits[padding][..., 0] = 0
|
184 |
+
logits[padding][..., 1:] = float("-inf")
|
185 |
+
|
186 |
+
return logits
|
187 |
+
|
188 |
+
def forward(self, **kwargs):
|
189 |
+
x = self.w2v_encoder(**kwargs)
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
@dataclass
|
194 |
+
class AVHubertSeq2SeqConfig(AVHubertAsrConfig):
|
195 |
+
decoder_embed_dim: int = field(
|
196 |
+
default=768, metadata={"help": "decoder embedding dimension"}
|
197 |
+
)
|
198 |
+
decoder_ffn_embed_dim: int = field(
|
199 |
+
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
200 |
+
)
|
201 |
+
decoder_layers: int = field(
|
202 |
+
default=6, metadata={"help": "num of decoder layers"}
|
203 |
+
)
|
204 |
+
decoder_layerdrop: float = field(
|
205 |
+
default=0.0, metadata={"help": "decoder layerdrop chance"}
|
206 |
+
)
|
207 |
+
decoder_attention_heads: int = field(
|
208 |
+
default=4, metadata={"help": "num decoder attention heads"}
|
209 |
+
)
|
210 |
+
decoder_learned_pos: bool = field(
|
211 |
+
default=False,
|
212 |
+
metadata={"help": "use learned positional embeddings in the decoder"},
|
213 |
+
)
|
214 |
+
decoder_normalize_before: bool = field(
|
215 |
+
default=False,
|
216 |
+
metadata={"help": "apply layernorm before each decoder block"},
|
217 |
+
)
|
218 |
+
no_token_positional_embeddings: bool = field(
|
219 |
+
default=False,
|
220 |
+
metadata={
|
221 |
+
"help": "if set, disables positional embeddings "
|
222 |
+
"(outside self attention)"
|
223 |
+
},
|
224 |
+
)
|
225 |
+
decoder_dropout: float = field(
|
226 |
+
default=0.0, metadata={"help": "dropout probability in the decoder"}
|
227 |
+
)
|
228 |
+
decoder_attention_dropout: float = field(
|
229 |
+
default=0.0,
|
230 |
+
metadata={
|
231 |
+
"help": "dropout probability for attention weights "
|
232 |
+
"inside the decoder"
|
233 |
+
},
|
234 |
+
)
|
235 |
+
decoder_activation_dropout: float = field(
|
236 |
+
default=0.0,
|
237 |
+
metadata={
|
238 |
+
"help": "dropout probability after activation in FFN "
|
239 |
+
"inside the decoder"
|
240 |
+
},
|
241 |
+
)
|
242 |
+
max_target_positions: int = field(
|
243 |
+
default=2048, metadata={"help": "max target positions"}
|
244 |
+
)
|
245 |
+
share_decoder_input_output_embed: bool = field(
|
246 |
+
default=False,
|
247 |
+
metadata={"help": "share decoder input and output embeddings"},
|
248 |
+
)
|
249 |
+
no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
|
250 |
+
|
251 |
+
class HubertEncoder(FairseqEncoder):
|
252 |
+
def __init__(self, cfg: AVHubertAsrConfig, tgt_dict=None):
|
253 |
+
self.apply_mask = cfg.apply_mask
|
254 |
+
|
255 |
+
arg_overrides = {
|
256 |
+
"dropout": cfg.dropout,
|
257 |
+
"activation_dropout": cfg.activation_dropout,
|
258 |
+
"dropout_input": cfg.dropout_input,
|
259 |
+
"attention_dropout": cfg.attention_dropout,
|
260 |
+
"mask_length": cfg.mask_length,
|
261 |
+
"mask_prob": cfg.mask_prob,
|
262 |
+
"mask_selection": cfg.mask_selection,
|
263 |
+
"mask_other": cfg.mask_other,
|
264 |
+
"no_mask_overlap": cfg.no_mask_overlap,
|
265 |
+
"mask_channel_length": cfg.mask_channel_length,
|
266 |
+
"mask_channel_prob": cfg.mask_channel_prob,
|
267 |
+
"mask_channel_selection": cfg.mask_channel_selection,
|
268 |
+
"mask_channel_other": cfg.mask_channel_other,
|
269 |
+
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
270 |
+
"encoder_layerdrop": cfg.layerdrop,
|
271 |
+
"feature_grad_mult": cfg.feature_grad_mult,
|
272 |
+
}
|
273 |
+
|
274 |
+
if cfg.w2v_args is None:
|
275 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
276 |
+
cfg.w2v_path, arg_overrides
|
277 |
+
)
|
278 |
+
w2v_args = state.get("cfg", None)
|
279 |
+
if w2v_args is None:
|
280 |
+
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
281 |
+
cfg.w2v_args = w2v_args
|
282 |
+
else:
|
283 |
+
state = None
|
284 |
+
w2v_args = cfg.w2v_args
|
285 |
+
if isinstance(w2v_args, Namespace):
|
286 |
+
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
|
287 |
+
w2v_args
|
288 |
+
)
|
289 |
+
|
290 |
+
assert cfg.normalize == w2v_args.task.normalize, (
|
291 |
+
"Fine-tuning works best when data normalization is the same. "
|
292 |
+
"Please check that --normalize is set or unset for "
|
293 |
+
"both pre-training and here"
|
294 |
+
)
|
295 |
+
|
296 |
+
w2v_args.task.data = cfg.data
|
297 |
+
|
298 |
+
task = tasks.setup_task(w2v_args.task)
|
299 |
+
model = task.build_model(w2v_args.model)
|
300 |
+
|
301 |
+
if state is not None and not cfg.no_pretrained_weights:
|
302 |
+
# set strict=False because we omit some modules
|
303 |
+
model.load_state_dict(state["model"], strict=False)
|
304 |
+
|
305 |
+
model.remove_pretraining_modules()
|
306 |
+
|
307 |
+
super().__init__(task.source_dictionary)
|
308 |
+
|
309 |
+
d = model.encoder.embedding_dim
|
310 |
+
|
311 |
+
self.w2v_model = model
|
312 |
+
|
313 |
+
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
314 |
+
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
315 |
+
self.num_updates = 0
|
316 |
+
|
317 |
+
if tgt_dict is not None:
|
318 |
+
self.proj = Linear(d, len(tgt_dict))
|
319 |
+
elif getattr(cfg, "decoder_embed_dim", d) != d:
|
320 |
+
self.proj = Linear(d, cfg.decoder_embed_dim)
|
321 |
+
else:
|
322 |
+
self.proj = None
|
323 |
+
|
324 |
+
def set_num_updates(self, num_updates):
|
325 |
+
"""Set the number of parameters updates."""
|
326 |
+
super().set_num_updates(num_updates)
|
327 |
+
self.num_updates = num_updates
|
328 |
+
|
329 |
+
def forward(self, source, padding_mask, tbc=True, **kwargs):
|
330 |
+
|
331 |
+
w2v_args = {
|
332 |
+
"source": source,
|
333 |
+
"padding_mask": padding_mask,
|
334 |
+
"mask": self.apply_mask and self.training,
|
335 |
+
}
|
336 |
+
ft = self.freeze_finetune_updates <= self.num_updates
|
337 |
+
|
338 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
339 |
+
x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
|
340 |
+
|
341 |
+
if tbc:
|
342 |
+
# B x T x C -> T x B x C
|
343 |
+
x = x.transpose(0, 1)
|
344 |
+
|
345 |
+
x = self.final_dropout(x)
|
346 |
+
|
347 |
+
if self.proj:
|
348 |
+
x = self.proj(x)
|
349 |
+
|
350 |
+
return {
|
351 |
+
"encoder_out": x, # T x B x C
|
352 |
+
"encoder_padding_mask": padding_mask, # B x T
|
353 |
+
"padding_mask": padding_mask,
|
354 |
+
}
|
355 |
+
|
356 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
357 |
+
if encoder_out["encoder_out"] is not None:
|
358 |
+
encoder_out["encoder_out"] = encoder_out[
|
359 |
+
"encoder_out"
|
360 |
+
].index_select(1, new_order)
|
361 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
362 |
+
encoder_out["encoder_padding_mask"] = encoder_out[
|
363 |
+
"encoder_padding_mask"
|
364 |
+
].index_select(0, new_order)
|
365 |
+
return encoder_out
|
366 |
+
|
367 |
+
def max_positions(self):
|
368 |
+
"""Maximum input length supported by the encoder."""
|
369 |
+
return None
|
370 |
+
|
371 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
372 |
+
return state_dict
|
373 |
+
|
374 |
+
|
375 |
+
class HubertEncoderWrapper(FairseqEncoder):
|
376 |
+
def __init__(self, w2v_model):
|
377 |
+
super().__init__(None)
|
378 |
+
self.w2v_model = w2v_model
|
379 |
+
|
380 |
+
def forward(self, source, padding_mask, **kwargs):
|
381 |
+
w2v_args = {
|
382 |
+
"source": source,
|
383 |
+
"padding_mask": padding_mask,
|
384 |
+
}
|
385 |
+
|
386 |
+
x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
|
387 |
+
# B x T x C -> T x B x C
|
388 |
+
x = x.transpose(0, 1) #torch.Size([106, 1, 1024])
|
389 |
+
|
390 |
+
return {
|
391 |
+
"encoder_out": x, # T x B x C
|
392 |
+
"encoder_padding_mask": padding_mask, # B x T
|
393 |
+
"padding_mask": padding_mask
|
394 |
+
}
|
395 |
+
|
396 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
397 |
+
if encoder_out["encoder_out"] is not None:
|
398 |
+
encoder_out["encoder_out"] = encoder_out[
|
399 |
+
"encoder_out"
|
400 |
+
].index_select(1, new_order)
|
401 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
402 |
+
encoder_out["encoder_padding_mask"] = encoder_out[
|
403 |
+
"encoder_padding_mask"
|
404 |
+
].index_select(0, new_order)
|
405 |
+
if encoder_out["padding_mask"] is not None:
|
406 |
+
encoder_out["padding_mask"] = encoder_out[
|
407 |
+
"padding_mask"
|
408 |
+
].index_select(0, new_order)
|
409 |
+
return encoder_out
|
410 |
+
|
411 |
+
@register_model("av_hubert_seq2seq", dataclass=AVHubertSeq2SeqConfig)
|
412 |
+
class AVHubertSeq2Seq(FairseqEncoderDecoderModel):
|
413 |
+
def __init__(self, encoder, decoder, tgt_dict, cfg):
|
414 |
+
super().__init__(encoder, decoder)
|
415 |
+
self.cfg = cfg
|
416 |
+
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
417 |
+
|
418 |
+
@classmethod
|
419 |
+
def build_model(cls, cfg, task):
|
420 |
+
"""Build a new model instance."""
|
421 |
+
|
422 |
+
arg_overrides = {
|
423 |
+
"dropout": cfg.dropout,
|
424 |
+
"activation_dropout": cfg.activation_dropout,
|
425 |
+
"dropout_input": cfg.dropout_input,
|
426 |
+
"attention_dropout": cfg.attention_dropout,
|
427 |
+
"mask_length": cfg.mask_length,
|
428 |
+
"mask_prob": cfg.mask_prob,
|
429 |
+
"mask_selection": cfg.mask_selection,
|
430 |
+
"mask_other": cfg.mask_other,
|
431 |
+
"no_mask_overlap": cfg.no_mask_overlap,
|
432 |
+
"mask_channel_length": cfg.mask_channel_length,
|
433 |
+
"mask_channel_prob": cfg.mask_channel_prob,
|
434 |
+
"mask_channel_selection": cfg.mask_channel_selection,
|
435 |
+
"mask_channel_other": cfg.mask_channel_other,
|
436 |
+
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
437 |
+
"encoder_layerdrop": cfg.layerdrop,
|
438 |
+
"feature_grad_mult": cfg.feature_grad_mult,
|
439 |
+
}
|
440 |
+
|
441 |
+
if cfg.w2v_args is None:
|
442 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
443 |
+
cfg.w2v_path, arg_overrides
|
444 |
+
)
|
445 |
+
w2v_args = state.get("cfg", None)
|
446 |
+
if w2v_args is None:
|
447 |
+
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
448 |
+
cfg.w2v_args = w2v_args
|
449 |
+
else:
|
450 |
+
state = None
|
451 |
+
w2v_args = cfg.w2v_args
|
452 |
+
if isinstance(w2v_args, Namespace):
|
453 |
+
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
|
454 |
+
w2v_args
|
455 |
+
)
|
456 |
+
|
457 |
+
assert cfg.normalize == w2v_args.task.normalize, (
|
458 |
+
"Fine-tuning works best when data normalization is the same. "
|
459 |
+
"Please check that --normalize is set or unset for "
|
460 |
+
"both pre-training and here"
|
461 |
+
)
|
462 |
+
|
463 |
+
w2v_args.task.data = cfg.data
|
464 |
+
|
465 |
+
task_pretrain = tasks.setup_task(w2v_args.task)
|
466 |
+
if state is not None:
|
467 |
+
task_pretrain.load_state_dict(state['task_state'])
|
468 |
+
|
469 |
+
encoder_ = task_pretrain.build_model(w2v_args.model)
|
470 |
+
|
471 |
+
encoder = HubertEncoderWrapper(encoder_)
|
472 |
+
if state is not None and not cfg.no_pretrained_weights:
|
473 |
+
# set strict=False because we omit some modules
|
474 |
+
del state['model']['mask_emb']
|
475 |
+
encoder.w2v_model.load_state_dict(state["model"], strict=False)
|
476 |
+
|
477 |
+
encoder.w2v_model.remove_pretraining_modules()
|
478 |
+
|
479 |
+
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
480 |
+
|
481 |
+
def build_embedding(dictionary, embed_dim):
|
482 |
+
num_embeddings = len(dictionary)
|
483 |
+
padding_idx = dictionary.pad()
|
484 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx)
|
485 |
+
return emb
|
486 |
+
|
487 |
+
decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim)
|
488 |
+
decoder = TransformerDecoder(cfg, tgt_dict, decoder_embed_tokens)
|
489 |
+
|
490 |
+
return AVHubertSeq2Seq(encoder, decoder, tgt_dict, cfg)
|
491 |
+
|
492 |
+
|
493 |
+
def forward(self, **kwargs):
|
494 |
+
# ft = self.freeze_finetune_updates <= self.num_updates
|
495 |
+
# with torch.no_grad() if not ft else contextlib.ExitStack():
|
496 |
+
# output = self.encoder(**kwargs)
|
497 |
+
with torch.no_grad():
|
498 |
+
output = self.encoder(**kwargs) #encoder_out,encoder_padding_mask,padding_mask
|
499 |
+
# decoder_out = self.decoder(prev_output_tokens=kwargs['prev_output_tokens'], encoder_out=output)
|
500 |
+
return output
|
501 |
+
|
502 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
503 |
+
super().upgrade_state_dict_named(state_dict, name)
|
504 |
+
return state_dict
|
505 |
+
|
506 |
+
def set_num_updates(self, num_updates):
|
507 |
+
"""Set the number of parameters updates."""
|
508 |
+
super().set_num_updates(num_updates)
|
509 |
+
self.num_updates = num_updates
|
510 |
+
|
511 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
512 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
513 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
514 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
515 |
+
return m
|
516 |
+
|
517 |
+
|
518 |
+
def Linear(in_features, out_features, bias=True):
|
519 |
+
m = nn.Linear(in_features, out_features, bias)
|
520 |
+
nn.init.xavier_uniform_(m.weight)
|
521 |
+
if bias:
|
522 |
+
nn.init.constant_(m.bias, 0.0)
|
523 |
+
return m
|
slam_llm/models/avhubert/hubert_criterion.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import re
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from typing import List, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from fairseq import metrics, utils
|
15 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
16 |
+
from fairseq.dataclass import FairseqDataclass
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class AVHubertCriterionConfig(FairseqDataclass):
|
21 |
+
pred_masked_weight: float = field(
|
22 |
+
default=1.0,
|
23 |
+
metadata={"help": "weight for predictive loss for masked frames"},
|
24 |
+
)
|
25 |
+
pred_nomask_weight: float = field(
|
26 |
+
default=0.0,
|
27 |
+
metadata={"help": "weight for predictive loss for unmasked frames"},
|
28 |
+
)
|
29 |
+
loss_weights: Optional[List[float]] = field(
|
30 |
+
default=None,
|
31 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
32 |
+
)
|
33 |
+
log_keys: List[str] = field(
|
34 |
+
default_factory=lambda: [],
|
35 |
+
metadata={"help": "output keys to log"},
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
@register_criterion("av_hubert", dataclass=AVHubertCriterionConfig)
|
40 |
+
class AVHubertCriterion(FairseqCriterion):
|
41 |
+
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
|
42 |
+
super().__init__(task)
|
43 |
+
self.pred_masked_weight = pred_masked_weight
|
44 |
+
self.pred_nomask_weight = pred_nomask_weight
|
45 |
+
self.loss_weights = loss_weights
|
46 |
+
self.log_keys = [] if log_keys is None else log_keys
|
47 |
+
|
48 |
+
def forward(self, model, sample, reduce=True, log_pred=False):
|
49 |
+
"""Compute the loss for the given sample.
|
50 |
+
Returns a tuple with three elements:
|
51 |
+
1) the loss
|
52 |
+
2) the sample size, which is used as the denominator for the gradient
|
53 |
+
3) logging outputs to display while training
|
54 |
+
"""
|
55 |
+
net_output = model(target_list=sample["target_list"], **sample["net_input"])
|
56 |
+
loss = 0.
|
57 |
+
sample_size = 0
|
58 |
+
logging_output = {}
|
59 |
+
reduction = "sum" if reduce else "none"
|
60 |
+
|
61 |
+
loss_m_list = []
|
62 |
+
logp_m_list, targ_m_list = net_output['logit_m_list'], net_output['target_m_list']
|
63 |
+
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
64 |
+
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
65 |
+
loss_m_list.append(loss_m)
|
66 |
+
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
|
67 |
+
if self.pred_masked_weight > 0:
|
68 |
+
loss += self.pred_masked_weight * sum(loss_m_list)
|
69 |
+
sample_size += targ_m_list[0].numel()
|
70 |
+
|
71 |
+
loss_u_list = []
|
72 |
+
logp_u_list, targ_u_list = net_output['logit_u_list'], net_output['target_u_list']
|
73 |
+
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
74 |
+
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
75 |
+
loss_u_list.append(loss_u)
|
76 |
+
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
|
77 |
+
if self.pred_nomask_weight > 0:
|
78 |
+
loss += self.pred_nomask_weight * sum(loss_u_list)
|
79 |
+
sample_size += targ_u_list[0].numel()
|
80 |
+
|
81 |
+
if self.loss_weights is not None:
|
82 |
+
assert hasattr(model, "get_extra_losses")
|
83 |
+
extra_losses, names = model.get_extra_losses(net_output)
|
84 |
+
if torch.is_tensor(extra_losses):
|
85 |
+
extra_losses = [extra_losses]
|
86 |
+
names = [names]
|
87 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
88 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
89 |
+
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
90 |
+
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
91 |
+
if coef != 0 and p is not None:
|
92 |
+
p = coef * p.float() * sample_size
|
93 |
+
loss += p
|
94 |
+
logging_output[f"loss_{n}"] = p.item()
|
95 |
+
|
96 |
+
logging_output = {
|
97 |
+
"loss": loss.item() if reduce else loss,
|
98 |
+
"ntokens": sample_size,
|
99 |
+
"nsentences": sample["id"].numel(),
|
100 |
+
"sample_size": sample_size,
|
101 |
+
**logging_output,
|
102 |
+
}
|
103 |
+
|
104 |
+
for lk in self.log_keys:
|
105 |
+
if lk in net_output:
|
106 |
+
logging_output[lk] = float((net_output[lk]))
|
107 |
+
|
108 |
+
with torch.no_grad():
|
109 |
+
for i, logp_m in enumerate(logp_m_list):
|
110 |
+
# corr_m, count_m = compute_correct(logp_m)
|
111 |
+
if logp_m.numel() == 0:
|
112 |
+
corr_m, count_m = 0, 0
|
113 |
+
else:
|
114 |
+
corr_m, count_m = (logp_m.argmax(dim=-1)==targ_m_list[i]).sum().item(), len(targ_m_list[i])
|
115 |
+
logging_output[f"correct_m_{i}"] = corr_m
|
116 |
+
logging_output[f"count_m_{i}"] = count_m
|
117 |
+
|
118 |
+
for i, logp_u in enumerate(logp_u_list):
|
119 |
+
if logp_u.numel() == 0:
|
120 |
+
corr_u, count_u = 0, 0
|
121 |
+
else:
|
122 |
+
corr_u, count_u = (logp_u.argmax(dim=-1)==targ_u_list[i]).sum().item(), len(targ_u_list[i])
|
123 |
+
logging_output[f"correct_u_{i}"] = corr_u
|
124 |
+
logging_output[f"count_u_{i}"] = count_u
|
125 |
+
|
126 |
+
return loss, sample_size, logging_output
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def reduce_metrics(logging_outputs) -> None:
|
130 |
+
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
|
131 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
132 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
133 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
134 |
+
|
135 |
+
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
|
136 |
+
if sample_size != ntokens:
|
137 |
+
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
|
138 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
|
139 |
+
else:
|
140 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
|
141 |
+
|
142 |
+
counts = {}
|
143 |
+
for lk in logging_outputs[0].keys():
|
144 |
+
if lk.startswith("count_"):
|
145 |
+
val = sum(log[lk] for log in logging_outputs)
|
146 |
+
metrics.log_scalar(lk, val)
|
147 |
+
counts[lk] = val
|
148 |
+
|
149 |
+
for lk in logging_outputs[0].keys():
|
150 |
+
if lk.startswith("loss_"):
|
151 |
+
val = sum(log[lk] for log in logging_outputs)
|
152 |
+
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
|
153 |
+
elif lk.startswith("correct_"):
|
154 |
+
val = sum(log[lk] for log in logging_outputs)
|
155 |
+
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def aggregate_logging_outputs(logging_outputs):
|
159 |
+
"""Aggregate logging outputs from data parallel training."""
|
160 |
+
raise NotImplementedError()
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def logging_outputs_can_be_summed() -> bool:
|
164 |
+
"""
|
165 |
+
Whether the logging outputs returned by `forward` can be summed
|
166 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
167 |
+
to True will improves distributed training speed.
|
168 |
+
"""
|
169 |
+
return False
|
slam_llm/models/avhubert/hubert_dataset.py
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import itertools
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
from typing import Any, List, Optional, Union
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from fairseq.data import data_utils
|
19 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
20 |
+
from python_speech_features import logfbank
|
21 |
+
from scipy.io import wavfile
|
22 |
+
|
23 |
+
DBG=True if len(sys.argv) == 1 else False
|
24 |
+
|
25 |
+
if DBG:
|
26 |
+
import utils as custom_utils
|
27 |
+
logging.basicConfig(
|
28 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
29 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
30 |
+
level=os.environ.get("LOGLEVEL", "DEBUG").upper(),
|
31 |
+
stream=sys.stdout,
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
from . import utils as custom_utils
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
def load_audio_visual(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1):
|
40 |
+
def is_audio_label_aligned(audio_dur, label_durs):
|
41 |
+
return all([abs(audio_dur - label_dur)<tol for label_dur in label_durs])
|
42 |
+
|
43 |
+
n_long, n_short, n_unaligned = 0, 0, 0
|
44 |
+
names, inds, sizes = [], [], []
|
45 |
+
dur_from_label_list = []
|
46 |
+
is_seq_label = any([x==-1 for x in label_rates])
|
47 |
+
for label_path, label_rate in zip(label_paths, label_rates):
|
48 |
+
label_lengths = [len(line.rstrip().split())/label_rate for line in open(label_path).readlines()]
|
49 |
+
dur_from_label_list.append(label_lengths)
|
50 |
+
dur_from_label_list = list(zip(*dur_from_label_list))
|
51 |
+
|
52 |
+
with open(manifest_path) as f:
|
53 |
+
root = f.readline().strip()
|
54 |
+
for ind, line in enumerate(f):
|
55 |
+
items = line.strip().split("\t")
|
56 |
+
sz = int(items[-2]) #
|
57 |
+
if min_keep is not None and sz < min_keep:
|
58 |
+
n_short += 1
|
59 |
+
elif max_keep is not None and sz > max_keep:
|
60 |
+
n_long += 1
|
61 |
+
elif (not is_seq_label) and (not is_audio_label_aligned(sz/frame_rate, dur_from_label_list[ind])):
|
62 |
+
n_unaligned += 1
|
63 |
+
else:
|
64 |
+
video_path = items[1]
|
65 |
+
audio_path = items[2]
|
66 |
+
audio_id = items[0]
|
67 |
+
names.append((video_path, audio_path+':'+audio_id))
|
68 |
+
inds.append(ind)
|
69 |
+
sizes.append(sz)
|
70 |
+
tot = ind + 1
|
71 |
+
logger.info(
|
72 |
+
(
|
73 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
74 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long and {n_unaligned} unaligned, "
|
75 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
76 |
+
)
|
77 |
+
)
|
78 |
+
return root, names, inds, tot, sizes
|
79 |
+
|
80 |
+
def load_label(label_path, inds, tot):
|
81 |
+
with open(label_path) as f:
|
82 |
+
labels = [line.rstrip() for line in f]
|
83 |
+
assert (
|
84 |
+
len(labels) == tot
|
85 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
86 |
+
labels = [labels[i] for i in inds]
|
87 |
+
return labels
|
88 |
+
|
89 |
+
|
90 |
+
def load_label_offset(label_path, inds, tot):
|
91 |
+
with open(label_path) as f:
|
92 |
+
code_lengths = [len(line.encode("utf-8")) for line in f]
|
93 |
+
assert (
|
94 |
+
len(code_lengths) == tot
|
95 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
96 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
97 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
98 |
+
return offsets
|
99 |
+
|
100 |
+
|
101 |
+
def verify_label_lengths(
|
102 |
+
audio_sizes,
|
103 |
+
audio_rate,
|
104 |
+
label_path,
|
105 |
+
label_rate,
|
106 |
+
inds,
|
107 |
+
tot,
|
108 |
+
tol=0.1, # tolerance in seconds
|
109 |
+
):
|
110 |
+
if label_rate < 0:
|
111 |
+
logger.info(f"{label_path} is sequence label. skipped")
|
112 |
+
return
|
113 |
+
|
114 |
+
with open(label_path) as f:
|
115 |
+
lengths = [len(line.rstrip().split()) for line in f]
|
116 |
+
assert len(lengths) == tot
|
117 |
+
lengths = [lengths[i] for i in inds]
|
118 |
+
num_invalid = 0
|
119 |
+
for i, ind in enumerate(inds):
|
120 |
+
dur_from_audio = audio_sizes[i] / audio_rate
|
121 |
+
dur_from_label = lengths[i] / label_rate
|
122 |
+
if abs(dur_from_audio - dur_from_label) > tol:
|
123 |
+
logger.warning(
|
124 |
+
(
|
125 |
+
f"audio and label duration differ too much "
|
126 |
+
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
127 |
+
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
128 |
+
f"is correctly set (currently {label_rate}). "
|
129 |
+
f"num. of samples = {audio_sizes[i]}; "
|
130 |
+
f"label length = {lengths[i]}"
|
131 |
+
)
|
132 |
+
)
|
133 |
+
num_invalid += 1
|
134 |
+
if num_invalid > 0:
|
135 |
+
logger.warning(
|
136 |
+
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
class AVHubertDataset(FairseqDataset):
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
manifest_path: str,
|
144 |
+
sample_rate: float,
|
145 |
+
label_paths: List[str],
|
146 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
147 |
+
pad_list: List[str],
|
148 |
+
eos_list: List[str],
|
149 |
+
label_processors: Optional[List[Any]] = None,
|
150 |
+
max_keep_sample_size: Optional[int] = None,
|
151 |
+
min_keep_sample_size: Optional[int] = None,
|
152 |
+
max_sample_size: Optional[int] = None,
|
153 |
+
shuffle: bool = True,
|
154 |
+
pad_audio: bool = False,
|
155 |
+
normalize: bool = False,
|
156 |
+
store_labels: bool = True,
|
157 |
+
random_crop: bool = False,
|
158 |
+
single_target: bool = False,
|
159 |
+
stack_order_audio: int=1,
|
160 |
+
skip_verify: bool=False,
|
161 |
+
image_mean: float=0,
|
162 |
+
image_std: float=1,
|
163 |
+
image_crop_size: int=88,
|
164 |
+
image_aug: bool=False,
|
165 |
+
modalities: Optional[List[str]]=None,
|
166 |
+
is_s2s=False,
|
167 |
+
noise_fn=None,
|
168 |
+
noise_prob=0,
|
169 |
+
noise_snr=0,
|
170 |
+
noise_num=1
|
171 |
+
):
|
172 |
+
self.label_rates = (
|
173 |
+
[label_rates for _ in range(len(label_paths))]
|
174 |
+
if isinstance(label_rates, int)
|
175 |
+
else label_rates
|
176 |
+
)
|
177 |
+
self.modalities = set(modalities)
|
178 |
+
self.audio_root, self.names, inds, tot, self.sizes = load_audio_visual(manifest_path, max_keep_sample_size, min_keep_sample_size, frame_rate=sample_rate, label_paths=label_paths, label_rates=self.label_rates)
|
179 |
+
self.sample_rate = sample_rate
|
180 |
+
self.stack_order_audio = stack_order_audio
|
181 |
+
self.shuffle = shuffle
|
182 |
+
self.random_crop = random_crop
|
183 |
+
|
184 |
+
self.num_labels = len(label_paths)
|
185 |
+
self.pad_list = pad_list
|
186 |
+
self.eos_list = eos_list
|
187 |
+
self.label_processors = label_processors
|
188 |
+
self.single_target = single_target
|
189 |
+
self.store_labels = store_labels
|
190 |
+
self.is_s2s = is_s2s
|
191 |
+
self.noise_wav, self.noise_prob, self.noise_snr, self.noise_num = [ln.strip() for ln in open(noise_fn).readlines()] if noise_fn is not None else [], noise_prob, noise_snr, noise_num
|
192 |
+
|
193 |
+
assert self.single_target == (self.label_rates[0] == -1), f"single target should be equivalent to sequence label (label_rate==-1)"
|
194 |
+
if store_labels:
|
195 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
196 |
+
else:
|
197 |
+
self.label_paths = label_paths
|
198 |
+
self.label_offsets_list = [
|
199 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
200 |
+
]
|
201 |
+
assert (
|
202 |
+
label_processors is None
|
203 |
+
or len(label_processors) == self.num_labels
|
204 |
+
)
|
205 |
+
if not skip_verify:
|
206 |
+
for label_path, label_rate in zip(label_paths, self.label_rates):
|
207 |
+
verify_label_lengths(self.sizes, self.sample_rate, label_path, label_rate, inds, tot)
|
208 |
+
else:
|
209 |
+
logger.info(f"Skip label alignment verifying")
|
210 |
+
|
211 |
+
self.max_sample_size = (
|
212 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
213 |
+
)
|
214 |
+
self.pad_audio = pad_audio
|
215 |
+
self.normalize = normalize
|
216 |
+
if image_aug:
|
217 |
+
self.transform = custom_utils.Compose([
|
218 |
+
custom_utils.Normalize( 0.0,255.0 ),
|
219 |
+
custom_utils.RandomCrop((image_crop_size, image_crop_size)),
|
220 |
+
custom_utils.HorizontalFlip(0.5),
|
221 |
+
custom_utils.Normalize(image_mean, image_std) ])
|
222 |
+
else:
|
223 |
+
self.transform = custom_utils.Compose([
|
224 |
+
custom_utils.Normalize( 0.0,255.0 ),
|
225 |
+
custom_utils.CenterCrop((image_crop_size, image_crop_size)),
|
226 |
+
custom_utils.Normalize(image_mean, image_std) ])
|
227 |
+
logger.info(f"image transform: {self.transform}")
|
228 |
+
|
229 |
+
logger.info(
|
230 |
+
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
231 |
+
f"normalize={normalize}, max_sample_size={self.max_sample_size}, "
|
232 |
+
f"seqs2seq data={self.is_s2s},")
|
233 |
+
logger.info(
|
234 |
+
f"Noise wav: {noise_fn}->{len(self.noise_wav)} wav, Prob: {self.noise_prob}, SNR: {self.noise_snr}, Number of mixture: {self.noise_num}"
|
235 |
+
)
|
236 |
+
|
237 |
+
def get_label(self, index, label_idx):
|
238 |
+
if self.store_labels:
|
239 |
+
label = self.label_list[label_idx][index]
|
240 |
+
else:
|
241 |
+
with open(self.label_paths[label_idx]) as f:
|
242 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
243 |
+
f.seek(offset_s)
|
244 |
+
label = f.read(offset_e - offset_s)
|
245 |
+
|
246 |
+
if self.label_processors is not None:
|
247 |
+
label = self.label_processors[label_idx](label)
|
248 |
+
return label
|
249 |
+
|
250 |
+
def get_labels(self, index):
|
251 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
252 |
+
|
253 |
+
def load_feature(self, mix_name):
|
254 |
+
"""
|
255 |
+
Load image and audio feature
|
256 |
+
Returns:
|
257 |
+
video_feats: numpy.ndarray of shape [T, H, W, 1], audio_feats: numpy.ndarray of shape [T, F]
|
258 |
+
"""
|
259 |
+
def stacker(feats, stack_order):
|
260 |
+
"""
|
261 |
+
Concatenating consecutive audio frames
|
262 |
+
Args:
|
263 |
+
feats - numpy.ndarray of shape [T, F]
|
264 |
+
stack_order - int (number of neighboring frames to concatenate
|
265 |
+
Returns:
|
266 |
+
feats - numpy.ndarray of shape [T', F']
|
267 |
+
"""
|
268 |
+
feat_dim = feats.shape[1]
|
269 |
+
if len(feats) % stack_order != 0:
|
270 |
+
res = stack_order - len(feats) % stack_order
|
271 |
+
res = np.zeros([res, feat_dim]).astype(feats.dtype)
|
272 |
+
feats = np.concatenate([feats, res], axis=0)
|
273 |
+
feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim)
|
274 |
+
return feats
|
275 |
+
video_fn, audio_fn = mix_name
|
276 |
+
if 'video' in self.modalities:
|
277 |
+
video_feats = self.load_video(video_fn) # [T, H, W, 1]
|
278 |
+
else:
|
279 |
+
video_feats = None
|
280 |
+
if 'audio' in self.modalities:
|
281 |
+
audio_fn = audio_fn.split(':')[0]
|
282 |
+
sample_rate, wav_data = wavfile.read(audio_fn)
|
283 |
+
assert sample_rate == 16_000 and len(wav_data.shape) == 1
|
284 |
+
if np.random.rand() < self.noise_prob:
|
285 |
+
wav_data = self.add_noise(wav_data)
|
286 |
+
audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32) # [T, F]
|
287 |
+
audio_feats = stacker(audio_feats, self.stack_order_audio) # [T/stack_order_audio, F*stack_order_audio]
|
288 |
+
else:
|
289 |
+
audio_feats = None
|
290 |
+
if audio_feats is not None and video_feats is not None:
|
291 |
+
diff = len(audio_feats) - len(video_feats)
|
292 |
+
if diff < 0:
|
293 |
+
audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)])
|
294 |
+
elif diff > 0:
|
295 |
+
audio_feats = audio_feats[:-diff]
|
296 |
+
return video_feats, audio_feats
|
297 |
+
|
298 |
+
def load_video(self, audio_name):
|
299 |
+
feats = custom_utils.load_video(os.path.join(self.audio_root, audio_name))
|
300 |
+
feats = self.transform(feats)
|
301 |
+
feats = np.expand_dims(feats, axis=-1)
|
302 |
+
return feats
|
303 |
+
|
304 |
+
def select_noise(self):
|
305 |
+
rand_indexes = np.random.randint(0, len(self.noise_wav), size=self.noise_num)
|
306 |
+
noise_wav = []
|
307 |
+
for x in rand_indexes:
|
308 |
+
noise_wav.append(wavfile.read(self.noise_wav[x])[1].astype(np.float32))
|
309 |
+
if self.noise_num == 1:
|
310 |
+
return noise_wav[0]
|
311 |
+
else:
|
312 |
+
min_len = min([len(x) for x in noise_wav])
|
313 |
+
noise_wav = [x[:min_len] for x in noise_wav]
|
314 |
+
noise_wav = np.floor(np.stack(noise_wav).mean(axis=0))
|
315 |
+
return noise_wav
|
316 |
+
|
317 |
+
def add_noise(self, clean_wav):
|
318 |
+
clean_wav = clean_wav.astype(np.float32)
|
319 |
+
noise_wav = self.select_noise()
|
320 |
+
if type(self.noise_snr) == int or type(self.noise_snr) == float:
|
321 |
+
snr = self.noise_snr
|
322 |
+
elif type(self.noise_snr) == tuple:
|
323 |
+
snr = np.random.randint(self.noise_snr[0], self.noise_snr[1]+1)
|
324 |
+
clean_rms = np.sqrt(np.mean(np.square(clean_wav), axis=-1))
|
325 |
+
if len(clean_wav) > len(noise_wav):
|
326 |
+
ratio = int(np.ceil(len(clean_wav)/len(noise_wav)))
|
327 |
+
noise_wav = np.concatenate([noise_wav for _ in range(ratio)])
|
328 |
+
if len(clean_wav) < len(noise_wav):
|
329 |
+
start = 0
|
330 |
+
noise_wav = noise_wav[start: start + len(clean_wav)]
|
331 |
+
noise_rms = np.sqrt(np.mean(np.square(noise_wav), axis=-1))
|
332 |
+
adjusted_noise_rms = clean_rms / (10**(snr/20))
|
333 |
+
adjusted_noise_wav = noise_wav * (adjusted_noise_rms / noise_rms)
|
334 |
+
mixed = clean_wav + adjusted_noise_wav
|
335 |
+
|
336 |
+
#Avoid clipping noise
|
337 |
+
max_int16 = np.iinfo(np.int16).max
|
338 |
+
min_int16 = np.iinfo(np.int16).min
|
339 |
+
if mixed.max(axis=0) > max_int16 or mixed.min(axis=0) < min_int16:
|
340 |
+
if mixed.max(axis=0) >= abs(mixed.min(axis=0)):
|
341 |
+
reduction_rate = max_int16 / mixed.max(axis=0)
|
342 |
+
else :
|
343 |
+
reduction_rate = min_int16 / mixed.min(axis=0)
|
344 |
+
mixed = mixed * (reduction_rate)
|
345 |
+
mixed = mixed.astype(np.int16)
|
346 |
+
return mixed
|
347 |
+
|
348 |
+
def __getitem__(self, index):
|
349 |
+
video_feats, audio_feats = self.load_feature(self.names[index])
|
350 |
+
audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)) if audio_feats is not None else None, torch.from_numpy(video_feats.astype(np.float32)) if video_feats is not None else None
|
351 |
+
if self.normalize and 'audio' in self.modalities:
|
352 |
+
with torch.no_grad():
|
353 |
+
audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
|
354 |
+
labels = self.get_labels(index)
|
355 |
+
fid = self.names[index][1].split(':')[1]
|
356 |
+
return {"id": index, 'fid': fid, "video_source": video_feats, 'audio_source': audio_feats, "label_list": labels}
|
357 |
+
|
358 |
+
def __len__(self):
|
359 |
+
return len(self.sizes)
|
360 |
+
|
361 |
+
def crop_to_max_size(self, wav, target_size, start=None):
|
362 |
+
size = len(wav)
|
363 |
+
diff = size - target_size
|
364 |
+
if diff <= 0:
|
365 |
+
return wav, 0
|
366 |
+
# longer utterances
|
367 |
+
if start is None:
|
368 |
+
start, end = 0, target_size
|
369 |
+
if self.random_crop:
|
370 |
+
start = np.random.randint(0, diff + 1)
|
371 |
+
end = size - diff + start
|
372 |
+
else:
|
373 |
+
end = start + target_size
|
374 |
+
return wav[start:end], start
|
375 |
+
|
376 |
+
def collater(self, samples):
|
377 |
+
samples = [s for s in samples if s["id"] is not None]
|
378 |
+
if len(samples) == 0:
|
379 |
+
return {}
|
380 |
+
|
381 |
+
audio_source, video_source = [s["audio_source"] for s in samples], [s["video_source"] for s in samples]
|
382 |
+
if audio_source[0] is None:
|
383 |
+
audio_source = None
|
384 |
+
if video_source[0] is None:
|
385 |
+
video_source = None
|
386 |
+
if audio_source is not None:
|
387 |
+
audio_sizes = [len(s) for s in audio_source]
|
388 |
+
else:
|
389 |
+
audio_sizes = [len(s) for s in video_source]
|
390 |
+
if self.pad_audio:
|
391 |
+
audio_size = min(max(audio_sizes), self.max_sample_size)
|
392 |
+
else:
|
393 |
+
audio_size = min(min(audio_sizes), self.max_sample_size)
|
394 |
+
if audio_source is not None:
|
395 |
+
collated_audios, padding_mask, audio_starts = self.collater_audio(audio_source, audio_size)
|
396 |
+
else:
|
397 |
+
collated_audios, audio_starts = None, None
|
398 |
+
if video_source is not None:
|
399 |
+
collated_videos, padding_mask, audio_starts = self.collater_audio(video_source, audio_size, audio_starts)
|
400 |
+
else:
|
401 |
+
collated_videos = None
|
402 |
+
targets_by_label = [
|
403 |
+
[s["label_list"][i] for s in samples]
|
404 |
+
for i in range(self.num_labels)
|
405 |
+
]
|
406 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(
|
407 |
+
targets_by_label, audio_size, audio_starts
|
408 |
+
)
|
409 |
+
source = {"audio": collated_audios, "video": collated_videos}
|
410 |
+
net_input = {"source": source, "padding_mask": padding_mask}
|
411 |
+
batch = {
|
412 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
413 |
+
"net_input": net_input,
|
414 |
+
"utt_id": [s['fid'] for s in samples]
|
415 |
+
}
|
416 |
+
|
417 |
+
if self.single_target:
|
418 |
+
batch["target_lengths"] = lengths_list[0]
|
419 |
+
batch["ntokens"] = ntokens_list[0]
|
420 |
+
if self.is_s2s:
|
421 |
+
batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1]
|
422 |
+
else:
|
423 |
+
batch["target"] = targets_list[0]
|
424 |
+
else:
|
425 |
+
batch["target_lengths_list"] = lengths_list
|
426 |
+
batch["ntokens_list"] = ntokens_list
|
427 |
+
batch["target_list"] = targets_list
|
428 |
+
return batch
|
429 |
+
|
430 |
+
def collater_audio(self, audios, audio_size, audio_starts=None):
|
431 |
+
audio_feat_shape = list(audios[0].shape[1:])
|
432 |
+
collated_audios = audios[0].new_zeros([len(audios), audio_size]+audio_feat_shape)
|
433 |
+
padding_mask = (
|
434 |
+
torch.BoolTensor(len(audios), audio_size).fill_(False) #
|
435 |
+
)
|
436 |
+
start_known = audio_starts is not None
|
437 |
+
audio_starts = [0 for _ in audios] if not start_known else audio_starts
|
438 |
+
for i, audio in enumerate(audios):
|
439 |
+
diff = len(audio) - audio_size
|
440 |
+
if diff == 0:
|
441 |
+
collated_audios[i] = audio
|
442 |
+
elif diff < 0:
|
443 |
+
assert self.pad_audio
|
444 |
+
collated_audios[i] = torch.cat(
|
445 |
+
[audio, audio.new_full([-diff]+audio_feat_shape, 0.0)]
|
446 |
+
)
|
447 |
+
padding_mask[i, diff:] = True
|
448 |
+
else:
|
449 |
+
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
450 |
+
audio, audio_size, audio_starts[i] if start_known else None
|
451 |
+
)
|
452 |
+
if len(audios[0].shape) == 2:
|
453 |
+
collated_audios = collated_audios.transpose(1, 2) # [B, T, F] -> [B, F, T]
|
454 |
+
else:
|
455 |
+
collated_audios = collated_audios.permute((0, 4, 1, 2, 3)).contiguous() # [B, T, H, W, C] -> [B, C, T, H, W]
|
456 |
+
return collated_audios, padding_mask, audio_starts
|
457 |
+
|
458 |
+
def collater_frm_label(
|
459 |
+
self, targets, audio_size, audio_starts, label_rate, pad
|
460 |
+
):
|
461 |
+
assert label_rate > 0
|
462 |
+
s2f = label_rate / self.sample_rate # num label per sample
|
463 |
+
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
464 |
+
frm_size = int(round(audio_size * s2f))
|
465 |
+
if not self.pad_audio:
|
466 |
+
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
467 |
+
frm_size = min(frm_size, *rem_size)
|
468 |
+
targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)]
|
469 |
+
logger.debug(f"audio_starts={audio_starts}")
|
470 |
+
logger.debug(f"frame_starts={frm_starts}")
|
471 |
+
logger.debug(f"frame_size={frm_size}")
|
472 |
+
|
473 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
474 |
+
ntokens = lengths.sum().item()
|
475 |
+
targets = data_utils.collate_tokens(
|
476 |
+
targets, pad_idx=pad, left_pad=False
|
477 |
+
)
|
478 |
+
return targets, lengths, ntokens
|
479 |
+
|
480 |
+
def collater_seq_label(self, targets, pad):
|
481 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
482 |
+
ntokens = lengths.sum().item()
|
483 |
+
targets = data_utils.collate_tokens(
|
484 |
+
targets, pad_idx=pad, left_pad=False
|
485 |
+
)
|
486 |
+
return targets, lengths, ntokens
|
487 |
+
|
488 |
+
def collater_seq_label_s2s(self, targets, pad):
|
489 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
490 |
+
ntokens = lengths.sum().item()
|
491 |
+
pad, eos = self.label_processors[0].dictionary.pad(), self.label_processors[0].dictionary.eos()
|
492 |
+
targets_ = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False)
|
493 |
+
prev_output_tokens = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False, move_eos_to_beginning=True)
|
494 |
+
return (targets_, prev_output_tokens), lengths, ntokens
|
495 |
+
|
496 |
+
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
497 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
498 |
+
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
499 |
+
for targets, label_rate, pad in itr:
|
500 |
+
if label_rate == -1:
|
501 |
+
if self.is_s2s:
|
502 |
+
targets, lengths, ntokens = self.collater_seq_label_s2s(targets, pad)
|
503 |
+
else:
|
504 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
505 |
+
else:
|
506 |
+
targets, lengths, ntokens = self.collater_frm_label(
|
507 |
+
targets, audio_size, audio_starts, label_rate, pad
|
508 |
+
)
|
509 |
+
targets_list.append(targets)
|
510 |
+
lengths_list.append(lengths)
|
511 |
+
ntokens_list.append(ntokens)
|
512 |
+
return targets_list, lengths_list, ntokens_list
|
513 |
+
|
514 |
+
def num_tokens(self, index):
|
515 |
+
return self.size(index)
|
516 |
+
|
517 |
+
def size(self, index):
|
518 |
+
if self.pad_audio:
|
519 |
+
return self.sizes[index]
|
520 |
+
return min(self.sizes[index], self.max_sample_size)
|
521 |
+
|
522 |
+
def ordered_indices(self):
|
523 |
+
if self.shuffle:
|
524 |
+
order = [np.random.permutation(len(self))]
|
525 |
+
else:
|
526 |
+
order = [np.arange(len(self))]
|
527 |
+
|
528 |
+
order.append(self.sizes)
|
529 |
+
return np.lexsort(order)[::-1]
|
slam_llm/models/avhubert/hubert_pretraining.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os, glob
|
9 |
+
import sys
|
10 |
+
from typing import Dict, List, Optional, Tuple
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from dataclasses import dataclass, field
|
15 |
+
from fairseq import metrics, search
|
16 |
+
from fairseq.data import Dictionary, encoders
|
17 |
+
from fairseq.dataclass.configs import FairseqDataclass
|
18 |
+
from fairseq.tasks import register_task
|
19 |
+
from fairseq.tasks.fairseq_task import FairseqTask
|
20 |
+
from omegaconf import MISSING, II
|
21 |
+
import numpy as np
|
22 |
+
from argparse import Namespace
|
23 |
+
|
24 |
+
DBG=True if len(sys.argv) == 1 else False
|
25 |
+
|
26 |
+
if DBG:
|
27 |
+
from hubert_dataset import AVHubertDataset
|
28 |
+
from sequence_generator import SequenceGenerator
|
29 |
+
else:
|
30 |
+
from .hubert_dataset import AVHubertDataset
|
31 |
+
from .sequence_generator import SequenceGenerator
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
class LabelEncoder(object):
|
37 |
+
def __init__(self, dictionary: Dictionary) -> None:
|
38 |
+
self.dictionary = dictionary
|
39 |
+
|
40 |
+
def __call__(self, label: str) -> List[str]:
|
41 |
+
return self.dictionary.encode_line(
|
42 |
+
label, append_eos=False, add_if_not_exist=False,
|
43 |
+
)
|
44 |
+
|
45 |
+
class LabelEncoderS2SToken(object):
|
46 |
+
def __init__(self, dictionary: Dictionary, bpe_tokenizer) -> None:
|
47 |
+
self.bpe_tokenizer = bpe_tokenizer
|
48 |
+
self.dictionary = dictionary
|
49 |
+
|
50 |
+
def __call__(self, label: str) -> List[str]:
|
51 |
+
label = self.bpe_tokenizer.encode(label.lower())
|
52 |
+
return self.dictionary.encode_line(
|
53 |
+
label, append_eos=True, add_if_not_exist=False,
|
54 |
+
).long()
|
55 |
+
|
56 |
+
def decode(self, tok, symbols_ignore=None):
|
57 |
+
tok = self.dictionary.string(tok, extra_symbols_to_ignore=symbols_ignore)
|
58 |
+
if self.bpe_tokenizer:
|
59 |
+
tok = self.bpe_tokenizer.decode(tok)
|
60 |
+
return tok
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class AVHubertPretrainingConfig(FairseqDataclass):
|
64 |
+
input_modality: str = II("task.input_modality") #??
|
65 |
+
data: str = field(
|
66 |
+
default=MISSING, metadata={"help": "path to data directory"}
|
67 |
+
)
|
68 |
+
labels: List[str] = field(
|
69 |
+
default_factory=lambda: ["ltr"],
|
70 |
+
metadata={
|
71 |
+
"help": (
|
72 |
+
"extension of the label files to load, frame-level labels for"
|
73 |
+
" pre-training, and sequence-level label for fine-tuning"
|
74 |
+
)
|
75 |
+
},
|
76 |
+
)
|
77 |
+
label_dir: Optional[str] = field(
|
78 |
+
default=None,
|
79 |
+
metadata={
|
80 |
+
"help": "if set, looks for labels in this directory instead",
|
81 |
+
},
|
82 |
+
)
|
83 |
+
label_rate: int = field(
|
84 |
+
default=-1,
|
85 |
+
metadata={"help": "label frame rate. -1 for sequence label"},
|
86 |
+
)
|
87 |
+
|
88 |
+
sample_rate: int = field(
|
89 |
+
default=16_000,
|
90 |
+
metadata={
|
91 |
+
"help": "target sample rate. audio files will be up/down "
|
92 |
+
"sampled to this rate"
|
93 |
+
},
|
94 |
+
)
|
95 |
+
normalize: bool = field(
|
96 |
+
default=False,
|
97 |
+
metadata={
|
98 |
+
"help": "if set, normalizes input to have 0 mean and unit variance"
|
99 |
+
},
|
100 |
+
)
|
101 |
+
enable_padding: bool = field(
|
102 |
+
default=False,
|
103 |
+
metadata={"help": "pad shorter samples instead of cropping"},
|
104 |
+
)
|
105 |
+
max_sample_size: Optional[int] = field(
|
106 |
+
default=None,
|
107 |
+
metadata={"help": "max sample size to keep in training"},
|
108 |
+
)
|
109 |
+
min_sample_size: Optional[int] = field(
|
110 |
+
default=None,
|
111 |
+
metadata={"help": "min sample size to keep in training"},
|
112 |
+
)
|
113 |
+
max_trim_sample_size: Optional[int] = field(
|
114 |
+
default=II("task.max_sample_size"),
|
115 |
+
metadata={"help": "max sample size to trim to for batching"},
|
116 |
+
)
|
117 |
+
single_target: Optional[bool] = field(
|
118 |
+
default=False,
|
119 |
+
metadata={
|
120 |
+
"help": "if set, AddTargetDatasets outputs same keys "
|
121 |
+
"as AddTargetDataset"
|
122 |
+
},
|
123 |
+
)
|
124 |
+
random_crop: Optional[bool] = field(
|
125 |
+
default=True,
|
126 |
+
metadata={"help": "always crop from the beginning if false"},
|
127 |
+
)
|
128 |
+
pad_audio: Optional[bool] = field(
|
129 |
+
default=False,
|
130 |
+
metadata={"help": "pad audio to the longest one in the batch if true"},
|
131 |
+
)
|
132 |
+
pdb: Optional[bool] = field(
|
133 |
+
default=False,
|
134 |
+
metadata={"help": "pdb"},
|
135 |
+
)
|
136 |
+
stack_order_audio: int = field(
|
137 |
+
default=1,
|
138 |
+
metadata={"help": "concatenate n consecutive audio frames for one step"},
|
139 |
+
)
|
140 |
+
skip_verify: Optional[bool] = field(
|
141 |
+
default=False,
|
142 |
+
metadata={"help": "skip verifying label-audio alignment"},
|
143 |
+
)
|
144 |
+
image_aug: bool = field(default=False, metadata={'help': 'image data augmentation'})
|
145 |
+
image_crop_size: int = field(
|
146 |
+
default=88, metadata={"help": "image ROI size"})
|
147 |
+
image_mean: float = field(
|
148 |
+
default=0.421, metadata={"help": "image mean"})
|
149 |
+
image_std: float = field(
|
150 |
+
default=0.165, metadata={"help": "image std"})
|
151 |
+
modalities: Optional[List[str]] = field(default_factory=lambda: ["audio", "video"], metadata={'help': 'modalities to load'})
|
152 |
+
is_s2s: bool=field(default=False, metadata={'help': 'seq2seq fine-tuning only'})
|
153 |
+
tokenizer_bpe_name: Optional[str] = field(default=None, metadata={'help': 'tokenizer model name'})
|
154 |
+
tokenizer_bpe_model: Optional[str] = field(default=None, metadata={'help': 'tokenizer model path'})
|
155 |
+
noise_wav: Optional[str] = field(default=None, metadata={'help': 'manifest of noise wav files (one wav file path per line)'})
|
156 |
+
noise_prob: float = field(default=0, metadata={'help': 'noise probability'})
|
157 |
+
noise_snr: Optional[str] = field(default='0', metadata={'help': 'noise SNR in audio'})
|
158 |
+
noise_num: int = field(default=1, metadata={'help': 'number of noise wav files to mix'})
|
159 |
+
fine_tuning: bool = field(default=False, metadata={"help": "set to true if fine-tuning AV-Hubert"})
|
160 |
+
|
161 |
+
@register_task("av_hubert_pretraining", dataclass=AVHubertPretrainingConfig)
|
162 |
+
class AVHubertPretrainingTask(FairseqTask):
|
163 |
+
|
164 |
+
cfg: AVHubertPretrainingConfig
|
165 |
+
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
cfg: AVHubertPretrainingConfig,
|
169 |
+
) -> None:
|
170 |
+
super().__init__(cfg)
|
171 |
+
|
172 |
+
logger.info(f"current directory is {os.getcwd()}")
|
173 |
+
logger.info(f"AVHubertPretrainingTask Config {cfg}")
|
174 |
+
|
175 |
+
self.fine_tuning = cfg.fine_tuning
|
176 |
+
if cfg.fine_tuning:
|
177 |
+
self.state.add_factory("target_dictionary", self.load_dictionaries)
|
178 |
+
if cfg.is_s2s:
|
179 |
+
self.state.add_factory("s2s_tokenizer", self.load_tokenizer)
|
180 |
+
else:
|
181 |
+
self.state.add_factory("dictionaries", self.load_dictionaries)
|
182 |
+
|
183 |
+
self.blank_symbol = "<s>"
|
184 |
+
|
185 |
+
@property
|
186 |
+
def source_dictionary(self) -> Optional[Dictionary]:
|
187 |
+
return None # self._source_dictionary
|
188 |
+
|
189 |
+
@property
|
190 |
+
def target_dictionary(self) -> Optional[Dictionary]:
|
191 |
+
return self.state.target_dictionary # self._target_dictionary
|
192 |
+
|
193 |
+
@property
|
194 |
+
def dictionaries(self) -> List[Dictionary]:
|
195 |
+
return self.state.dictionaries
|
196 |
+
|
197 |
+
def load_dictionaries(self):
|
198 |
+
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
|
199 |
+
dictionaries = [
|
200 |
+
Dictionary.load(f"{label_dir}/dict.{label}.txt")
|
201 |
+
for label in self.cfg.labels
|
202 |
+
]
|
203 |
+
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
|
204 |
+
|
205 |
+
def load_tokenizer(self):
|
206 |
+
bpe_args = Namespace(**{'bpe': self.cfg.tokenizer_bpe_name, f"{self.cfg.tokenizer_bpe_name}_model": self.cfg.tokenizer_bpe_model})
|
207 |
+
bpe_tokenizer = encoders.build_bpe(bpe_args)
|
208 |
+
return bpe_tokenizer
|
209 |
+
|
210 |
+
@property
|
211 |
+
def s2s_tokenizer(self):
|
212 |
+
return self.state.s2s_tokenizer
|
213 |
+
|
214 |
+
@classmethod
|
215 |
+
def setup_task(
|
216 |
+
cls, cfg: AVHubertPretrainingConfig, **kwargs
|
217 |
+
) -> "AVHubertPretrainingTask":
|
218 |
+
if cfg.pdb:
|
219 |
+
import pdb
|
220 |
+
pdb.set_trace()
|
221 |
+
return cls(cfg)
|
222 |
+
|
223 |
+
def get_label_dir(self) -> str:
|
224 |
+
if self.cfg.label_dir is None:
|
225 |
+
return self.cfg.data
|
226 |
+
return self.cfg.label_dir
|
227 |
+
|
228 |
+
def load_dataset(self, split: str, **kwargs) -> None:
|
229 |
+
manifest = f"{self.cfg.data}/{split}.tsv"
|
230 |
+
dictionaries = [self.target_dictionary] if self.fine_tuning else self.dictionaries
|
231 |
+
pad_list = [dictionary.pad() for dictionary in dictionaries]
|
232 |
+
eos_list = [dictionary.eos() for dictionary in dictionaries]
|
233 |
+
if not self.cfg.is_s2s:
|
234 |
+
procs = [LabelEncoder(dictionary) for dictionary in dictionaries]
|
235 |
+
else:
|
236 |
+
logger.info(f"Using tokenizer")
|
237 |
+
bpe_tokenizer = self.s2s_tokenizer
|
238 |
+
procs = [LabelEncoderS2SToken(dictionary, bpe_tokenizer) for dictionary in dictionaries]
|
239 |
+
paths = [
|
240 |
+
f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels
|
241 |
+
]
|
242 |
+
image_aug = self.cfg.image_aug if split == 'train' else False
|
243 |
+
noise_fn, noise_snr = f"{self.cfg.noise_wav}/{split}.tsv" if self.cfg.noise_wav is not None else None, eval(self.cfg.noise_snr)
|
244 |
+
noise_num = self.cfg.noise_num #
|
245 |
+
self.datasets[split] = AVHubertDataset(
|
246 |
+
manifest,
|
247 |
+
sample_rate=self.cfg.sample_rate,
|
248 |
+
label_paths=paths,
|
249 |
+
label_rates=self.cfg.label_rate,
|
250 |
+
pad_list=pad_list,
|
251 |
+
eos_list=eos_list,
|
252 |
+
label_processors=procs,
|
253 |
+
max_keep_sample_size=self.cfg.max_sample_size,
|
254 |
+
min_keep_sample_size=self.cfg.min_sample_size,
|
255 |
+
max_sample_size=self.cfg.max_trim_sample_size,
|
256 |
+
pad_audio=self.cfg.pad_audio,
|
257 |
+
normalize=self.cfg.normalize,
|
258 |
+
store_labels=False,
|
259 |
+
random_crop=self.cfg.random_crop,
|
260 |
+
single_target=self.cfg.single_target,
|
261 |
+
stack_order_audio=self.cfg.stack_order_audio,
|
262 |
+
skip_verify=self.cfg.skip_verify,
|
263 |
+
image_mean=self.cfg.image_mean,
|
264 |
+
image_std=self.cfg.image_std,
|
265 |
+
image_crop_size=self.cfg.image_crop_size,
|
266 |
+
image_aug=image_aug,
|
267 |
+
modalities=self.cfg.modalities,
|
268 |
+
is_s2s=self.cfg.is_s2s,
|
269 |
+
noise_fn=noise_fn,
|
270 |
+
noise_prob=self.cfg.noise_prob,
|
271 |
+
noise_snr=noise_snr,
|
272 |
+
noise_num=noise_num
|
273 |
+
)
|
274 |
+
|
275 |
+
def max_positions(self) -> Tuple[int, int]:
|
276 |
+
return (sys.maxsize, sys.maxsize)
|
277 |
+
|
278 |
+
def filter_indices_by_size(
|
279 |
+
self, indices: np.array, *args, **kwargs
|
280 |
+
) -> np.array:
|
281 |
+
return indices
|
282 |
+
|
283 |
+
def build_generator(
|
284 |
+
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
|
285 |
+
):
|
286 |
+
"""
|
287 |
+
Build a :class:`~fairseq.SequenceGenerator` instance for this
|
288 |
+
task.
|
289 |
+
Args:
|
290 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
291 |
+
args (fairseq.dataclass.configs.GenerationConfig):
|
292 |
+
configuration object (dataclass) for generation
|
293 |
+
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
|
294 |
+
through to SequenceGenerator
|
295 |
+
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
|
296 |
+
If provided, this function constrains the beam search to
|
297 |
+
allowed tokens only at each step. The provided function
|
298 |
+
should take 2 arguments: the batch ID (`batch_id: int`)
|
299 |
+
and a unidimensional tensor of token ids (`inputs_ids:
|
300 |
+
torch.Tensor`). It has to return a `List[int]` with the
|
301 |
+
allowed tokens for the next generation step conditioned
|
302 |
+
on the previously generated tokens (`inputs_ids`) and
|
303 |
+
the batch ID (`batch_id`). This argument is useful for
|
304 |
+
constrained generation conditioned on the prefix, as
|
305 |
+
described in "Autoregressive Entity Retrieval"
|
306 |
+
(https://arxiv.org/abs/2010.00904) and
|
307 |
+
https://github.com/facebookresearch/GENRE.
|
308 |
+
"""
|
309 |
+
if getattr(args, "score_reference", False):
|
310 |
+
from fairseq.sequence_scorer import SequenceScorer
|
311 |
+
|
312 |
+
return SequenceScorer(
|
313 |
+
self.target_dictionary,
|
314 |
+
compute_alignment=getattr(args, "print_alignment", False),
|
315 |
+
)
|
316 |
+
|
317 |
+
# Choose search strategy. Defaults to Beam Search.
|
318 |
+
sampling = getattr(args, "sampling", False)
|
319 |
+
sampling_topk = getattr(args, "sampling_topk", -1)
|
320 |
+
sampling_topp = getattr(args, "sampling_topp", -1.0)
|
321 |
+
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
|
322 |
+
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
|
323 |
+
match_source_len = getattr(args, "match_source_len", False)
|
324 |
+
diversity_rate = getattr(args, "diversity_rate", -1)
|
325 |
+
constrained = getattr(args, "constraints", False)
|
326 |
+
if prefix_allowed_tokens_fn is None:
|
327 |
+
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
|
328 |
+
if (
|
329 |
+
sum(
|
330 |
+
int(cond)
|
331 |
+
for cond in [
|
332 |
+
sampling,
|
333 |
+
diverse_beam_groups > 0,
|
334 |
+
match_source_len,
|
335 |
+
diversity_rate > 0,
|
336 |
+
]
|
337 |
+
)
|
338 |
+
> 1
|
339 |
+
):
|
340 |
+
raise ValueError("Provided Search parameters are mutually exclusive.")
|
341 |
+
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
|
342 |
+
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
|
343 |
+
|
344 |
+
if sampling:
|
345 |
+
search_strategy = search.Sampling(
|
346 |
+
self.target_dictionary, sampling_topk, sampling_topp
|
347 |
+
)
|
348 |
+
elif diverse_beam_groups > 0:
|
349 |
+
search_strategy = search.DiverseBeamSearch(
|
350 |
+
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
|
351 |
+
)
|
352 |
+
elif match_source_len:
|
353 |
+
# this is useful for tagging applications where the output
|
354 |
+
# length should match the input length, so we hardcode the
|
355 |
+
# length constraints for simplicity
|
356 |
+
search_strategy = search.LengthConstrainedBeamSearch(
|
357 |
+
self.target_dictionary,
|
358 |
+
min_len_a=1,
|
359 |
+
min_len_b=0,
|
360 |
+
max_len_a=1,
|
361 |
+
max_len_b=0,
|
362 |
+
)
|
363 |
+
elif diversity_rate > -1:
|
364 |
+
search_strategy = search.DiverseSiblingsSearch(
|
365 |
+
self.target_dictionary, diversity_rate
|
366 |
+
)
|
367 |
+
elif constrained:
|
368 |
+
search_strategy = search.LexicallyConstrainedBeamSearch(
|
369 |
+
self.target_dictionary, args.constraints
|
370 |
+
)
|
371 |
+
elif prefix_allowed_tokens_fn:
|
372 |
+
search_strategy = search.PrefixConstrainedBeamSearch(
|
373 |
+
self.target_dictionary, prefix_allowed_tokens_fn
|
374 |
+
)
|
375 |
+
else:
|
376 |
+
search_strategy = search.BeamSearch(self.target_dictionary)
|
377 |
+
|
378 |
+
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
|
379 |
+
if seq_gen_cls is None:
|
380 |
+
if getattr(args, "print_alignment", False):
|
381 |
+
seq_gen_cls = SequenceGeneratorWithAlignment
|
382 |
+
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
|
383 |
+
else:
|
384 |
+
seq_gen_cls = SequenceGenerator
|
385 |
+
|
386 |
+
return seq_gen_cls(
|
387 |
+
models,
|
388 |
+
self.target_dictionary,
|
389 |
+
beam_size=getattr(args, "beam", 5),
|
390 |
+
max_len_a=getattr(args, "max_len_a", 0),
|
391 |
+
max_len_b=getattr(args, "max_len_b", 200),
|
392 |
+
min_len=getattr(args, "min_len", 1),
|
393 |
+
normalize_scores=(not getattr(args, "unnormalized", False)),
|
394 |
+
len_penalty=getattr(args, "lenpen", 1),
|
395 |
+
unk_penalty=getattr(args, "unkpen", 0),
|
396 |
+
temperature=getattr(args, "temperature", 1.0),
|
397 |
+
match_source_len=getattr(args, "match_source_len", False),
|
398 |
+
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
399 |
+
search_strategy=search_strategy,
|
400 |
+
**extra_gen_cls_kwargs,
|
401 |
+
)
|
slam_llm/models/avhubert/infer_s2s.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import ast
|
8 |
+
from itertools import chain
|
9 |
+
import logging
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import json
|
14 |
+
import hashlib
|
15 |
+
import editdistance
|
16 |
+
from argparse import Namespace
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from fairseq import checkpoint_utils, options, tasks, utils, distributed_utils
|
21 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
22 |
+
from fairseq.logging import progress_bar
|
23 |
+
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
24 |
+
from fairseq.models import FairseqLanguageModel
|
25 |
+
from omegaconf import DictConfig
|
26 |
+
|
27 |
+
from pathlib import Path
|
28 |
+
import hydra
|
29 |
+
from hydra.core.config_store import ConfigStore
|
30 |
+
from fairseq.dataclass.configs import (
|
31 |
+
CheckpointConfig,
|
32 |
+
CommonConfig,
|
33 |
+
CommonEvalConfig,
|
34 |
+
DatasetConfig,
|
35 |
+
DistributedTrainingConfig,
|
36 |
+
GenerationConfig,
|
37 |
+
FairseqDataclass,
|
38 |
+
)
|
39 |
+
from dataclasses import dataclass, field, is_dataclass
|
40 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
41 |
+
from omegaconf import OmegaConf
|
42 |
+
|
43 |
+
logging.root.setLevel(logging.INFO)
|
44 |
+
logging.basicConfig(level=logging.INFO)
|
45 |
+
logger = logging.getLogger(__name__)
|
46 |
+
|
47 |
+
config_path = Path(__file__).resolve().parent / "conf"
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class OverrideConfig(FairseqDataclass):
|
51 |
+
noise_wav: Optional[str] = field(default=None, metadata={'help': 'noise wav file'})
|
52 |
+
noise_prob: float = field(default=0, metadata={'help': 'noise probability'})
|
53 |
+
noise_snr: float = field(default=0, metadata={'help': 'noise SNR in audio'})
|
54 |
+
modalities: List[str] = field(default_factory=lambda: [""], metadata={'help': 'which modality to use'})
|
55 |
+
data: Optional[str] = field(default=None, metadata={'help': 'path to test data directory'})
|
56 |
+
label_dir: Optional[str] = field(default=None, metadata={'help': 'path to test label directory'})
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class InferConfig(FairseqDataclass):
|
60 |
+
task: Any = None
|
61 |
+
generation: GenerationConfig = GenerationConfig()
|
62 |
+
common: CommonConfig = CommonConfig()
|
63 |
+
common_eval: CommonEvalConfig = CommonEvalConfig()
|
64 |
+
checkpoint: CheckpointConfig = CheckpointConfig()
|
65 |
+
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
|
66 |
+
dataset: DatasetConfig = DatasetConfig()
|
67 |
+
override: OverrideConfig = OverrideConfig()
|
68 |
+
is_ax: bool = field(
|
69 |
+
default=False,
|
70 |
+
metadata={
|
71 |
+
"help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
|
72 |
+
},
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def main(cfg: DictConfig):
|
77 |
+
|
78 |
+
if isinstance(cfg, Namespace):
|
79 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
80 |
+
|
81 |
+
assert cfg.common_eval.path is not None, "--path required for recognition!"
|
82 |
+
assert (
|
83 |
+
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
|
84 |
+
), "--sampling requires --nbest to be equal to --beam"
|
85 |
+
|
86 |
+
if cfg.common_eval.results_path is not None:
|
87 |
+
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
|
88 |
+
output_path = os.path.join(cfg.common_eval.results_path, "decode.log")
|
89 |
+
with open(output_path, "w", buffering=1, encoding="utf-8") as h:
|
90 |
+
return _main(cfg, h)
|
91 |
+
return _main(cfg, sys.stdout)
|
92 |
+
|
93 |
+
|
94 |
+
def get_symbols_to_strip_from_output(generator):
|
95 |
+
if hasattr(generator, "symbols_to_strip_from_output"):
|
96 |
+
return generator.symbols_to_strip_from_output
|
97 |
+
else:
|
98 |
+
return {generator.eos, generator.pad}
|
99 |
+
|
100 |
+
def _main(cfg, output_file):
|
101 |
+
logging.basicConfig(
|
102 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
103 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
104 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
105 |
+
stream=output_file,
|
106 |
+
)
|
107 |
+
logger = logging.getLogger("hybrid.speech_recognize")
|
108 |
+
if output_file is not sys.stdout: # also print to stdout
|
109 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
110 |
+
|
111 |
+
utils.import_user_module(cfg.common)
|
112 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([cfg.common_eval.path])
|
113 |
+
models = [model.eval().cuda() for model in models] #!!
|
114 |
+
saved_cfg.task.modalities = cfg.override.modalities
|
115 |
+
task = tasks.setup_task(saved_cfg.task)
|
116 |
+
|
117 |
+
task.build_tokenizer(saved_cfg.tokenizer)
|
118 |
+
task.build_bpe(saved_cfg.bpe)
|
119 |
+
|
120 |
+
logger.info(cfg)
|
121 |
+
|
122 |
+
# Fix seed for stochastic decoding
|
123 |
+
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
124 |
+
np.random.seed(cfg.common.seed)
|
125 |
+
utils.set_torch_seed(cfg.common.seed)
|
126 |
+
|
127 |
+
use_cuda = torch.cuda.is_available()
|
128 |
+
|
129 |
+
# Set dictionary
|
130 |
+
dictionary = task.target_dictionary
|
131 |
+
|
132 |
+
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
|
133 |
+
task.cfg.noise_prob = cfg.override.noise_prob
|
134 |
+
task.cfg.noise_snr = cfg.override.noise_snr
|
135 |
+
task.cfg.noise_wav = cfg.override.noise_wav
|
136 |
+
if cfg.override.data is not None:
|
137 |
+
task.cfg.data = cfg.override.data
|
138 |
+
if cfg.override.label_dir is not None:
|
139 |
+
task.cfg.label_dir = cfg.override.label_dir
|
140 |
+
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
|
141 |
+
|
142 |
+
lms = [None]
|
143 |
+
|
144 |
+
# Optimize ensemble for generation
|
145 |
+
for model in chain(models, lms):
|
146 |
+
if model is None:
|
147 |
+
continue
|
148 |
+
if cfg.common.fp16:
|
149 |
+
model.half()
|
150 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
151 |
+
model.cuda()
|
152 |
+
model.prepare_for_inference_(cfg)
|
153 |
+
|
154 |
+
# Load dataset (possibly sharded)
|
155 |
+
itr = task.get_batch_iterator(
|
156 |
+
dataset=task.dataset(cfg.dataset.gen_subset),
|
157 |
+
max_tokens=cfg.dataset.max_tokens,
|
158 |
+
max_sentences=cfg.dataset.batch_size,
|
159 |
+
max_positions=utils.resolve_max_positions(
|
160 |
+
task.max_positions(), *[m.max_positions() for m in models]
|
161 |
+
),
|
162 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
163 |
+
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
164 |
+
seed=cfg.common.seed,
|
165 |
+
num_shards=cfg.distributed_training.distributed_world_size,
|
166 |
+
shard_id=cfg.distributed_training.distributed_rank,
|
167 |
+
num_workers=cfg.dataset.num_workers,
|
168 |
+
data_buffer_size=cfg.dataset.data_buffer_size,
|
169 |
+
).next_epoch_itr(shuffle=False)
|
170 |
+
progress = progress_bar.progress_bar(
|
171 |
+
itr,
|
172 |
+
log_format=cfg.common.log_format,
|
173 |
+
log_interval=cfg.common.log_interval,
|
174 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
175 |
+
)
|
176 |
+
|
177 |
+
# Initialize generator
|
178 |
+
if cfg.generation.match_source_len:
|
179 |
+
logger.warning(
|
180 |
+
"The option match_source_len is not applicable to speech recognition. Ignoring it."
|
181 |
+
)
|
182 |
+
gen_timer = StopwatchMeter()
|
183 |
+
extra_gen_cls_kwargs = {
|
184 |
+
"lm_model": lms[0],
|
185 |
+
"lm_weight": cfg.generation.lm_weight,
|
186 |
+
}
|
187 |
+
cfg.generation.score_reference = False #
|
188 |
+
save_attention_plot = cfg.generation.print_alignment is not None
|
189 |
+
cfg.generation.print_alignment = None #
|
190 |
+
generator = task.build_generator(
|
191 |
+
models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
192 |
+
)
|
193 |
+
|
194 |
+
def decode_fn(x):
|
195 |
+
symbols_ignore = get_symbols_to_strip_from_output(generator)
|
196 |
+
symbols_ignore.add(dictionary.pad())
|
197 |
+
if hasattr(task.datasets[cfg.dataset.gen_subset].label_processors[0], 'decode'):
|
198 |
+
return task.datasets[cfg.dataset.gen_subset].label_processors[0].decode(x, symbols_ignore)
|
199 |
+
chars = dictionary.string(x, extra_symbols_to_ignore=symbols_ignore)
|
200 |
+
words = " ".join("".join(chars.split()).replace('|', ' ').split())
|
201 |
+
return words
|
202 |
+
|
203 |
+
num_sentences = 0
|
204 |
+
has_target = True
|
205 |
+
wps_meter = TimeMeter()
|
206 |
+
result_dict = {'utt_id': [], 'ref': [], 'hypo': []}
|
207 |
+
for sample in progress:
|
208 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
209 |
+
if "net_input" not in sample:
|
210 |
+
continue
|
211 |
+
|
212 |
+
prefix_tokens = None
|
213 |
+
if cfg.generation.prefix_size > 0:
|
214 |
+
prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
|
215 |
+
|
216 |
+
constraints = None
|
217 |
+
if "constraints" in sample:
|
218 |
+
constraints = sample["constraints"]
|
219 |
+
|
220 |
+
gen_timer.start()
|
221 |
+
hypos = task.inference_step(
|
222 |
+
generator,
|
223 |
+
models,
|
224 |
+
sample,
|
225 |
+
prefix_tokens=prefix_tokens,
|
226 |
+
constraints=constraints,
|
227 |
+
)
|
228 |
+
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
|
229 |
+
gen_timer.stop(num_generated_tokens)
|
230 |
+
|
231 |
+
for i in range(len(sample["id"])):
|
232 |
+
result_dict['utt_id'].append(sample['utt_id'][i])
|
233 |
+
ref_sent = decode_fn(sample['target'][i].int().cpu())
|
234 |
+
result_dict['ref'].append(ref_sent)
|
235 |
+
best_hypo = hypos[i][0]['tokens'].int().cpu()
|
236 |
+
hypo_str = decode_fn(best_hypo)
|
237 |
+
result_dict['hypo'].append(hypo_str)
|
238 |
+
logger.info(f"\nREF:{ref_sent}\nHYP:{hypo_str}\n")
|
239 |
+
wps_meter.update(num_generated_tokens)
|
240 |
+
progress.log({"wps": round(wps_meter.avg)})
|
241 |
+
num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
|
242 |
+
|
243 |
+
logger.info("NOTE: hypothesis and token scores are output in base 2")
|
244 |
+
logger.info("Recognized {:,} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
|
245 |
+
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
|
246 |
+
|
247 |
+
yaml_str = OmegaConf.to_yaml(cfg.generation)
|
248 |
+
fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16)
|
249 |
+
fid = fid % 1000000
|
250 |
+
result_fn = f"{cfg.common_eval.results_path}/hypo-{fid}.json"
|
251 |
+
json.dump(result_dict, open(result_fn, 'w'), indent=4)
|
252 |
+
n_err, n_total = 0, 0
|
253 |
+
assert len(result_dict['hypo']) == len(result_dict['ref'])
|
254 |
+
for hypo, ref in zip(result_dict['hypo'], result_dict['ref']):
|
255 |
+
hypo, ref = hypo.strip().split(), ref.strip().split()
|
256 |
+
n_err += editdistance.eval(hypo, ref)
|
257 |
+
n_total += len(ref)
|
258 |
+
wer = 100 * n_err / n_total
|
259 |
+
wer_fn = f"{cfg.common_eval.results_path}/wer.{fid}"
|
260 |
+
with open(wer_fn, "w") as fo:
|
261 |
+
fo.write(f"WER: {wer}\n")
|
262 |
+
fo.write(f"err / num_ref_words = {n_err} / {n_total}\n\n")
|
263 |
+
fo.write(f"{yaml_str}")
|
264 |
+
logger.info(f"WER: {wer}%")
|
265 |
+
return
|
266 |
+
|
267 |
+
|
268 |
+
@hydra.main(config_path=config_path, config_name="infer")
|
269 |
+
def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]:
|
270 |
+
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
|
271 |
+
cfg = OmegaConf.create(container)
|
272 |
+
OmegaConf.set_struct(cfg, True)
|
273 |
+
|
274 |
+
if cfg.common.reset_logging:
|
275 |
+
reset_logging()
|
276 |
+
|
277 |
+
wer = float("inf")
|
278 |
+
|
279 |
+
try:
|
280 |
+
if cfg.common.profile:
|
281 |
+
with torch.cuda.profiler.profile():
|
282 |
+
with torch.autograd.profiler.emit_nvtx():
|
283 |
+
distributed_utils.call_main(cfg, main)
|
284 |
+
else:
|
285 |
+
distributed_utils.call_main(cfg, main)
|
286 |
+
|
287 |
+
except BaseException as e: # pylint: disable=broad-except
|
288 |
+
if not cfg.common.suppress_crashes:
|
289 |
+
raise
|
290 |
+
else:
|
291 |
+
logger.error("Crashed! %s", str(e))
|
292 |
+
return
|
293 |
+
|
294 |
+
|
295 |
+
def cli_main() -> None:
|
296 |
+
try:
|
297 |
+
from hydra._internal.utils import (
|
298 |
+
get_args,
|
299 |
+
) # pylint: disable=import-outside-toplevel
|
300 |
+
|
301 |
+
cfg_name = get_args().config_name or "infer"
|
302 |
+
except ImportError:
|
303 |
+
logger.warning("Failed to get config name from hydra args")
|
304 |
+
cfg_name = "infer"
|
305 |
+
|
306 |
+
cs = ConfigStore.instance()
|
307 |
+
cs.store(name=cfg_name, node=InferConfig)
|
308 |
+
|
309 |
+
for k in InferConfig.__dataclass_fields__:
|
310 |
+
if is_dataclass(InferConfig.__dataclass_fields__[k].type):
|
311 |
+
v = InferConfig.__dataclass_fields__[k].default
|
312 |
+
cs.store(name=k, node=v)
|
313 |
+
|
314 |
+
hydra_main() # pylint: disable=no-value-for-parameter
|
315 |
+
|
316 |
+
|
317 |
+
if __name__ == "__main__":
|
318 |
+
cli_main()
|
slam_llm/models/avhubert/resnet.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import torch.nn as nn
|
10 |
+
import pdb
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
17 |
+
padding=1, bias=False)
|
18 |
+
|
19 |
+
|
20 |
+
def downsample_basic_block( inplanes, outplanes, stride ):
|
21 |
+
return nn.Sequential(
|
22 |
+
nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False),
|
23 |
+
nn.BatchNorm2d(outplanes),
|
24 |
+
)
|
25 |
+
|
26 |
+
def downsample_basic_block_v2( inplanes, outplanes, stride ):
|
27 |
+
return nn.Sequential(
|
28 |
+
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False),
|
29 |
+
nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False),
|
30 |
+
nn.BatchNorm2d(outplanes),
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
class BasicBlock(nn.Module):
|
36 |
+
expansion = 1
|
37 |
+
|
38 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type = 'relu' ):
|
39 |
+
super(BasicBlock, self).__init__()
|
40 |
+
|
41 |
+
assert relu_type in ['relu','prelu']
|
42 |
+
|
43 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
44 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
45 |
+
|
46 |
+
if relu_type == 'relu':
|
47 |
+
self.relu1 = nn.ReLU(inplace=True)
|
48 |
+
self.relu2 = nn.ReLU(inplace=True)
|
49 |
+
elif relu_type == 'prelu':
|
50 |
+
self.relu1 = nn.PReLU(num_parameters=planes)
|
51 |
+
self.relu2 = nn.PReLU(num_parameters=planes)
|
52 |
+
else:
|
53 |
+
raise Exception('relu type not implemented')
|
54 |
+
|
55 |
+
self.conv2 = conv3x3(planes, planes)
|
56 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
57 |
+
|
58 |
+
self.downsample = downsample
|
59 |
+
self.stride = stride
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
residual = x
|
63 |
+
out = self.conv1(x)
|
64 |
+
out = self.bn1(out)
|
65 |
+
out = self.relu1(out)
|
66 |
+
out = self.conv2(out)
|
67 |
+
out = self.bn2(out)
|
68 |
+
if self.downsample is not None:
|
69 |
+
residual = self.downsample(x)
|
70 |
+
|
71 |
+
out += residual
|
72 |
+
out = self.relu2(out)
|
73 |
+
|
74 |
+
return out
|
75 |
+
|
76 |
+
|
77 |
+
class ResNet(nn.Module):
|
78 |
+
|
79 |
+
def __init__(self, block, layers, num_classes=1000, relu_type = 'relu', gamma_zero = False, avg_pool_downsample = False):
|
80 |
+
self.inplanes = 64
|
81 |
+
self.relu_type = relu_type
|
82 |
+
self.gamma_zero = gamma_zero
|
83 |
+
self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block
|
84 |
+
|
85 |
+
super(ResNet, self).__init__()
|
86 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
87 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
88 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
89 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
90 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
91 |
+
|
92 |
+
for m in self.modules():
|
93 |
+
if isinstance(m, nn.Conv2d):
|
94 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
95 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
96 |
+
elif isinstance(m, nn.BatchNorm2d):
|
97 |
+
m.weight.data.fill_(1)
|
98 |
+
m.bias.data.zero_()
|
99 |
+
|
100 |
+
if self.gamma_zero:
|
101 |
+
for m in self.modules():
|
102 |
+
if isinstance(m, BasicBlock ):
|
103 |
+
m.bn2.weight.data.zero_()
|
104 |
+
|
105 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
106 |
+
|
107 |
+
|
108 |
+
downsample = None
|
109 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
110 |
+
downsample = self.downsample_block( inplanes = self.inplanes,
|
111 |
+
outplanes = planes * block.expansion,
|
112 |
+
stride = stride )
|
113 |
+
|
114 |
+
layers = []
|
115 |
+
layers.append(block(self.inplanes, planes, stride, downsample, relu_type = self.relu_type))
|
116 |
+
self.inplanes = planes * block.expansion
|
117 |
+
for i in range(1, blocks):
|
118 |
+
layers.append(block(self.inplanes, planes, relu_type = self.relu_type))
|
119 |
+
|
120 |
+
return nn.Sequential(*layers)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
x = self.layer1(x)
|
124 |
+
x = self.layer2(x)
|
125 |
+
x = self.layer3(x)
|
126 |
+
x = self.layer4(x)
|
127 |
+
x = self.avgpool(x)
|
128 |
+
x = x.view(x.size(0), -1)
|
129 |
+
return x
|
130 |
+
|
131 |
+
class ResEncoder(nn.Module):
|
132 |
+
def __init__(self, relu_type, weights):
|
133 |
+
super(ResEncoder, self).__init__()
|
134 |
+
self.frontend_nout = 64
|
135 |
+
self.backend_out = 512
|
136 |
+
frontend_relu = nn.PReLU(num_parameters=self.frontend_nout) if relu_type == 'prelu' else nn.ReLU()
|
137 |
+
self.frontend3D = nn.Sequential(
|
138 |
+
nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False),
|
139 |
+
nn.BatchNorm3d(self.frontend_nout),
|
140 |
+
frontend_relu,
|
141 |
+
nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))
|
142 |
+
self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)
|
143 |
+
if weights is not None:
|
144 |
+
logger.info(f"Load {weights} for resnet")
|
145 |
+
std = torch.load(weights, map_location=torch.device('cpu'))['model_state_dict']
|
146 |
+
frontend_std, trunk_std = OrderedDict(), OrderedDict()
|
147 |
+
for key, val in std.items():
|
148 |
+
new_key = '.'.join(key.split('.')[1:])
|
149 |
+
if 'frontend3D' in key:
|
150 |
+
frontend_std[new_key] = val
|
151 |
+
if 'trunk' in key:
|
152 |
+
trunk_std[new_key] = val
|
153 |
+
self.frontend3D.load_state_dict(frontend_std)
|
154 |
+
self.trunk.load_state_dict(trunk_std)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
B, C, T, H, W = x.size()
|
158 |
+
x = self.frontend3D(x)
|
159 |
+
Tnew = x.shape[2]
|
160 |
+
x = self.threeD_to_2D_tensor(x)
|
161 |
+
x = self.trunk(x)
|
162 |
+
x = x.view(B, Tnew, x.size(1))
|
163 |
+
x = x.transpose(1, 2).contiguous()
|
164 |
+
return x
|
165 |
+
|
166 |
+
def threeD_to_2D_tensor(self, x):
|
167 |
+
n_batch, n_channels, s_time, sx, sy = x.shape
|
168 |
+
x = x.transpose(1, 2).contiguous()
|
169 |
+
return x.reshape(n_batch*s_time, n_channels, sx, sy)
|
slam_llm/models/avhubert/sequence_generator.py
ADDED
@@ -0,0 +1,985 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from fairseq import search, utils
|
14 |
+
from fairseq.data import data_utils
|
15 |
+
from fairseq.models import FairseqIncrementalDecoder
|
16 |
+
from torch import Tensor
|
17 |
+
from fairseq.ngram_repeat_block import NGramRepeatBlock
|
18 |
+
|
19 |
+
|
20 |
+
class SequenceGenerator(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
models,
|
24 |
+
tgt_dict,
|
25 |
+
beam_size=1,
|
26 |
+
max_len_a=0,
|
27 |
+
max_len_b=200,
|
28 |
+
max_len=0,
|
29 |
+
min_len=1,
|
30 |
+
normalize_scores=True,
|
31 |
+
len_penalty=1.0,
|
32 |
+
unk_penalty=0.0,
|
33 |
+
temperature=1.0,
|
34 |
+
match_source_len=False,
|
35 |
+
no_repeat_ngram_size=0,
|
36 |
+
search_strategy=None,
|
37 |
+
eos=None,
|
38 |
+
symbols_to_strip_from_output=None,
|
39 |
+
lm_model=None,
|
40 |
+
lm_weight=1.0,
|
41 |
+
):
|
42 |
+
"""Generates translations of a given source sentence.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models,
|
46 |
+
currently support fairseq.models.TransformerModel for scripting
|
47 |
+
beam_size (int, optional): beam width (default: 1)
|
48 |
+
max_len_a/b (int, optional): generate sequences of maximum length
|
49 |
+
ax + b, where x is the source length
|
50 |
+
max_len (int, optional): the maximum length of the generated output
|
51 |
+
(not including end-of-sentence)
|
52 |
+
min_len (int, optional): the minimum length of the generated output
|
53 |
+
(not including end-of-sentence)
|
54 |
+
normalize_scores (bool, optional): normalize scores by the length
|
55 |
+
of the output (default: True)
|
56 |
+
len_penalty (float, optional): length penalty, where <1.0 favors
|
57 |
+
shorter, >1.0 favors longer sentences (default: 1.0)
|
58 |
+
unk_penalty (float, optional): unknown word penalty, where <0
|
59 |
+
produces more unks, >0 produces fewer (default: 0.0)
|
60 |
+
temperature (float, optional): temperature, where values
|
61 |
+
>1.0 produce more uniform samples and values <1.0 produce
|
62 |
+
sharper samples (default: 1.0)
|
63 |
+
match_source_len (bool, optional): outputs should match the source
|
64 |
+
length (default: False)
|
65 |
+
"""
|
66 |
+
super().__init__()
|
67 |
+
if isinstance(models, EnsembleModel):
|
68 |
+
self.model = models
|
69 |
+
else:
|
70 |
+
self.model = EnsembleModel(models)
|
71 |
+
self.tgt_dict = tgt_dict
|
72 |
+
self.pad = tgt_dict.pad()
|
73 |
+
self.unk = tgt_dict.unk()
|
74 |
+
self.eos = tgt_dict.eos() if eos is None else eos
|
75 |
+
self.symbols_to_strip_from_output = (
|
76 |
+
symbols_to_strip_from_output.union({self.eos})
|
77 |
+
if symbols_to_strip_from_output is not None
|
78 |
+
else {self.eos}
|
79 |
+
)
|
80 |
+
self.vocab_size = len(tgt_dict)
|
81 |
+
self.beam_size = beam_size
|
82 |
+
# the max beam size is the dictionary size - 1, since we never select pad
|
83 |
+
self.beam_size = min(beam_size, self.vocab_size - 1)
|
84 |
+
self.max_len_a = max_len_a
|
85 |
+
self.max_len_b = max_len_b
|
86 |
+
self.min_len = min_len
|
87 |
+
self.max_len = max_len or self.model.max_decoder_positions()
|
88 |
+
|
89 |
+
self.normalize_scores = normalize_scores
|
90 |
+
self.len_penalty = len_penalty
|
91 |
+
self.unk_penalty = unk_penalty
|
92 |
+
self.temperature = temperature
|
93 |
+
self.match_source_len = match_source_len
|
94 |
+
|
95 |
+
if no_repeat_ngram_size > 0:
|
96 |
+
self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
|
97 |
+
else:
|
98 |
+
self.repeat_ngram_blocker = None
|
99 |
+
|
100 |
+
assert temperature > 0, "--temperature must be greater than 0"
|
101 |
+
|
102 |
+
self.search = (
|
103 |
+
search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
|
104 |
+
)
|
105 |
+
# We only need to set src_lengths in LengthConstrainedBeamSearch.
|
106 |
+
# As a module attribute, setting it would break in multithread
|
107 |
+
# settings when the model is shared.
|
108 |
+
self.should_set_src_lengths = (
|
109 |
+
hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
|
110 |
+
)
|
111 |
+
|
112 |
+
self.model.eval()
|
113 |
+
|
114 |
+
self.lm_model = lm_model
|
115 |
+
self.lm_weight = lm_weight
|
116 |
+
if self.lm_model is not None:
|
117 |
+
self.lm_model.eval()
|
118 |
+
|
119 |
+
def cuda(self):
|
120 |
+
self.model.cuda()
|
121 |
+
return self
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def forward(
|
125 |
+
self,
|
126 |
+
sample: Dict[str, Dict[str, Tensor]],
|
127 |
+
prefix_tokens: Optional[Tensor] = None,
|
128 |
+
bos_token: Optional[int] = None,
|
129 |
+
):
|
130 |
+
"""Generate a batch of translations.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
sample (dict): batch
|
134 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
135 |
+
with these tokens
|
136 |
+
bos_token (int, optional): beginning of sentence token
|
137 |
+
(default: self.eos)
|
138 |
+
"""
|
139 |
+
return self._generate(sample, prefix_tokens, bos_token=bos_token)
|
140 |
+
|
141 |
+
# TODO(myleott): unused, deprecate after pytorch-translate migration
|
142 |
+
def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
|
143 |
+
"""Iterate over a batched dataset and yield individual translations.
|
144 |
+
Args:
|
145 |
+
cuda (bool, optional): use GPU for generation
|
146 |
+
timer (StopwatchMeter, optional): time generations
|
147 |
+
"""
|
148 |
+
for sample in data_itr:
|
149 |
+
s = utils.move_to_cuda(sample) if cuda else sample
|
150 |
+
if "net_input" not in s:
|
151 |
+
continue
|
152 |
+
input = s["net_input"]
|
153 |
+
# model.forward normally channels prev_output_tokens into the decoder
|
154 |
+
# separately, but SequenceGenerator directly calls model.encoder
|
155 |
+
encoder_input = {
|
156 |
+
k: v for k, v in input.items() if k != "prev_output_tokens"
|
157 |
+
}
|
158 |
+
if timer is not None:
|
159 |
+
timer.start()
|
160 |
+
with torch.no_grad():
|
161 |
+
hypos = self.generate(encoder_input)
|
162 |
+
if timer is not None:
|
163 |
+
timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
|
164 |
+
for i, id in enumerate(s["id"].data):
|
165 |
+
# remove padding
|
166 |
+
src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
|
167 |
+
ref = (
|
168 |
+
utils.strip_pad(s["target"].data[i, :], self.pad)
|
169 |
+
if s["target"] is not None
|
170 |
+
else None
|
171 |
+
)
|
172 |
+
yield id, src, ref, hypos[i]
|
173 |
+
|
174 |
+
@torch.no_grad()
|
175 |
+
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
|
176 |
+
"""Generate translations. Match the api of other fairseq generators.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
180 |
+
sample (dict): batch
|
181 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
182 |
+
with these tokens
|
183 |
+
constraints (torch.LongTensor, optional): force decoder to include
|
184 |
+
the list of constraints
|
185 |
+
bos_token (int, optional): beginning of sentence token
|
186 |
+
(default: self.eos)
|
187 |
+
"""
|
188 |
+
return self._generate(sample, **kwargs)
|
189 |
+
|
190 |
+
def _generate(
|
191 |
+
self,
|
192 |
+
sample: Dict[str, Dict[str, Tensor]],
|
193 |
+
prefix_tokens: Optional[Tensor] = None,
|
194 |
+
constraints: Optional[Tensor] = None,
|
195 |
+
bos_token: Optional[int] = None,
|
196 |
+
):
|
197 |
+
incremental_states = torch.jit.annotate(
|
198 |
+
List[Dict[str, Dict[str, Optional[Tensor]]]],
|
199 |
+
[
|
200 |
+
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
|
201 |
+
for i in range(self.model.models_size)
|
202 |
+
],
|
203 |
+
)
|
204 |
+
net_input = sample["net_input"]
|
205 |
+
|
206 |
+
if "src_tokens" in net_input:
|
207 |
+
src_tokens = net_input["src_tokens"]
|
208 |
+
# length of the source text being the character length except EndOfSentence and pad
|
209 |
+
src_lengths = (
|
210 |
+
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
|
211 |
+
)
|
212 |
+
elif "source" in net_input:
|
213 |
+
src_tokens = net_input["source"]
|
214 |
+
src_lengths = (
|
215 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
216 |
+
if net_input["padding_mask"] is not None
|
217 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
218 |
+
)
|
219 |
+
elif "features" in net_input:
|
220 |
+
src_tokens = net_input["features"]
|
221 |
+
src_lengths = (
|
222 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
223 |
+
if net_input["padding_mask"] is not None
|
224 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
|
228 |
+
|
229 |
+
# bsz: total number of sentences in beam
|
230 |
+
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
|
231 |
+
if src_tokens['audio'] is not None:
|
232 |
+
bsz, src_len = src_tokens['audio'].size()[:2]
|
233 |
+
src_device = src_tokens['audio'].device
|
234 |
+
else:
|
235 |
+
bsz, src_len = net_input['padding_mask'].size()
|
236 |
+
src_device = src_tokens['video'].device
|
237 |
+
beam_size = self.beam_size
|
238 |
+
if constraints is not None and not self.search.supports_constraints:
|
239 |
+
raise NotImplementedError(
|
240 |
+
"Target-side constraints were provided, but search method doesn't support them"
|
241 |
+
)
|
242 |
+
|
243 |
+
# Initialize constraints, when active
|
244 |
+
self.search.init_constraints(constraints, beam_size)
|
245 |
+
|
246 |
+
max_len: int = -1
|
247 |
+
if self.match_source_len:
|
248 |
+
max_len = src_lengths.max().item()
|
249 |
+
else:
|
250 |
+
max_len = min(
|
251 |
+
int(self.max_len_a * src_len + self.max_len_b),
|
252 |
+
self.max_len - 1,
|
253 |
+
)
|
254 |
+
assert (
|
255 |
+
self.min_len <= max_len
|
256 |
+
), "min_len cannot be larger than max_len, please adjust these!"
|
257 |
+
# compute the encoder output for each beam
|
258 |
+
encoder_outs = self.model.forward_encoder(net_input)
|
259 |
+
|
260 |
+
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
|
261 |
+
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
|
262 |
+
new_order = new_order.to(src_device).long()
|
263 |
+
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
|
264 |
+
# ensure encoder_outs is a List.
|
265 |
+
assert encoder_outs is not None
|
266 |
+
|
267 |
+
# initialize buffers
|
268 |
+
scores = (
|
269 |
+
torch.zeros(bsz * beam_size, max_len + 1).to(src_device).float()
|
270 |
+
) # +1 for eos; pad is never chosen for scoring
|
271 |
+
tokens = (
|
272 |
+
torch.zeros(bsz * beam_size, max_len + 2)
|
273 |
+
.to(src_device)
|
274 |
+
.long()
|
275 |
+
.fill_(self.pad)
|
276 |
+
) # +2 for eos and pad
|
277 |
+
tokens[:, 0] = self.eos if bos_token is None else bos_token
|
278 |
+
attn: Optional[Tensor] = None
|
279 |
+
|
280 |
+
# A list that indicates candidates that should be ignored.
|
281 |
+
# For example, suppose we're sampling and have already finalized 2/5
|
282 |
+
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
|
283 |
+
# so that we only finalize the remaining 3 samples.
|
284 |
+
cands_to_ignore = (
|
285 |
+
torch.zeros(bsz, beam_size).to(src_device).eq(-1)
|
286 |
+
) # forward and backward-compatible False mask
|
287 |
+
|
288 |
+
# list of completed sentences
|
289 |
+
finalized = torch.jit.annotate(
|
290 |
+
List[List[Dict[str, Tensor]]],
|
291 |
+
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
|
292 |
+
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
|
293 |
+
|
294 |
+
# a boolean array indicating if the sentence at the index is finished or not
|
295 |
+
finished = [False for i in range(bsz)]
|
296 |
+
num_remaining_sent = bsz # number of sentences remaining
|
297 |
+
|
298 |
+
# number of candidate hypos per step
|
299 |
+
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
|
300 |
+
|
301 |
+
# offset arrays for converting between different indexing schemes
|
302 |
+
bbsz_offsets = (
|
303 |
+
(torch.arange(0, bsz) * beam_size)
|
304 |
+
.unsqueeze(1)
|
305 |
+
.type_as(tokens)
|
306 |
+
.to(src_device)
|
307 |
+
)
|
308 |
+
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_device)
|
309 |
+
|
310 |
+
reorder_state: Optional[Tensor] = None
|
311 |
+
batch_idxs: Optional[Tensor] = None
|
312 |
+
|
313 |
+
original_batch_idxs: Optional[Tensor] = None
|
314 |
+
if "id" in sample and isinstance(sample["id"], Tensor):
|
315 |
+
original_batch_idxs = sample["id"]
|
316 |
+
else:
|
317 |
+
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
|
318 |
+
|
319 |
+
for step in range(max_len + 1): # one extra step for EOS marker
|
320 |
+
# reorder decoder internal states based on the prev choice of beams
|
321 |
+
if reorder_state is not None:
|
322 |
+
if batch_idxs is not None:
|
323 |
+
# update beam indices to take into account removed sentences
|
324 |
+
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
|
325 |
+
batch_idxs
|
326 |
+
)
|
327 |
+
reorder_state.view(-1, beam_size).add_(
|
328 |
+
corr.unsqueeze(-1) * beam_size
|
329 |
+
)
|
330 |
+
original_batch_idxs = original_batch_idxs[batch_idxs]
|
331 |
+
self.model.reorder_incremental_state(incremental_states, reorder_state)
|
332 |
+
encoder_outs = self.model.reorder_encoder_out(
|
333 |
+
encoder_outs, reorder_state
|
334 |
+
)
|
335 |
+
|
336 |
+
lprobs, avg_attn_scores = self.model.forward_decoder(
|
337 |
+
tokens[:, : step + 1],
|
338 |
+
encoder_outs,
|
339 |
+
incremental_states,
|
340 |
+
self.temperature,
|
341 |
+
)
|
342 |
+
|
343 |
+
if self.lm_model is not None:
|
344 |
+
lm_out = self.lm_model(tokens[:, : step + 1])
|
345 |
+
probs = self.lm_model.get_normalized_probs(
|
346 |
+
lm_out, log_probs=True, sample=None
|
347 |
+
)
|
348 |
+
probs = probs[:, -1, :] * self.lm_weight
|
349 |
+
lprobs += probs
|
350 |
+
|
351 |
+
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
|
352 |
+
|
353 |
+
lprobs[:, self.pad] = -math.inf # never select pad
|
354 |
+
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
|
355 |
+
|
356 |
+
# handle max length constraint
|
357 |
+
if step >= max_len:
|
358 |
+
lprobs[:, : self.eos] = -math.inf
|
359 |
+
lprobs[:, self.eos + 1 :] = -math.inf
|
360 |
+
|
361 |
+
# handle prefix tokens (possibly with different lengths)
|
362 |
+
if (
|
363 |
+
prefix_tokens is not None
|
364 |
+
and step < prefix_tokens.size(1)
|
365 |
+
and step < max_len
|
366 |
+
):
|
367 |
+
lprobs, tokens, scores = self._prefix_tokens(
|
368 |
+
step, lprobs, scores, tokens, prefix_tokens, beam_size
|
369 |
+
)
|
370 |
+
elif step < self.min_len:
|
371 |
+
# minimum length constraint (does not apply if using prefix_tokens)
|
372 |
+
lprobs[:, self.eos] = -math.inf
|
373 |
+
|
374 |
+
# Record attention scores, only support avg_attn_scores is a Tensor
|
375 |
+
if avg_attn_scores is not None:
|
376 |
+
if attn is None:
|
377 |
+
attn = torch.empty(
|
378 |
+
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
|
379 |
+
).to(scores)
|
380 |
+
attn[:, :, step + 1].copy_(avg_attn_scores)
|
381 |
+
|
382 |
+
scores = scores.type_as(lprobs)
|
383 |
+
eos_bbsz_idx = torch.empty(0).to(
|
384 |
+
tokens
|
385 |
+
) # indices of hypothesis ending with eos (finished sentences)
|
386 |
+
eos_scores = torch.empty(0).to(
|
387 |
+
scores
|
388 |
+
) # scores of hypothesis ending with eos (finished sentences)
|
389 |
+
|
390 |
+
if self.should_set_src_lengths:
|
391 |
+
self.search.set_src_lengths(src_lengths)
|
392 |
+
|
393 |
+
if self.repeat_ngram_blocker is not None:
|
394 |
+
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
|
395 |
+
|
396 |
+
# Shape: (batch, cand_size)
|
397 |
+
cand_scores, cand_indices, cand_beams = self.search.step(
|
398 |
+
step,
|
399 |
+
lprobs.view(bsz, -1, self.vocab_size),
|
400 |
+
scores.view(bsz, beam_size, -1)[:, :, :step],
|
401 |
+
tokens[:, : step + 1],
|
402 |
+
original_batch_idxs,
|
403 |
+
)
|
404 |
+
|
405 |
+
# cand_bbsz_idx contains beam indices for the top candidate
|
406 |
+
# hypotheses, with a range of values: [0, bsz*beam_size),
|
407 |
+
# and dimensions: [bsz, cand_size]
|
408 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
409 |
+
|
410 |
+
# finalize hypotheses that end in eos
|
411 |
+
# Shape of eos_mask: (batch size, beam size)
|
412 |
+
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
|
413 |
+
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
|
414 |
+
|
415 |
+
# only consider eos when it's among the top beam_size indices
|
416 |
+
# Now we know what beam item(s) to finish
|
417 |
+
# Shape: 1d list of absolute-numbered
|
418 |
+
eos_bbsz_idx = torch.masked_select(
|
419 |
+
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
|
420 |
+
)
|
421 |
+
|
422 |
+
finalized_sents: List[int] = []
|
423 |
+
if eos_bbsz_idx.numel() > 0:
|
424 |
+
eos_scores = torch.masked_select(
|
425 |
+
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
426 |
+
)
|
427 |
+
|
428 |
+
finalized_sents = self.finalize_hypos(
|
429 |
+
step,
|
430 |
+
eos_bbsz_idx,
|
431 |
+
eos_scores,
|
432 |
+
tokens,
|
433 |
+
scores,
|
434 |
+
finalized,
|
435 |
+
finished,
|
436 |
+
beam_size,
|
437 |
+
attn,
|
438 |
+
src_lengths,
|
439 |
+
max_len,
|
440 |
+
)
|
441 |
+
num_remaining_sent -= len(finalized_sents)
|
442 |
+
|
443 |
+
assert num_remaining_sent >= 0
|
444 |
+
if num_remaining_sent == 0:
|
445 |
+
break
|
446 |
+
if self.search.stop_on_max_len and step >= max_len:
|
447 |
+
break
|
448 |
+
assert step < max_len, f"{step} < {max_len}"
|
449 |
+
|
450 |
+
# Remove finalized sentences (ones for which {beam_size}
|
451 |
+
# finished hypotheses have been generated) from the batch.
|
452 |
+
if len(finalized_sents) > 0:
|
453 |
+
new_bsz = bsz - len(finalized_sents)
|
454 |
+
|
455 |
+
# construct batch_idxs which holds indices of batches to keep for the next pass
|
456 |
+
batch_mask = torch.ones(
|
457 |
+
bsz, dtype=torch.bool, device=cand_indices.device
|
458 |
+
)
|
459 |
+
batch_mask[finalized_sents] = False
|
460 |
+
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
|
461 |
+
batch_idxs = torch.arange(
|
462 |
+
bsz, device=cand_indices.device
|
463 |
+
).masked_select(batch_mask)
|
464 |
+
|
465 |
+
# Choose the subset of the hypothesized constraints that will continue
|
466 |
+
self.search.prune_sentences(batch_idxs)
|
467 |
+
|
468 |
+
eos_mask = eos_mask[batch_idxs]
|
469 |
+
cand_beams = cand_beams[batch_idxs]
|
470 |
+
bbsz_offsets.resize_(new_bsz, 1)
|
471 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
472 |
+
cand_scores = cand_scores[batch_idxs]
|
473 |
+
cand_indices = cand_indices[batch_idxs]
|
474 |
+
|
475 |
+
if prefix_tokens is not None:
|
476 |
+
prefix_tokens = prefix_tokens[batch_idxs]
|
477 |
+
src_lengths = src_lengths[batch_idxs]
|
478 |
+
cands_to_ignore = cands_to_ignore[batch_idxs]
|
479 |
+
|
480 |
+
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
481 |
+
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
482 |
+
if attn is not None:
|
483 |
+
attn = attn.view(bsz, -1)[batch_idxs].view(
|
484 |
+
new_bsz * beam_size, attn.size(1), -1
|
485 |
+
)
|
486 |
+
bsz = new_bsz
|
487 |
+
else:
|
488 |
+
batch_idxs = None
|
489 |
+
|
490 |
+
# Set active_mask so that values > cand_size indicate eos hypos
|
491 |
+
# and values < cand_size indicate candidate active hypos.
|
492 |
+
# After, the min values per row are the top candidate active hypos
|
493 |
+
|
494 |
+
# Rewrite the operator since the element wise or is not supported in torchscript.
|
495 |
+
|
496 |
+
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
|
497 |
+
active_mask = torch.add(
|
498 |
+
eos_mask.type_as(cand_offsets) * cand_size,
|
499 |
+
cand_offsets[: eos_mask.size(1)],
|
500 |
+
)
|
501 |
+
|
502 |
+
# get the top beam_size active hypotheses, which are just
|
503 |
+
# the hypos with the smallest values in active_mask.
|
504 |
+
# {active_hypos} indicates which {beam_size} hypotheses
|
505 |
+
# from the list of {2 * beam_size} candidates were
|
506 |
+
# selected. Shapes: (batch size, beam size)
|
507 |
+
new_cands_to_ignore, active_hypos = torch.topk(
|
508 |
+
active_mask, k=beam_size, dim=1, largest=False
|
509 |
+
)
|
510 |
+
|
511 |
+
# update cands_to_ignore to ignore any finalized hypos.
|
512 |
+
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
|
513 |
+
# Make sure there is at least one active item for each sentence in the batch.
|
514 |
+
assert (~cands_to_ignore).any(dim=1).all()
|
515 |
+
|
516 |
+
# update cands_to_ignore to ignore any finalized hypos
|
517 |
+
|
518 |
+
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
|
519 |
+
# can be selected more than once).
|
520 |
+
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
|
521 |
+
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
|
522 |
+
|
523 |
+
active_bbsz_idx = active_bbsz_idx.view(-1)
|
524 |
+
active_scores = active_scores.view(-1)
|
525 |
+
|
526 |
+
# copy tokens and scores for active hypotheses
|
527 |
+
|
528 |
+
# Set the tokens for each beam (can select the same row more than once)
|
529 |
+
tokens[:, : step + 1] = torch.index_select(
|
530 |
+
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
|
531 |
+
)
|
532 |
+
# Select the next token for each of them
|
533 |
+
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
|
534 |
+
cand_indices, dim=1, index=active_hypos
|
535 |
+
)
|
536 |
+
if step > 0:
|
537 |
+
scores[:, :step] = torch.index_select(
|
538 |
+
scores[:, :step], dim=0, index=active_bbsz_idx
|
539 |
+
)
|
540 |
+
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
|
541 |
+
cand_scores, dim=1, index=active_hypos
|
542 |
+
)
|
543 |
+
|
544 |
+
# Update constraints based on which candidates were selected for the next beam
|
545 |
+
self.search.update_constraints(active_hypos)
|
546 |
+
|
547 |
+
# copy attention for active hypotheses
|
548 |
+
if attn is not None:
|
549 |
+
attn[:, :, : step + 2] = torch.index_select(
|
550 |
+
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
|
551 |
+
)
|
552 |
+
|
553 |
+
# reorder incremental state in decoder
|
554 |
+
reorder_state = active_bbsz_idx
|
555 |
+
|
556 |
+
# sort by score descending
|
557 |
+
for sent in range(len(finalized)):
|
558 |
+
scores = torch.tensor(
|
559 |
+
[float(elem["score"].item()) for elem in finalized[sent]]
|
560 |
+
)
|
561 |
+
_, sorted_scores_indices = torch.sort(scores, descending=True)
|
562 |
+
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
|
563 |
+
finalized[sent] = torch.jit.annotate(
|
564 |
+
List[Dict[str, Tensor]], finalized[sent]
|
565 |
+
)
|
566 |
+
return finalized
|
567 |
+
|
568 |
+
def _prefix_tokens(
|
569 |
+
self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
|
570 |
+
):
|
571 |
+
"""Handle prefix tokens"""
|
572 |
+
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
|
573 |
+
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
|
574 |
+
prefix_mask = prefix_toks.ne(self.pad)
|
575 |
+
lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs)
|
576 |
+
lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
|
577 |
+
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
|
578 |
+
)
|
579 |
+
# if prefix includes eos, then we should make sure tokens and
|
580 |
+
# scores are the same across all beams
|
581 |
+
eos_mask = prefix_toks.eq(self.eos)
|
582 |
+
if eos_mask.any():
|
583 |
+
# validate that the first beam matches the prefix
|
584 |
+
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
|
585 |
+
:, 0, 1 : step + 1
|
586 |
+
]
|
587 |
+
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
|
588 |
+
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
|
589 |
+
assert (first_beam == target_prefix).all()
|
590 |
+
|
591 |
+
# copy tokens, scores and lprobs from the first beam to all beams
|
592 |
+
tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
|
593 |
+
scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
|
594 |
+
lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
|
595 |
+
return lprobs, tokens, scores
|
596 |
+
|
597 |
+
def replicate_first_beam(self, tensor, mask, beam_size: int):
|
598 |
+
tensor = tensor.view(-1, beam_size, tensor.size(-1))
|
599 |
+
tensor[mask] = tensor[mask][:, :1, :]
|
600 |
+
return tensor.view(-1, tensor.size(-1))
|
601 |
+
|
602 |
+
def finalize_hypos(
|
603 |
+
self,
|
604 |
+
step: int,
|
605 |
+
bbsz_idx,
|
606 |
+
eos_scores,
|
607 |
+
tokens,
|
608 |
+
scores,
|
609 |
+
finalized: List[List[Dict[str, Tensor]]],
|
610 |
+
finished: List[bool],
|
611 |
+
beam_size: int,
|
612 |
+
attn: Optional[Tensor],
|
613 |
+
src_lengths,
|
614 |
+
max_len: int,
|
615 |
+
):
|
616 |
+
"""Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
|
617 |
+
A sentence is finalized when {beam_size} finished items have been collected for it.
|
618 |
+
|
619 |
+
Returns number of sentences (not beam items) being finalized.
|
620 |
+
These will be removed from the batch and not processed further.
|
621 |
+
Args:
|
622 |
+
bbsz_idx (Tensor):
|
623 |
+
"""
|
624 |
+
assert bbsz_idx.numel() == eos_scores.numel()
|
625 |
+
|
626 |
+
# clone relevant token and attention tensors.
|
627 |
+
# tokens is (batch * beam, max_len). So the index_select
|
628 |
+
# gets the newly EOS rows, then selects cols 1..{step + 2}
|
629 |
+
tokens_clone = tokens.index_select(0, bbsz_idx)[
|
630 |
+
:, 1 : step + 2
|
631 |
+
] # skip the first index, which is EOS
|
632 |
+
|
633 |
+
tokens_clone[:, step] = self.eos
|
634 |
+
attn_clone = (
|
635 |
+
attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
|
636 |
+
if attn is not None
|
637 |
+
else None
|
638 |
+
)
|
639 |
+
|
640 |
+
# compute scores per token position
|
641 |
+
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
|
642 |
+
pos_scores[:, step] = eos_scores
|
643 |
+
# convert from cumulative to per-position scores
|
644 |
+
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
|
645 |
+
|
646 |
+
# normalize sentence-level scores
|
647 |
+
if self.normalize_scores:
|
648 |
+
eos_scores /= (step + 1) ** self.len_penalty
|
649 |
+
|
650 |
+
# cum_unfin records which sentences in the batch are finished.
|
651 |
+
# It helps match indexing between (a) the original sentences
|
652 |
+
# in the batch and (b) the current, possibly-reduced set of
|
653 |
+
# sentences.
|
654 |
+
cum_unfin: List[int] = []
|
655 |
+
prev = 0
|
656 |
+
for f in finished:
|
657 |
+
if f:
|
658 |
+
prev += 1
|
659 |
+
else:
|
660 |
+
cum_unfin.append(prev)
|
661 |
+
|
662 |
+
# The keys here are of the form "{sent}_{unfin_idx}", where
|
663 |
+
# "unfin_idx" is the index in the current (possibly reduced)
|
664 |
+
# list of sentences, and "sent" is the index in the original,
|
665 |
+
# unreduced batch
|
666 |
+
# set() is not supported in script export
|
667 |
+
sents_seen: Dict[str, Optional[Tensor]] = {}
|
668 |
+
|
669 |
+
# For every finished beam item
|
670 |
+
for i in range(bbsz_idx.size()[0]):
|
671 |
+
idx = bbsz_idx[i]
|
672 |
+
score = eos_scores[i]
|
673 |
+
# sentence index in the current (possibly reduced) batch
|
674 |
+
unfin_idx = idx // beam_size
|
675 |
+
# sentence index in the original (unreduced) batch
|
676 |
+
sent = unfin_idx + cum_unfin[unfin_idx]
|
677 |
+
# Cannot create dict for key type '(int, int)' in torchscript.
|
678 |
+
# The workaround is to cast int to string
|
679 |
+
seen = str(sent.item()) + "_" + str(unfin_idx.item())
|
680 |
+
if seen not in sents_seen:
|
681 |
+
sents_seen[seen] = None
|
682 |
+
|
683 |
+
if self.match_source_len and step > src_lengths[unfin_idx]:
|
684 |
+
score = torch.tensor(-math.inf).to(score)
|
685 |
+
|
686 |
+
# An input sentence (among those in a batch) is finished when
|
687 |
+
# beam_size hypotheses have been collected for it
|
688 |
+
if len(finalized[sent]) < beam_size:
|
689 |
+
if attn_clone is not None:
|
690 |
+
# remove padding tokens from attn scores
|
691 |
+
hypo_attn = attn_clone[i]
|
692 |
+
else:
|
693 |
+
hypo_attn = torch.empty(0)
|
694 |
+
|
695 |
+
finalized[sent].append(
|
696 |
+
{
|
697 |
+
"tokens": tokens_clone[i],
|
698 |
+
"score": score,
|
699 |
+
"attention": hypo_attn, # src_len x tgt_len
|
700 |
+
"alignment": torch.empty(0),
|
701 |
+
"positional_scores": pos_scores[i],
|
702 |
+
}
|
703 |
+
)
|
704 |
+
|
705 |
+
newly_finished: List[int] = []
|
706 |
+
|
707 |
+
for seen in sents_seen.keys():
|
708 |
+
# check termination conditions for this sentence
|
709 |
+
sent: int = int(float(seen.split("_")[0]))
|
710 |
+
unfin_idx: int = int(float(seen.split("_")[1]))
|
711 |
+
|
712 |
+
if not finished[sent] and self.is_finished(
|
713 |
+
step, unfin_idx, max_len, len(finalized[sent]), beam_size
|
714 |
+
):
|
715 |
+
finished[sent] = True
|
716 |
+
newly_finished.append(unfin_idx)
|
717 |
+
|
718 |
+
return newly_finished
|
719 |
+
|
720 |
+
def is_finished(
|
721 |
+
self,
|
722 |
+
step: int,
|
723 |
+
unfin_idx: int,
|
724 |
+
max_len: int,
|
725 |
+
finalized_sent_len: int,
|
726 |
+
beam_size: int,
|
727 |
+
):
|
728 |
+
"""
|
729 |
+
Check whether decoding for a sentence is finished, which
|
730 |
+
occurs when the list of finalized sentences has reached the
|
731 |
+
beam size, or when we reach the maximum length.
|
732 |
+
"""
|
733 |
+
assert finalized_sent_len <= beam_size
|
734 |
+
if finalized_sent_len == beam_size or step == max_len:
|
735 |
+
return True
|
736 |
+
return False
|
737 |
+
|
738 |
+
|
739 |
+
class EnsembleModel(nn.Module):
|
740 |
+
"""A wrapper around an ensemble of models."""
|
741 |
+
|
742 |
+
def __init__(self, models):
|
743 |
+
super().__init__()
|
744 |
+
self.models_size = len(models)
|
745 |
+
# method '__len__' is not supported in ModuleList for torch script
|
746 |
+
self.single_model = models[0]
|
747 |
+
self.models = nn.ModuleList(models)
|
748 |
+
|
749 |
+
self.has_incremental: bool = False
|
750 |
+
if all(
|
751 |
+
hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
|
752 |
+
for m in models
|
753 |
+
):
|
754 |
+
self.has_incremental = True
|
755 |
+
|
756 |
+
def forward(self):
|
757 |
+
pass
|
758 |
+
|
759 |
+
def has_encoder(self):
|
760 |
+
return hasattr(self.single_model, "encoder")
|
761 |
+
|
762 |
+
def has_incremental_states(self):
|
763 |
+
return self.has_incremental
|
764 |
+
|
765 |
+
def max_decoder_positions(self):
|
766 |
+
return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
|
767 |
+
|
768 |
+
@torch.jit.export
|
769 |
+
def forward_encoder(self, net_input: Dict[str, Tensor]):
|
770 |
+
if not self.has_encoder():
|
771 |
+
return None
|
772 |
+
return [model.encoder.forward_torchscript(net_input) for model in self.models]
|
773 |
+
|
774 |
+
@torch.jit.export
|
775 |
+
def forward_decoder(
|
776 |
+
self,
|
777 |
+
tokens,
|
778 |
+
encoder_outs: List[Dict[str, List[Tensor]]],
|
779 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
780 |
+
temperature: float = 1.0,
|
781 |
+
):
|
782 |
+
log_probs = []
|
783 |
+
avg_attn: Optional[Tensor] = None
|
784 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None
|
785 |
+
for i, model in enumerate(self.models):
|
786 |
+
if self.has_encoder():
|
787 |
+
encoder_out = encoder_outs[i]
|
788 |
+
# decode each model
|
789 |
+
if self.has_incremental_states():
|
790 |
+
decoder_out = model.decoder.forward(
|
791 |
+
tokens,
|
792 |
+
encoder_out=encoder_out,
|
793 |
+
incremental_state=incremental_states[i],
|
794 |
+
)
|
795 |
+
else:
|
796 |
+
if hasattr(model, "decoder"):
|
797 |
+
decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
|
798 |
+
else:
|
799 |
+
decoder_out = model.forward(tokens)
|
800 |
+
|
801 |
+
attn: Optional[Tensor] = None
|
802 |
+
decoder_len = len(decoder_out)
|
803 |
+
if decoder_len > 1 and decoder_out[1] is not None:
|
804 |
+
if isinstance(decoder_out[1], Tensor):
|
805 |
+
attn = decoder_out[1]
|
806 |
+
else:
|
807 |
+
attn_holder = decoder_out[1]["attn"]
|
808 |
+
if isinstance(attn_holder, Tensor):
|
809 |
+
attn = attn_holder
|
810 |
+
elif attn_holder is not None:
|
811 |
+
attn = attn_holder[0]
|
812 |
+
if attn is not None:
|
813 |
+
attn = attn[:, -1, :]
|
814 |
+
|
815 |
+
decoder_out_tuple = (
|
816 |
+
decoder_out[0][:, -1:, :].div_(temperature),
|
817 |
+
None if decoder_len <= 1 else decoder_out[1],
|
818 |
+
)
|
819 |
+
probs = model.get_normalized_probs(
|
820 |
+
decoder_out_tuple, log_probs=True, sample=None
|
821 |
+
)
|
822 |
+
probs = probs[:, -1, :]
|
823 |
+
if self.models_size == 1:
|
824 |
+
return probs, attn
|
825 |
+
|
826 |
+
log_probs.append(probs)
|
827 |
+
if attn is not None:
|
828 |
+
if avg_attn is None:
|
829 |
+
avg_attn = attn
|
830 |
+
else:
|
831 |
+
avg_attn.add_(attn)
|
832 |
+
|
833 |
+
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
|
834 |
+
self.models_size
|
835 |
+
)
|
836 |
+
|
837 |
+
if avg_attn is not None:
|
838 |
+
avg_attn.div_(self.models_size)
|
839 |
+
return avg_probs, avg_attn
|
840 |
+
|
841 |
+
@torch.jit.export
|
842 |
+
def reorder_encoder_out(
|
843 |
+
self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
|
844 |
+
):
|
845 |
+
"""
|
846 |
+
Reorder encoder output according to *new_order*.
|
847 |
+
|
848 |
+
Args:
|
849 |
+
encoder_out: output from the ``forward()`` method
|
850 |
+
new_order (LongTensor): desired order
|
851 |
+
|
852 |
+
Returns:
|
853 |
+
*encoder_out* rearranged according to *new_order*
|
854 |
+
"""
|
855 |
+
new_outs: List[Dict[str, List[Tensor]]] = []
|
856 |
+
if not self.has_encoder():
|
857 |
+
return new_outs
|
858 |
+
for i, model in enumerate(self.models):
|
859 |
+
assert encoder_outs is not None
|
860 |
+
new_outs.append(
|
861 |
+
model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
|
862 |
+
)
|
863 |
+
return new_outs
|
864 |
+
|
865 |
+
@torch.jit.export
|
866 |
+
def reorder_incremental_state(
|
867 |
+
self,
|
868 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
869 |
+
new_order,
|
870 |
+
):
|
871 |
+
if not self.has_incremental_states():
|
872 |
+
return
|
873 |
+
for i, model in enumerate(self.models):
|
874 |
+
model.decoder.reorder_incremental_state_scripting(
|
875 |
+
incremental_states[i], new_order
|
876 |
+
)
|
877 |
+
|
878 |
+
|
879 |
+
class SequenceGeneratorWithAlignment(SequenceGenerator):
|
880 |
+
def __init__(
|
881 |
+
self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
|
882 |
+
):
|
883 |
+
"""Generates translations of a given source sentence.
|
884 |
+
|
885 |
+
Produces alignments following "Jointly Learning to Align and
|
886 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
887 |
+
|
888 |
+
Args:
|
889 |
+
left_pad_target (bool, optional): Whether or not the
|
890 |
+
hypothesis should be left padded or not when they are
|
891 |
+
teacher forced for generating alignments.
|
892 |
+
"""
|
893 |
+
super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
|
894 |
+
self.left_pad_target = left_pad_target
|
895 |
+
|
896 |
+
if print_alignment == "hard":
|
897 |
+
self.extract_alignment = utils.extract_hard_alignment
|
898 |
+
elif print_alignment == "soft":
|
899 |
+
self.extract_alignment = utils.extract_soft_alignment
|
900 |
+
|
901 |
+
@torch.no_grad()
|
902 |
+
def generate(self, models, sample, **kwargs):
|
903 |
+
finalized = super()._generate(sample, **kwargs)
|
904 |
+
|
905 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
906 |
+
bsz = src_tokens.shape[0]
|
907 |
+
beam_size = self.beam_size
|
908 |
+
(
|
909 |
+
src_tokens,
|
910 |
+
src_lengths,
|
911 |
+
prev_output_tokens,
|
912 |
+
tgt_tokens,
|
913 |
+
) = self._prepare_batch_for_alignment(sample, finalized)
|
914 |
+
if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
|
915 |
+
attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
|
916 |
+
else:
|
917 |
+
attn = [
|
918 |
+
finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
|
919 |
+
for i in range(bsz * beam_size)
|
920 |
+
]
|
921 |
+
|
922 |
+
if src_tokens.device != "cpu":
|
923 |
+
src_tokens = src_tokens.to("cpu")
|
924 |
+
tgt_tokens = tgt_tokens.to("cpu")
|
925 |
+
attn = [i.to("cpu") for i in attn]
|
926 |
+
|
927 |
+
# Process the attn matrix to extract hard alignments.
|
928 |
+
for i in range(bsz * beam_size):
|
929 |
+
alignment = self.extract_alignment(
|
930 |
+
attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
|
931 |
+
)
|
932 |
+
finalized[i // beam_size][i % beam_size]["alignment"] = alignment
|
933 |
+
return finalized
|
934 |
+
|
935 |
+
def _prepare_batch_for_alignment(self, sample, hypothesis):
|
936 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
937 |
+
bsz = src_tokens.shape[0]
|
938 |
+
src_tokens = (
|
939 |
+
src_tokens[:, None, :]
|
940 |
+
.expand(-1, self.beam_size, -1)
|
941 |
+
.contiguous()
|
942 |
+
.view(bsz * self.beam_size, -1)
|
943 |
+
)
|
944 |
+
src_lengths = sample["net_input"]["src_lengths"]
|
945 |
+
src_lengths = (
|
946 |
+
src_lengths[:, None]
|
947 |
+
.expand(-1, self.beam_size)
|
948 |
+
.contiguous()
|
949 |
+
.view(bsz * self.beam_size)
|
950 |
+
)
|
951 |
+
prev_output_tokens = data_utils.collate_tokens(
|
952 |
+
[beam["tokens"] for example in hypothesis for beam in example],
|
953 |
+
self.pad,
|
954 |
+
self.eos,
|
955 |
+
self.left_pad_target,
|
956 |
+
move_eos_to_beginning=True,
|
957 |
+
)
|
958 |
+
tgt_tokens = data_utils.collate_tokens(
|
959 |
+
[beam["tokens"] for example in hypothesis for beam in example],
|
960 |
+
self.pad,
|
961 |
+
self.eos,
|
962 |
+
self.left_pad_target,
|
963 |
+
move_eos_to_beginning=False,
|
964 |
+
)
|
965 |
+
return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
|
966 |
+
|
967 |
+
|
968 |
+
class EnsembleModelWithAlignment(EnsembleModel):
|
969 |
+
"""A wrapper around an ensemble of models."""
|
970 |
+
|
971 |
+
def __init__(self, models):
|
972 |
+
super().__init__(models)
|
973 |
+
|
974 |
+
def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
|
975 |
+
avg_attn = None
|
976 |
+
for model in self.models:
|
977 |
+
decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
|
978 |
+
attn = decoder_out[1]["attn"][0]
|
979 |
+
if avg_attn is None:
|
980 |
+
avg_attn = attn
|
981 |
+
else:
|
982 |
+
avg_attn.add_(attn)
|
983 |
+
if len(self.models) > 1:
|
984 |
+
avg_attn.div_(len(self.models))
|
985 |
+
return avg_attn
|
slam_llm/models/avhubert/utils.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import random
|
10 |
+
import numpy as np
|
11 |
+
from typing import Dict, List, Optional, Tuple
|
12 |
+
|
13 |
+
def load_video(path):
|
14 |
+
for i in range(3):
|
15 |
+
try:
|
16 |
+
cap = cv2.VideoCapture(path)
|
17 |
+
frames = []
|
18 |
+
while True:
|
19 |
+
ret, frame = cap.read()
|
20 |
+
if ret:
|
21 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
22 |
+
frames.append(frame)
|
23 |
+
else:
|
24 |
+
break
|
25 |
+
frames = np.stack(frames)
|
26 |
+
return frames
|
27 |
+
except Exception:
|
28 |
+
print(f"failed loading {path} ({i} / 3)")
|
29 |
+
if i == 2:
|
30 |
+
raise ValueError(f"Unable to load {path}")
|
31 |
+
|
32 |
+
|
33 |
+
class Compose(object):
|
34 |
+
"""Compose several preprocess together.
|
35 |
+
Args:
|
36 |
+
preprocess (list of ``Preprocess`` objects): list of preprocess to compose.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, preprocess):
|
40 |
+
self.preprocess = preprocess
|
41 |
+
|
42 |
+
def __call__(self, sample):
|
43 |
+
for t in self.preprocess:
|
44 |
+
sample = t(sample)
|
45 |
+
return sample
|
46 |
+
|
47 |
+
def __repr__(self):
|
48 |
+
format_string = self.__class__.__name__ + '('
|
49 |
+
for t in self.preprocess:
|
50 |
+
format_string += '\n'
|
51 |
+
format_string += ' {0}'.format(t)
|
52 |
+
format_string += '\n)'
|
53 |
+
return format_string
|
54 |
+
|
55 |
+
|
56 |
+
class Normalize(object):
|
57 |
+
"""Normalize a ndarray image with mean and standard deviation.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, mean, std):
|
61 |
+
self.mean = mean
|
62 |
+
self.std = std
|
63 |
+
|
64 |
+
def __call__(self, frames):
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
68 |
+
Returns:
|
69 |
+
Tensor: Normalized Tensor image.
|
70 |
+
"""
|
71 |
+
frames = (frames - self.mean) / self.std
|
72 |
+
return frames
|
73 |
+
|
74 |
+
def __repr__(self):
|
75 |
+
return self.__class__.__name__+'(mean={0}, std={1})'.format(self.mean, self.std)
|
76 |
+
|
77 |
+
class CenterCrop(object):
|
78 |
+
"""Crop the given image at the center
|
79 |
+
"""
|
80 |
+
def __init__(self, size):
|
81 |
+
self.size = size
|
82 |
+
|
83 |
+
def __call__(self, frames):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
img (numpy.ndarray): Images to be cropped.
|
87 |
+
Returns:
|
88 |
+
numpy.ndarray: Cropped image.
|
89 |
+
"""
|
90 |
+
t, h, w = frames.shape
|
91 |
+
th, tw = self.size
|
92 |
+
delta_w = int(round((w - tw))/2.)
|
93 |
+
delta_h = int(round((h - th))/2.)
|
94 |
+
frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw]
|
95 |
+
return frames
|
96 |
+
|
97 |
+
|
98 |
+
class RandomCrop(object):
|
99 |
+
"""Crop the given image at the center
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, size):
|
103 |
+
self.size = size
|
104 |
+
|
105 |
+
def __call__(self, frames):
|
106 |
+
"""
|
107 |
+
Args:
|
108 |
+
img (numpy.ndarray): Images to be cropped.
|
109 |
+
Returns:
|
110 |
+
numpy.ndarray: Cropped image.
|
111 |
+
"""
|
112 |
+
t, h, w = frames.shape
|
113 |
+
th, tw = self.size
|
114 |
+
delta_w = random.randint(0, w-tw)
|
115 |
+
delta_h = random.randint(0, h-th)
|
116 |
+
frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw]
|
117 |
+
return frames
|
118 |
+
|
119 |
+
def __repr__(self):
|
120 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
121 |
+
|
122 |
+
class HorizontalFlip(object):
|
123 |
+
"""Flip image horizontally.
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(self, flip_ratio):
|
127 |
+
self.flip_ratio = flip_ratio
|
128 |
+
|
129 |
+
def __call__(self, frames):
|
130 |
+
"""
|
131 |
+
Args:
|
132 |
+
img (numpy.ndarray): Images to be flipped with a probability flip_ratio
|
133 |
+
Returns:
|
134 |
+
numpy.ndarray: Cropped image.
|
135 |
+
"""
|
136 |
+
t, h, w = frames.shape
|
137 |
+
if random.random() < self.flip_ratio:
|
138 |
+
for index in range(t):
|
139 |
+
frames[index] = cv2.flip(frames[index], 1)
|
140 |
+
return frames
|
141 |
+
|
142 |
+
def compute_mask_indices(
|
143 |
+
shape: Tuple[int, int],
|
144 |
+
padding_mask: Optional[torch.Tensor],
|
145 |
+
mask_prob: float,
|
146 |
+
mask_length: int,
|
147 |
+
mask_type: str = "static",
|
148 |
+
mask_other: float = 0.0,
|
149 |
+
min_masks: int = 0,
|
150 |
+
no_overlap: bool = False,
|
151 |
+
min_space: int = 0,
|
152 |
+
) -> np.ndarray:
|
153 |
+
"""
|
154 |
+
Computes random mask spans for a given shape
|
155 |
+
Args:
|
156 |
+
shape: the the shape for which to compute masks.
|
157 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
158 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
159 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
160 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
161 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
162 |
+
mask_type: how to compute mask lengths
|
163 |
+
static = fixed size
|
164 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
165 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
166 |
+
poisson = sample from possion distribution with lambda = mask length
|
167 |
+
min_masks: minimum number of masked spans
|
168 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
169 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
170 |
+
"""
|
171 |
+
|
172 |
+
bsz, all_sz = shape
|
173 |
+
mask = np.full((bsz, all_sz), False)
|
174 |
+
|
175 |
+
all_num_mask = int(
|
176 |
+
# add a random number for probabilistic rounding
|
177 |
+
mask_prob * all_sz / float(mask_length)
|
178 |
+
+ np.random.rand()
|
179 |
+
)
|
180 |
+
|
181 |
+
all_num_mask = max(min_masks, all_num_mask)
|
182 |
+
|
183 |
+
mask_idcs = []
|
184 |
+
for i in range(bsz):
|
185 |
+
if padding_mask is not None:
|
186 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
187 |
+
num_mask = int(
|
188 |
+
# add a random number for probabilistic rounding
|
189 |
+
mask_prob * sz / float(mask_length)
|
190 |
+
+ np.random.rand()
|
191 |
+
)
|
192 |
+
num_mask = max(min_masks, num_mask)
|
193 |
+
else:
|
194 |
+
sz = all_sz
|
195 |
+
num_mask = all_num_mask
|
196 |
+
|
197 |
+
if mask_type == "static":
|
198 |
+
lengths = np.full(num_mask, mask_length)
|
199 |
+
elif mask_type == "uniform":
|
200 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
201 |
+
elif mask_type == "normal":
|
202 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
203 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
204 |
+
elif mask_type == "poisson":
|
205 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
206 |
+
lengths = [int(round(x)) for x in lengths]
|
207 |
+
else:
|
208 |
+
raise Exception("unknown mask selection " + mask_type)
|
209 |
+
|
210 |
+
if sum(lengths) == 0:
|
211 |
+
lengths[0] = min(mask_length, sz - 1)
|
212 |
+
|
213 |
+
if no_overlap:
|
214 |
+
mask_idc = []
|
215 |
+
|
216 |
+
def arrange(s, e, length, keep_length):
|
217 |
+
span_start = np.random.randint(s, e - length)
|
218 |
+
mask_idc.extend(span_start + i for i in range(length))
|
219 |
+
|
220 |
+
new_parts = []
|
221 |
+
if span_start - s - min_space >= keep_length:
|
222 |
+
new_parts.append((s, span_start - min_space + 1))
|
223 |
+
if e - span_start - keep_length - min_space > keep_length:
|
224 |
+
new_parts.append((span_start + length + min_space, e))
|
225 |
+
return new_parts
|
226 |
+
|
227 |
+
parts = [(0, sz)]
|
228 |
+
min_length = min(lengths)
|
229 |
+
for length in sorted(lengths, reverse=True):
|
230 |
+
lens = np.fromiter(
|
231 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
232 |
+
np.int,
|
233 |
+
)
|
234 |
+
l_sum = np.sum(lens)
|
235 |
+
if l_sum == 0:
|
236 |
+
break
|
237 |
+
probs = lens / np.sum(lens)
|
238 |
+
c = np.random.choice(len(parts), p=probs)
|
239 |
+
s, e = parts.pop(c)
|
240 |
+
parts.extend(arrange(s, e, length, min_length))
|
241 |
+
mask_idc = np.asarray(mask_idc)
|
242 |
+
else:
|
243 |
+
min_len = min(lengths)
|
244 |
+
if sz - min_len <= num_mask:
|
245 |
+
min_len = sz - num_mask - 1
|
246 |
+
|
247 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
248 |
+
|
249 |
+
mask_idc = np.asarray(
|
250 |
+
[
|
251 |
+
mask_idc[j] + offset
|
252 |
+
for j in range(len(mask_idc))
|
253 |
+
for offset in range(lengths[j])
|
254 |
+
]
|
255 |
+
)
|
256 |
+
|
257 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
258 |
+
|
259 |
+
min_len = min([len(m) for m in mask_idcs])
|
260 |
+
batch_indexes, starts, ends = [], [], []
|
261 |
+
for i, mask_idc in enumerate(mask_idcs):
|
262 |
+
if len(mask_idc) > min_len:
|
263 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
264 |
+
mask[i, mask_idc] = True
|
265 |
+
vals, run_starts, run_lengths = find_runs(mask[i])
|
266 |
+
start_indices, lengths = run_starts[vals == True], run_lengths[vals == True]
|
267 |
+
starts.append(start_indices)
|
268 |
+
ends.append(start_indices+lengths)
|
269 |
+
batch_indexes.append(np.zeros([len(start_indices)])+i)
|
270 |
+
return mask, np.concatenate(starts).astype(np.int64), np.concatenate(ends).astype(np.int64), np.concatenate(batch_indexes).astype(np.int64)
|
271 |
+
|
272 |
+
def find_runs(x):
|
273 |
+
"""Find runs of consecutive items in an array."""
|
274 |
+
|
275 |
+
# ensure array
|
276 |
+
x = np.asanyarray(x)
|
277 |
+
if x.ndim != 1:
|
278 |
+
raise ValueError('only 1D array supported')
|
279 |
+
n = x.shape[0]
|
280 |
+
|
281 |
+
# handle empty array
|
282 |
+
if n == 0:
|
283 |
+
return np.array([]), np.array([]), np.array([])
|
284 |
+
|
285 |
+
else:
|
286 |
+
# find run starts
|
287 |
+
loc_run_start = np.empty(n, dtype=bool)
|
288 |
+
loc_run_start[0] = True
|
289 |
+
np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
|
290 |
+
run_starts = np.nonzero(loc_run_start)[0]
|
291 |
+
|
292 |
+
# find run values
|
293 |
+
run_values = x[loc_run_start]
|
294 |
+
|
295 |
+
# find run lengths
|
296 |
+
run_lengths = np.diff(np.append(run_starts, n))
|
297 |
+
|
298 |
+
return run_values, run_starts, run_lengths
|
slam_llm/models/encoder.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
class WhisperWrappedEncoder:
|
8 |
+
|
9 |
+
@classmethod
|
10 |
+
def load(cls, model_config):
|
11 |
+
|
12 |
+
def extract_variable_length_features(self, x: torch.Tensor):
|
13 |
+
"""
|
14 |
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
15 |
+
the mel spectrogram of the audio
|
16 |
+
"""
|
17 |
+
x = F.gelu(self.conv1(x))
|
18 |
+
x = F.gelu(self.conv2(x))
|
19 |
+
x = x.permute(0, 2, 1)
|
20 |
+
|
21 |
+
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
22 |
+
# x = (x + self.positional_embedding).to(x.dtype)
|
23 |
+
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
|
24 |
+
|
25 |
+
for block in self.blocks:
|
26 |
+
x = block(x)
|
27 |
+
|
28 |
+
x = self.ln_post(x)
|
29 |
+
return x
|
30 |
+
|
31 |
+
import whisper
|
32 |
+
encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder
|
33 |
+
encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder)
|
34 |
+
return encoder
|
35 |
+
|
36 |
+
|
37 |
+
class BEATsEncoder:
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def load(cls, model_config):
|
41 |
+
from .BEATs.BEATs import BEATs, BEATsConfig
|
42 |
+
checkpoint = torch.load(model_config.encoder_path)
|
43 |
+
cfg = BEATsConfig(checkpoint['cfg'])
|
44 |
+
BEATs_model = BEATs(cfg)
|
45 |
+
BEATs_model.load_state_dict(checkpoint['model'])
|
46 |
+
|
47 |
+
return BEATs_model
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class UserDirModule:
|
52 |
+
user_dir: str
|
53 |
+
|
54 |
+
class EATEncoder:
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def load(cls, model_config):
|
58 |
+
import fairseq
|
59 |
+
model_path = UserDirModule(model_config.encoder_fairseq_dir)
|
60 |
+
fairseq.utils.import_user_module(model_path)
|
61 |
+
EATEncoder, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
|
62 |
+
EATEncoder = EATEncoder[0]
|
63 |
+
|
64 |
+
return EATEncoder
|
65 |
+
|
66 |
+
def extract_features(self, source, padding_mask):
|
67 |
+
return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x']
|
68 |
+
|
69 |
+
class SpatialASTEncoder:
|
70 |
+
@classmethod
|
71 |
+
def load(cls, model_config):
|
72 |
+
from functools import partial
|
73 |
+
from .SpatialAST import SpatialAST
|
74 |
+
binaural_encoder = SpatialAST.BinauralEncoder(
|
75 |
+
num_classes=355, drop_path_rate=0.1, num_cls_tokens=3,
|
76 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
77 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
78 |
+
)
|
79 |
+
|
80 |
+
checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu')
|
81 |
+
binaural_encoder.load_state_dict(checkpoint['model'], strict=False)
|
82 |
+
return binaural_encoder
|
83 |
+
|
84 |
+
class WavLMEncoder(nn.Module):
|
85 |
+
def __init__(self, config, model):
|
86 |
+
super().__init__()
|
87 |
+
self.config = config
|
88 |
+
self.model = model
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def load(cls, model_config):
|
92 |
+
from .wavlm.WavLM import WavLM, WavLMConfig
|
93 |
+
checkpoint = torch.load(model_config.encoder_path)
|
94 |
+
cfg = WavLMConfig(checkpoint['cfg'])
|
95 |
+
WavLM_model = WavLM(cfg)
|
96 |
+
WavLM_model.load_state_dict(checkpoint['model'])
|
97 |
+
assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match"
|
98 |
+
|
99 |
+
return cls(cfg, WavLM_model)
|
100 |
+
|
101 |
+
def extract_features(self, source, padding_mask):
|
102 |
+
return self.model.extract_features(source, padding_mask)[0]
|
103 |
+
|
104 |
+
class AVHubertEncoder:
|
105 |
+
|
106 |
+
@classmethod
|
107 |
+
def load(cls, model_config):
|
108 |
+
import fairseq
|
109 |
+
from .avhubert import hubert_pretraining, hubert, hubert_asr
|
110 |
+
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
|
111 |
+
model = models[0]
|
112 |
+
return model
|
113 |
+
|
114 |
+
class HubertEncoder:
|
115 |
+
|
116 |
+
@classmethod
|
117 |
+
def load(cls, model_config):
|
118 |
+
import fairseq
|
119 |
+
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
|
120 |
+
model = models[0]
|
121 |
+
if model_config.encoder_type == "pretrain":
|
122 |
+
pass
|
123 |
+
elif model_config.encoder_type == "finetune":
|
124 |
+
model.w2v_encoder.proj = None
|
125 |
+
model.w2v_encoder.apply_mask = False
|
126 |
+
else:
|
127 |
+
assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]"
|
128 |
+
return model
|
129 |
+
|
130 |
+
|
131 |
+
class HfTextEncoder:
|
132 |
+
|
133 |
+
@classmethod
|
134 |
+
def load(cls, model_config):
|
135 |
+
from transformers import AutoModel
|
136 |
+
model = AutoModel.from_pretrained(model_config.encoder_path)
|
137 |
+
return model
|
138 |
+
|
139 |
+
class MusicFMEncoder(nn.Module):
|
140 |
+
def __init__(self, config, model):
|
141 |
+
super().__init__()
|
142 |
+
self.config = config
|
143 |
+
self.model = model
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def load(cls, model_config):
|
147 |
+
from .musicfm.model.musicfm_25hz import MusicFM25Hz
|
148 |
+
model = MusicFM25Hz(
|
149 |
+
stat_path = model_config.encoder_stat_path,
|
150 |
+
model_path = model_config.encoder_path,
|
151 |
+
w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft")
|
152 |
+
)
|
153 |
+
return cls(model_config, model)
|
154 |
+
|
155 |
+
def extract_features(self, source, padding_mask=None):
|
156 |
+
_, hidden_states = self.model.get_predictions(source)
|
157 |
+
out = hidden_states[self.config.encoder_layer_idx]
|
158 |
+
return out
|
slam_llm/models/musicfm/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
slam_llm/models/musicfm/model/musicfm_25hz.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
#
|
3 |
+
# Copyright 2023 ByteDance Inc.
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
|
6 |
+
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
8 |
+
#
|
9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
10 |
+
#
|
11 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
12 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
13 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
14 |
+
# IN THE SOFTWARE.
|
15 |
+
|
16 |
+
import json
|
17 |
+
import random
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
from einops import rearrange
|
21 |
+
|
22 |
+
from ..modules.random_quantizer import RandomProjectionQuantizer
|
23 |
+
from ..modules.features import MelSTFT
|
24 |
+
from ..modules.conv import Conv2dSubsampling
|
25 |
+
|
26 |
+
|
27 |
+
class MusicFM25Hz(nn.Module):
|
28 |
+
"""
|
29 |
+
MusicFM
|
30 |
+
|
31 |
+
Input: 128-band mel spectrogram
|
32 |
+
Frontend: 2-layer Residual convolution
|
33 |
+
Backend: 12-layer Conformer
|
34 |
+
Quantizer: a codebook for mel spectrogram
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
num_codebooks=1,
|
40 |
+
codebook_dim=16,
|
41 |
+
codebook_size=4096,
|
42 |
+
features=["melspec_2048"],
|
43 |
+
hop_length=240,
|
44 |
+
n_mels=128,
|
45 |
+
conv_dim=512,
|
46 |
+
encoder_dim=1024,
|
47 |
+
encoder_depth=12,
|
48 |
+
mask_hop=0.4,
|
49 |
+
mask_prob=0.6,
|
50 |
+
is_flash=False,
|
51 |
+
stat_path="./data/fma_stats.json",
|
52 |
+
model_path="./data/pretrained_fma.pt",
|
53 |
+
w2v2_config_path="facebook/wav2vec2-conformer-rope-large-960h-ft",
|
54 |
+
):
|
55 |
+
super(MusicFM25Hz, self).__init__()
|
56 |
+
|
57 |
+
# global variables
|
58 |
+
self.hop_length = hop_length
|
59 |
+
self.mask_hop = mask_hop
|
60 |
+
self.mask_prob = mask_prob
|
61 |
+
self.num_codebooks = num_codebooks
|
62 |
+
self.codebook_size = codebook_size
|
63 |
+
self.features = features
|
64 |
+
|
65 |
+
# load feature mean / std stats
|
66 |
+
with open(stat_path, "r") as f:
|
67 |
+
self.stat = json.load(f)
|
68 |
+
|
69 |
+
# feature extractor
|
70 |
+
self.preprocessor_melspec_2048 = MelSTFT(
|
71 |
+
n_fft=2048, hop_length=hop_length, is_db=True
|
72 |
+
)
|
73 |
+
|
74 |
+
# random quantizer
|
75 |
+
seed = 142
|
76 |
+
for feature in self.features:
|
77 |
+
for i in range(num_codebooks):
|
78 |
+
setattr(
|
79 |
+
self,
|
80 |
+
f"quantizer_{feature}_{i}",
|
81 |
+
RandomProjectionQuantizer(
|
82 |
+
n_mels * 4, codebook_dim, codebook_size, seed=seed + i
|
83 |
+
),
|
84 |
+
)
|
85 |
+
|
86 |
+
# two residual convolution layers + one projection layer
|
87 |
+
self.conv = Conv2dSubsampling(
|
88 |
+
1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
|
89 |
+
)
|
90 |
+
|
91 |
+
# Conformer
|
92 |
+
if is_flash:
|
93 |
+
from modules.flash_conformer import (
|
94 |
+
Wav2Vec2ConformerEncoder,
|
95 |
+
Wav2Vec2ConformerConfig,
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
99 |
+
Wav2Vec2ConformerEncoder,
|
100 |
+
Wav2Vec2ConformerConfig,
|
101 |
+
)
|
102 |
+
config = Wav2Vec2ConformerConfig.from_pretrained(
|
103 |
+
w2v2_config_path
|
104 |
+
)
|
105 |
+
config.num_hidden_layers = encoder_depth
|
106 |
+
config.hidden_size = encoder_dim
|
107 |
+
|
108 |
+
self.conformer = Wav2Vec2ConformerEncoder(config)
|
109 |
+
|
110 |
+
# projection
|
111 |
+
self.linear = nn.Linear(encoder_dim, codebook_size)
|
112 |
+
|
113 |
+
# loss function
|
114 |
+
self.loss = nn.CrossEntropyLoss()
|
115 |
+
|
116 |
+
# cls token (used for sequence classification)
|
117 |
+
random.seed(seed)
|
118 |
+
self.cls_token = nn.Parameter(torch.randn(encoder_dim))
|
119 |
+
|
120 |
+
# load model
|
121 |
+
if model_path:
|
122 |
+
S = torch.load(model_path)["state_dict"]
|
123 |
+
SS = {k[6:]: v for k, v in S.items()}
|
124 |
+
self.load_state_dict(SS, strict=True)
|
125 |
+
|
126 |
+
def masking(self, x):
|
127 |
+
"""random masking of 400ms with given probability"""
|
128 |
+
mx = x.clone()
|
129 |
+
b, t = mx.shape
|
130 |
+
len_masking_raw = int(24000 * self.mask_hop)
|
131 |
+
len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
|
132 |
+
|
133 |
+
# get random mask indices
|
134 |
+
start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
|
135 |
+
time_domain_masked_indices = torch.nonzero(
|
136 |
+
start_indices.repeat_interleave(len_masking_raw, dim=1)
|
137 |
+
)
|
138 |
+
token_domain_masked_indices = torch.nonzero(
|
139 |
+
start_indices.repeat_interleave(len_masking_token, dim=1)
|
140 |
+
)
|
141 |
+
|
142 |
+
# mask with random values
|
143 |
+
masking_noise = (
|
144 |
+
torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
|
145 |
+
) # 0 mean 0.1 std
|
146 |
+
mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
|
147 |
+
|
148 |
+
return mx, token_domain_masked_indices
|
149 |
+
|
150 |
+
@torch.no_grad()
|
151 |
+
def preprocessing(self, x, features):
|
152 |
+
"""extract classic audio features"""
|
153 |
+
# check precision
|
154 |
+
if x.dtype == torch.float16:
|
155 |
+
precision = 16
|
156 |
+
else:
|
157 |
+
precision = 32
|
158 |
+
|
159 |
+
out = {}
|
160 |
+
for key in features:
|
161 |
+
layer = getattr(self, "preprocessor_%s" % key)
|
162 |
+
out[key] = layer.float()(x.float())[..., :-1]
|
163 |
+
if precision == 16:
|
164 |
+
out[key] = out[key].half()
|
165 |
+
return out
|
166 |
+
|
167 |
+
def encoder(self, x):
|
168 |
+
"""2-layer conv + w2v-conformer"""
|
169 |
+
x = self.conv(x)
|
170 |
+
out = self.conformer(x, output_hidden_states=True)
|
171 |
+
hidden_emb = out["hidden_states"]
|
172 |
+
last_emb = out["last_hidden_state"]
|
173 |
+
logits = self.linear(last_emb)
|
174 |
+
logits = {
|
175 |
+
key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
|
176 |
+
for i, key in enumerate(self.features)
|
177 |
+
}
|
178 |
+
return logits, hidden_emb
|
179 |
+
|
180 |
+
@torch.no_grad()
|
181 |
+
def normalize(self, x):
|
182 |
+
"""normalize the input audio to have zero mean unit variance"""
|
183 |
+
for key in x.keys():
|
184 |
+
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
|
185 |
+
return x
|
186 |
+
|
187 |
+
@torch.no_grad()
|
188 |
+
def rearrange(self, x):
|
189 |
+
"""rearrange the batch to flatten every 4 steps"""
|
190 |
+
for key in x.keys():
|
191 |
+
if key == "chromagram":
|
192 |
+
x[key] = rearrange(x[key], "b f t -> b t f")
|
193 |
+
else:
|
194 |
+
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
|
195 |
+
return x
|
196 |
+
|
197 |
+
@torch.no_grad()
|
198 |
+
def tokenize(self, x):
|
199 |
+
out = {}
|
200 |
+
for key in x.keys():
|
201 |
+
layer = getattr(self, "quantizer_%s" % key)
|
202 |
+
out[key] = layer(x[key])
|
203 |
+
return out
|
204 |
+
|
205 |
+
def get_targets(self, x):
|
206 |
+
x = self.preprocessing(x, features=self.features)
|
207 |
+
x = self.normalize(x)
|
208 |
+
x = self.rearrange(x)
|
209 |
+
target_tokens = self.tokenize(x)
|
210 |
+
return target_tokens
|
211 |
+
|
212 |
+
def get_predictions(self, x):
|
213 |
+
# preprocessing
|
214 |
+
x = self.preprocessing(x, features=["melspec_2048"])
|
215 |
+
x = self.normalize(x)
|
216 |
+
|
217 |
+
# encoding
|
218 |
+
logits, hidden_emb = self.encoder(x["melspec_2048"])
|
219 |
+
|
220 |
+
return logits, hidden_emb
|
221 |
+
|
222 |
+
def get_latent(self, x, layer_ix=12):
|
223 |
+
_, hidden_states = self.get_predictions(x)
|
224 |
+
emb = hidden_states[layer_ix]
|
225 |
+
return emb
|
226 |
+
|
227 |
+
def get_loss(self, logits, target_tokens, masked_indices):
|
228 |
+
losses = {}
|
229 |
+
accuracies = {}
|
230 |
+
for key in logits.keys():
|
231 |
+
masked_logits = logits[key][tuple(masked_indices.t())]
|
232 |
+
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
|
233 |
+
losses[key] = self.loss(masked_logits, masked_tokens)
|
234 |
+
accuracies[key] = (
|
235 |
+
torch.sum(masked_logits.argmax(-1) == masked_tokens)
|
236 |
+
/ masked_tokens.numel()
|
237 |
+
)
|
238 |
+
return losses, accuracies
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
# get target feature tokens
|
242 |
+
target_tokens = self.get_targets(x)
|
243 |
+
|
244 |
+
# masking
|
245 |
+
x, masked_indices = self.masking(x)
|
246 |
+
|
247 |
+
# forward
|
248 |
+
logits, hidden_emb = self.get_predictions(x)
|
249 |
+
|
250 |
+
# get loss
|
251 |
+
losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
|
252 |
+
|
253 |
+
return logits, hidden_emb, losses, accuracies
|
slam_llm/models/musicfm/modules/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
slam_llm/models/musicfm/modules/conv.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
#
|
3 |
+
# Copyright 2023 ByteDance Inc.
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
|
6 |
+
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
8 |
+
#
|
9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
10 |
+
#
|
11 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
12 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
13 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
14 |
+
# IN THE SOFTWARE.
|
15 |
+
|
16 |
+
from torch import nn
|
17 |
+
from einops import rearrange
|
18 |
+
|
19 |
+
|
20 |
+
class Res2dModule(nn.Module):
|
21 |
+
def __init__(self, idim, odim, stride=(2, 2)):
|
22 |
+
super(Res2dModule, self).__init__()
|
23 |
+
self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(odim)
|
25 |
+
self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
|
26 |
+
self.bn2 = nn.BatchNorm2d(odim)
|
27 |
+
self.relu = nn.ReLU()
|
28 |
+
|
29 |
+
# residual
|
30 |
+
self.diff = False
|
31 |
+
if (idim != odim) or (stride[0] > 1):
|
32 |
+
self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
33 |
+
self.bn3 = nn.BatchNorm2d(odim)
|
34 |
+
self.diff = True
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
|
38 |
+
if self.diff:
|
39 |
+
x = self.bn3(self.conv3(x))
|
40 |
+
out = x + out
|
41 |
+
out = self.relu(out)
|
42 |
+
return out
|
43 |
+
|
44 |
+
|
45 |
+
class Conv2dSubsampling(nn.Module):
|
46 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
47 |
+
|
48 |
+
Args:
|
49 |
+
idim (int): Input dimension.
|
50 |
+
hdim (int): Hidden dimension.
|
51 |
+
odim (int): Output dimension.
|
52 |
+
strides (list): Sizes of strides.
|
53 |
+
n_bands (int): Number of frequency bands.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
|
57 |
+
"""Construct an Conv2dSubsampling object."""
|
58 |
+
super(Conv2dSubsampling, self).__init__()
|
59 |
+
|
60 |
+
self.conv = nn.Sequential(
|
61 |
+
Res2dModule(idim, hdim, (2, strides[0])),
|
62 |
+
Res2dModule(hdim, hdim, (2, strides[1])),
|
63 |
+
)
|
64 |
+
self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
"""Subsample x.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
x (torch.Tensor): Input tensor (#batch, idim, time).
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
74 |
+
where time' = time // 4.
|
75 |
+
"""
|
76 |
+
|
77 |
+
if x.dim() == 3:
|
78 |
+
x = x.unsqueeze(1) # (b, c, f, t)
|
79 |
+
x = self.conv(x)
|
80 |
+
x = rearrange(x, "b c f t -> b t (c f)")
|
81 |
+
x = self.linear(x)
|
82 |
+
return x
|
slam_llm/models/musicfm/modules/features.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
#
|
3 |
+
# Copyright 2023 ByteDance Inc.
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
|
6 |
+
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
8 |
+
#
|
9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
10 |
+
#
|
11 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
12 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
13 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
14 |
+
# IN THE SOFTWARE.
|
15 |
+
|
16 |
+
import torchaudio
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
|
20 |
+
class MelSTFT(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
sample_rate=24000,
|
24 |
+
n_fft=2048,
|
25 |
+
hop_length=240,
|
26 |
+
n_mels=128,
|
27 |
+
is_db=False,
|
28 |
+
):
|
29 |
+
super(MelSTFT, self).__init__()
|
30 |
+
|
31 |
+
# spectrogram
|
32 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
33 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
|
34 |
+
)
|
35 |
+
|
36 |
+
# amplitude to decibel
|
37 |
+
self.is_db = is_db
|
38 |
+
if is_db:
|
39 |
+
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
40 |
+
|
41 |
+
def forward(self, waveform):
|
42 |
+
if self.is_db:
|
43 |
+
return self.amplitude_to_db(self.mel_stft(waveform))
|
44 |
+
else:
|
45 |
+
return self.mel_stft(waveform)
|
slam_llm/models/musicfm/modules/flash_conformer.py
ADDED
@@ -0,0 +1,2114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch Wav2Vec2-Conformer model."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Optional, Tuple, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.utils.checkpoint
|
24 |
+
from torch import nn
|
25 |
+
from torch.nn import CrossEntropyLoss
|
26 |
+
from torch.nn import functional as F
|
27 |
+
|
28 |
+
from transformers.activations import ACT2FN
|
29 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
30 |
+
from transformers.modeling_outputs import (
|
31 |
+
BaseModelOutput,
|
32 |
+
CausalLMOutput,
|
33 |
+
SequenceClassifierOutput,
|
34 |
+
TokenClassifierOutput,
|
35 |
+
Wav2Vec2BaseModelOutput,
|
36 |
+
XVectorOutput,
|
37 |
+
)
|
38 |
+
from transformers.modeling_utils import PreTrainedModel
|
39 |
+
from transformers.utils import (
|
40 |
+
ModelOutput,
|
41 |
+
add_code_sample_docstrings,
|
42 |
+
add_start_docstrings,
|
43 |
+
add_start_docstrings_to_model_forward,
|
44 |
+
logging,
|
45 |
+
replace_return_docstrings,
|
46 |
+
)
|
47 |
+
from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
|
48 |
+
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__)
|
51 |
+
|
52 |
+
|
53 |
+
_HIDDEN_STATES_START_POSITION = 2
|
54 |
+
|
55 |
+
# General docstring
|
56 |
+
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
|
57 |
+
|
58 |
+
# Base docstring
|
59 |
+
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
|
60 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
|
61 |
+
|
62 |
+
# CTC docstring
|
63 |
+
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
64 |
+
_CTC_EXPECTED_LOSS = 64.21
|
65 |
+
|
66 |
+
|
67 |
+
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
68 |
+
"facebook/wav2vec2-conformer-rel-pos-large",
|
69 |
+
# See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
|
70 |
+
]
|
71 |
+
|
72 |
+
|
73 |
+
@dataclass
|
74 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
|
75 |
+
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
|
76 |
+
"""
|
77 |
+
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
81 |
+
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
|
82 |
+
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
|
83 |
+
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
84 |
+
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
|
85 |
+
projected quantized states.
|
86 |
+
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
87 |
+
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
|
88 |
+
target vectors for contrastive loss.
|
89 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
90 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
91 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
92 |
+
|
93 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
94 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
95 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
96 |
+
sequence_length)`.
|
97 |
+
|
98 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
99 |
+
heads.
|
100 |
+
contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
101 |
+
The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
102 |
+
diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
103 |
+
The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
104 |
+
"""
|
105 |
+
|
106 |
+
loss: Optional[torch.FloatTensor] = None
|
107 |
+
projected_states: torch.FloatTensor = None
|
108 |
+
projected_quantized_states: torch.FloatTensor = None
|
109 |
+
codevector_perplexity: torch.FloatTensor = None
|
110 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
111 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
112 |
+
contrastive_loss: Optional[torch.FloatTensor] = None
|
113 |
+
diversity_loss: Optional[torch.FloatTensor] = None
|
114 |
+
|
115 |
+
|
116 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
117 |
+
def _compute_mask_indices(
|
118 |
+
shape: Tuple[int, int],
|
119 |
+
mask_prob: float,
|
120 |
+
mask_length: int,
|
121 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
122 |
+
min_masks: int = 0,
|
123 |
+
) -> np.ndarray:
|
124 |
+
"""
|
125 |
+
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
126 |
+
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
127 |
+
CPU as part of the preprocessing during training.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
131 |
+
the first element is the batch size and the second element is the length of the axis to span.
|
132 |
+
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
133 |
+
independently generated mask spans of length `mask_length` is computed by
|
134 |
+
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
135 |
+
actual percentage will be smaller.
|
136 |
+
mask_length: size of the mask
|
137 |
+
min_masks: minimum number of masked spans
|
138 |
+
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
139 |
+
each batch dimension.
|
140 |
+
"""
|
141 |
+
batch_size, sequence_length = shape
|
142 |
+
|
143 |
+
if mask_length < 1:
|
144 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
145 |
+
|
146 |
+
if mask_length > sequence_length:
|
147 |
+
raise ValueError(
|
148 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
149 |
+
f" and `sequence_length`: {sequence_length}`"
|
150 |
+
)
|
151 |
+
|
152 |
+
# epsilon is used for probabilistic rounding
|
153 |
+
epsilon = np.random.rand(1).item()
|
154 |
+
|
155 |
+
def compute_num_masked_span(input_length):
|
156 |
+
"""Given input length, compute how many spans should be masked"""
|
157 |
+
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
158 |
+
num_masked_span = max(num_masked_span, min_masks)
|
159 |
+
|
160 |
+
# make sure num masked span <= sequence_length
|
161 |
+
if num_masked_span * mask_length > sequence_length:
|
162 |
+
num_masked_span = sequence_length // mask_length
|
163 |
+
|
164 |
+
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
165 |
+
if input_length - (mask_length - 1) < num_masked_span:
|
166 |
+
num_masked_span = max(input_length - (mask_length - 1), 0)
|
167 |
+
|
168 |
+
return num_masked_span
|
169 |
+
|
170 |
+
# compute number of masked spans in batch
|
171 |
+
input_lengths = (
|
172 |
+
attention_mask.sum(-1).detach().tolist()
|
173 |
+
if attention_mask is not None
|
174 |
+
else [sequence_length for _ in range(batch_size)]
|
175 |
+
)
|
176 |
+
|
177 |
+
# SpecAugment mask to fill
|
178 |
+
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
179 |
+
spec_aug_mask_idxs = []
|
180 |
+
|
181 |
+
max_num_masked_span = compute_num_masked_span(sequence_length)
|
182 |
+
|
183 |
+
if max_num_masked_span == 0:
|
184 |
+
return spec_aug_mask
|
185 |
+
|
186 |
+
for input_length in input_lengths:
|
187 |
+
# compute num of masked spans for this input
|
188 |
+
num_masked_span = compute_num_masked_span(input_length)
|
189 |
+
|
190 |
+
# get random indices to mask
|
191 |
+
spec_aug_mask_idx = np.random.choice(
|
192 |
+
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
193 |
+
)
|
194 |
+
|
195 |
+
# pick first sampled index that will serve as a dummy index to pad vector
|
196 |
+
# to ensure same dimension for all batches due to probabilistic rounding
|
197 |
+
# Picking first sample just pads those vectors twice.
|
198 |
+
if len(spec_aug_mask_idx) == 0:
|
199 |
+
# this case can only happen if `input_length` is strictly smaller then
|
200 |
+
# `sequence_length` in which case the last token has to be a padding
|
201 |
+
# token which we can use as a dummy mask id
|
202 |
+
dummy_mask_idx = sequence_length - 1
|
203 |
+
else:
|
204 |
+
dummy_mask_idx = spec_aug_mask_idx[0]
|
205 |
+
|
206 |
+
spec_aug_mask_idx = np.concatenate(
|
207 |
+
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
208 |
+
)
|
209 |
+
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
210 |
+
|
211 |
+
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
212 |
+
|
213 |
+
# expand masked indices to masked spans
|
214 |
+
spec_aug_mask_idxs = np.broadcast_to(
|
215 |
+
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
216 |
+
)
|
217 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
218 |
+
|
219 |
+
# add offset to the starting indexes so that indexes now create a span
|
220 |
+
offsets = np.arange(mask_length)[None, None, :]
|
221 |
+
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
222 |
+
batch_size, max_num_masked_span * mask_length
|
223 |
+
)
|
224 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
225 |
+
|
226 |
+
# ensure that we cannot have indices larger than sequence_length
|
227 |
+
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
228 |
+
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
229 |
+
|
230 |
+
# scatter indices to mask
|
231 |
+
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
232 |
+
|
233 |
+
return spec_aug_mask
|
234 |
+
|
235 |
+
|
236 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
|
237 |
+
def _sample_negative_indices(
|
238 |
+
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
|
239 |
+
):
|
240 |
+
"""
|
241 |
+
Sample `num_negatives` vectors from feature vectors.
|
242 |
+
"""
|
243 |
+
batch_size, sequence_length = features_shape
|
244 |
+
|
245 |
+
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
246 |
+
sequence_length_range = np.arange(sequence_length)
|
247 |
+
|
248 |
+
# get `num_negatives` random vector indices from the same utterance
|
249 |
+
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
|
250 |
+
|
251 |
+
mask_time_indices = (
|
252 |
+
mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
|
253 |
+
)
|
254 |
+
|
255 |
+
for batch_idx in range(batch_size):
|
256 |
+
high = mask_time_indices[batch_idx].sum() - 1
|
257 |
+
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
|
258 |
+
|
259 |
+
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
|
260 |
+
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
|
261 |
+
# avoid sampling the same positive vector, but keep the distribution uniform
|
262 |
+
sampled_indices[sampled_indices >= feature_indices] += 1
|
263 |
+
|
264 |
+
# remap to actual indices
|
265 |
+
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
|
266 |
+
|
267 |
+
# correct for batch size
|
268 |
+
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
269 |
+
|
270 |
+
return sampled_negative_indices
|
271 |
+
|
272 |
+
|
273 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
274 |
+
class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
|
275 |
+
def __init__(self, config, layer_id=0):
|
276 |
+
super().__init__()
|
277 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
278 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
279 |
+
|
280 |
+
self.conv = nn.Conv1d(
|
281 |
+
self.in_conv_dim,
|
282 |
+
self.out_conv_dim,
|
283 |
+
kernel_size=config.conv_kernel[layer_id],
|
284 |
+
stride=config.conv_stride[layer_id],
|
285 |
+
bias=config.conv_bias,
|
286 |
+
)
|
287 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
288 |
+
|
289 |
+
def forward(self, hidden_states):
|
290 |
+
hidden_states = self.conv(hidden_states)
|
291 |
+
hidden_states = self.activation(hidden_states)
|
292 |
+
return hidden_states
|
293 |
+
|
294 |
+
|
295 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
296 |
+
class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
|
297 |
+
def __init__(self, config, layer_id=0):
|
298 |
+
super().__init__()
|
299 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
300 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
301 |
+
|
302 |
+
self.conv = nn.Conv1d(
|
303 |
+
self.in_conv_dim,
|
304 |
+
self.out_conv_dim,
|
305 |
+
kernel_size=config.conv_kernel[layer_id],
|
306 |
+
stride=config.conv_stride[layer_id],
|
307 |
+
bias=config.conv_bias,
|
308 |
+
)
|
309 |
+
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
310 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
311 |
+
|
312 |
+
def forward(self, hidden_states):
|
313 |
+
hidden_states = self.conv(hidden_states)
|
314 |
+
|
315 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
316 |
+
hidden_states = self.layer_norm(hidden_states)
|
317 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
318 |
+
|
319 |
+
hidden_states = self.activation(hidden_states)
|
320 |
+
return hidden_states
|
321 |
+
|
322 |
+
|
323 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
324 |
+
class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
|
325 |
+
def __init__(self, config, layer_id=0):
|
326 |
+
super().__init__()
|
327 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
328 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
329 |
+
|
330 |
+
self.conv = nn.Conv1d(
|
331 |
+
self.in_conv_dim,
|
332 |
+
self.out_conv_dim,
|
333 |
+
kernel_size=config.conv_kernel[layer_id],
|
334 |
+
stride=config.conv_stride[layer_id],
|
335 |
+
bias=config.conv_bias,
|
336 |
+
)
|
337 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
338 |
+
|
339 |
+
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
340 |
+
|
341 |
+
def forward(self, hidden_states):
|
342 |
+
hidden_states = self.conv(hidden_states)
|
343 |
+
hidden_states = self.layer_norm(hidden_states)
|
344 |
+
hidden_states = self.activation(hidden_states)
|
345 |
+
return hidden_states
|
346 |
+
|
347 |
+
|
348 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
|
349 |
+
class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
|
350 |
+
def __init__(self, config):
|
351 |
+
super().__init__()
|
352 |
+
self.conv = nn.Conv1d(
|
353 |
+
config.hidden_size,
|
354 |
+
config.hidden_size,
|
355 |
+
kernel_size=config.num_conv_pos_embeddings,
|
356 |
+
padding=config.num_conv_pos_embeddings // 2,
|
357 |
+
groups=config.num_conv_pos_embedding_groups,
|
358 |
+
)
|
359 |
+
|
360 |
+
if is_deepspeed_zero3_enabled():
|
361 |
+
import deepspeed
|
362 |
+
|
363 |
+
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
364 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
365 |
+
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
366 |
+
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
367 |
+
else:
|
368 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
369 |
+
|
370 |
+
self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
|
371 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
372 |
+
|
373 |
+
def forward(self, hidden_states):
|
374 |
+
hidden_states = hidden_states.transpose(1, 2)
|
375 |
+
|
376 |
+
hidden_states = self.conv(hidden_states)
|
377 |
+
hidden_states = self.padding(hidden_states)
|
378 |
+
hidden_states = self.activation(hidden_states)
|
379 |
+
|
380 |
+
hidden_states = hidden_states.transpose(1, 2)
|
381 |
+
return hidden_states
|
382 |
+
|
383 |
+
|
384 |
+
class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
|
385 |
+
"""Rotary positional embedding
|
386 |
+
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
|
387 |
+
"""
|
388 |
+
|
389 |
+
def __init__(self, config):
|
390 |
+
super().__init__()
|
391 |
+
dim = config.hidden_size // config.num_attention_heads
|
392 |
+
base = config.rotary_embedding_base
|
393 |
+
|
394 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
395 |
+
self.register_buffer("inv_freq", inv_freq)
|
396 |
+
self.cached_sequence_length = None
|
397 |
+
self.cached_rotary_positional_embedding = None
|
398 |
+
|
399 |
+
def forward(self, hidden_states):
|
400 |
+
sequence_length = hidden_states.shape[1]
|
401 |
+
|
402 |
+
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
|
403 |
+
return self.cached_rotary_positional_embedding
|
404 |
+
|
405 |
+
self.cached_sequence_length = sequence_length
|
406 |
+
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
407 |
+
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
408 |
+
embeddings = torch.cat((freqs, freqs), dim=-1)
|
409 |
+
|
410 |
+
cos_embeddings = embeddings.cos()[:, None, None, :]
|
411 |
+
sin_embeddings = embeddings.sin()[:, None, None, :]
|
412 |
+
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
|
413 |
+
return self.cached_rotary_positional_embedding
|
414 |
+
|
415 |
+
|
416 |
+
class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
|
417 |
+
"""Relative positional encoding module."""
|
418 |
+
|
419 |
+
def __init__(self, config):
|
420 |
+
super().__init__()
|
421 |
+
self.max_len = config.max_source_positions
|
422 |
+
self.d_model = config.hidden_size
|
423 |
+
self.pe = None
|
424 |
+
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
|
425 |
+
|
426 |
+
def extend_pe(self, x):
|
427 |
+
# Reset the positional encodings
|
428 |
+
if self.pe is not None:
|
429 |
+
# self.pe contains both positive and negative parts
|
430 |
+
# the length of self.pe is 2 * input_len - 1
|
431 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
432 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
433 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
434 |
+
return
|
435 |
+
# Suppose `i` is the position of query vector and `j` is the
|
436 |
+
# position of key vector. We use positive relative positions when keys
|
437 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
438 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
439 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
440 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
441 |
+
div_term = torch.exp(
|
442 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
|
443 |
+
)
|
444 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
445 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
446 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
447 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
448 |
+
|
449 |
+
# Reverse the order of positive indices and concat both positive and
|
450 |
+
# negative indices. This is used to support the shifting trick
|
451 |
+
# as in https://arxiv.org/abs/1901.02860
|
452 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
453 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
454 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
455 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
456 |
+
|
457 |
+
def forward(self, hidden_states: torch.Tensor):
|
458 |
+
self.extend_pe(hidden_states)
|
459 |
+
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
|
460 |
+
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
|
461 |
+
relative_position_embeddings = self.pe[:, start_idx:end_idx]
|
462 |
+
|
463 |
+
return relative_position_embeddings
|
464 |
+
|
465 |
+
|
466 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
|
467 |
+
class Wav2Vec2ConformerSamePadLayer(nn.Module):
|
468 |
+
def __init__(self, num_conv_pos_embeddings):
|
469 |
+
super().__init__()
|
470 |
+
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
471 |
+
|
472 |
+
def forward(self, hidden_states):
|
473 |
+
if self.num_pad_remove > 0:
|
474 |
+
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
475 |
+
return hidden_states
|
476 |
+
|
477 |
+
|
478 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
|
479 |
+
class Wav2Vec2ConformerFeatureEncoder(nn.Module):
|
480 |
+
"""Construct the features from raw audio waveform"""
|
481 |
+
|
482 |
+
def __init__(self, config):
|
483 |
+
super().__init__()
|
484 |
+
|
485 |
+
if config.feat_extract_norm == "group":
|
486 |
+
conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
|
487 |
+
Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
|
488 |
+
for i in range(config.num_feat_extract_layers - 1)
|
489 |
+
]
|
490 |
+
elif config.feat_extract_norm == "layer":
|
491 |
+
conv_layers = [
|
492 |
+
Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
|
493 |
+
]
|
494 |
+
else:
|
495 |
+
raise ValueError(
|
496 |
+
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
497 |
+
)
|
498 |
+
self.conv_layers = nn.ModuleList(conv_layers)
|
499 |
+
self.gradient_checkpointing = False
|
500 |
+
self._requires_grad = True
|
501 |
+
|
502 |
+
def _freeze_parameters(self):
|
503 |
+
for param in self.parameters():
|
504 |
+
param.requires_grad = False
|
505 |
+
self._requires_grad = False
|
506 |
+
|
507 |
+
def forward(self, input_values):
|
508 |
+
hidden_states = input_values[:, None]
|
509 |
+
|
510 |
+
# make sure hidden_states require grad for gradient_checkpointing
|
511 |
+
if self._requires_grad and self.training:
|
512 |
+
hidden_states.requires_grad = True
|
513 |
+
|
514 |
+
for conv_layer in self.conv_layers:
|
515 |
+
if self._requires_grad and self.gradient_checkpointing and self.training:
|
516 |
+
|
517 |
+
def create_custom_forward(module):
|
518 |
+
def custom_forward(*inputs):
|
519 |
+
return module(*inputs)
|
520 |
+
|
521 |
+
return custom_forward
|
522 |
+
|
523 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
524 |
+
create_custom_forward(conv_layer),
|
525 |
+
hidden_states,
|
526 |
+
)
|
527 |
+
else:
|
528 |
+
hidden_states = conv_layer(hidden_states)
|
529 |
+
|
530 |
+
return hidden_states
|
531 |
+
|
532 |
+
|
533 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
|
534 |
+
class Wav2Vec2ConformerFeatureProjection(nn.Module):
|
535 |
+
def __init__(self, config):
|
536 |
+
super().__init__()
|
537 |
+
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
538 |
+
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
539 |
+
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
540 |
+
|
541 |
+
def forward(self, hidden_states):
|
542 |
+
# non-projected hidden states are needed for quantization
|
543 |
+
norm_hidden_states = self.layer_norm(hidden_states)
|
544 |
+
hidden_states = self.projection(norm_hidden_states)
|
545 |
+
hidden_states = self.dropout(hidden_states)
|
546 |
+
return hidden_states, norm_hidden_states
|
547 |
+
|
548 |
+
|
549 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
|
550 |
+
class Wav2Vec2ConformerFeedForward(nn.Module):
|
551 |
+
def __init__(self, config):
|
552 |
+
super().__init__()
|
553 |
+
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
554 |
+
|
555 |
+
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
556 |
+
if isinstance(config.hidden_act, str):
|
557 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
558 |
+
else:
|
559 |
+
self.intermediate_act_fn = config.hidden_act
|
560 |
+
|
561 |
+
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
562 |
+
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
563 |
+
|
564 |
+
def forward(self, hidden_states):
|
565 |
+
hidden_states = self.intermediate_dense(hidden_states)
|
566 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
567 |
+
hidden_states = self.intermediate_dropout(hidden_states)
|
568 |
+
|
569 |
+
hidden_states = self.output_dense(hidden_states)
|
570 |
+
hidden_states = self.output_dropout(hidden_states)
|
571 |
+
return hidden_states
|
572 |
+
|
573 |
+
|
574 |
+
class Wav2Vec2ConformerConvolutionModule(nn.Module):
|
575 |
+
"""Convolution block used in the conformer block"""
|
576 |
+
|
577 |
+
def __init__(self, config):
|
578 |
+
super().__init__()
|
579 |
+
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
|
580 |
+
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
|
581 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
582 |
+
self.pointwise_conv1 = torch.nn.Conv1d(
|
583 |
+
config.hidden_size,
|
584 |
+
2 * config.hidden_size,
|
585 |
+
kernel_size=1,
|
586 |
+
stride=1,
|
587 |
+
padding=0,
|
588 |
+
bias=False,
|
589 |
+
)
|
590 |
+
self.glu = torch.nn.GLU(dim=1)
|
591 |
+
self.depthwise_conv = torch.nn.Conv1d(
|
592 |
+
config.hidden_size,
|
593 |
+
config.hidden_size,
|
594 |
+
config.conv_depthwise_kernel_size,
|
595 |
+
stride=1,
|
596 |
+
padding=(config.conv_depthwise_kernel_size - 1) // 2,
|
597 |
+
groups=config.hidden_size,
|
598 |
+
bias=False,
|
599 |
+
)
|
600 |
+
self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
|
601 |
+
self.activation = ACT2FN[config.hidden_act]
|
602 |
+
self.pointwise_conv2 = torch.nn.Conv1d(
|
603 |
+
config.hidden_size,
|
604 |
+
config.hidden_size,
|
605 |
+
kernel_size=1,
|
606 |
+
stride=1,
|
607 |
+
padding=0,
|
608 |
+
bias=False,
|
609 |
+
)
|
610 |
+
self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
|
611 |
+
|
612 |
+
def forward(self, hidden_states):
|
613 |
+
hidden_states = self.layer_norm(hidden_states)
|
614 |
+
# exchange the temporal dimension and the feature dimension
|
615 |
+
hidden_states = hidden_states.transpose(1, 2)
|
616 |
+
|
617 |
+
# GLU mechanism
|
618 |
+
# => (batch, 2*channel, dim)
|
619 |
+
hidden_states = self.pointwise_conv1(hidden_states)
|
620 |
+
# => (batch, channel, dim)
|
621 |
+
hidden_states = self.glu(hidden_states)
|
622 |
+
|
623 |
+
# 1D Depthwise Conv
|
624 |
+
hidden_states = self.depthwise_conv(hidden_states)
|
625 |
+
hidden_states = self.batch_norm(hidden_states)
|
626 |
+
hidden_states = self.activation(hidden_states)
|
627 |
+
|
628 |
+
hidden_states = self.pointwise_conv2(hidden_states)
|
629 |
+
hidden_states = self.dropout(hidden_states)
|
630 |
+
hidden_states = hidden_states.transpose(1, 2)
|
631 |
+
return hidden_states
|
632 |
+
|
633 |
+
|
634 |
+
class Wav2Vec2ConformerSelfAttention(nn.Module):
|
635 |
+
"""Construct an Wav2Vec2ConformerSelfAttention object.
|
636 |
+
Can be enhanced with rotary or relative position embeddings.
|
637 |
+
"""
|
638 |
+
|
639 |
+
def __init__(self, config):
|
640 |
+
super().__init__()
|
641 |
+
|
642 |
+
self.head_size = config.hidden_size // config.num_attention_heads
|
643 |
+
self.num_heads = config.num_attention_heads
|
644 |
+
self.position_embeddings_type = config.position_embeddings_type
|
645 |
+
|
646 |
+
self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
|
647 |
+
self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
|
648 |
+
self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
|
649 |
+
self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
|
650 |
+
|
651 |
+
self.dropout = nn.Dropout(p=config.attention_dropout)
|
652 |
+
self.dropout_p = config.attention_dropout
|
653 |
+
|
654 |
+
self.is_causal = config.is_causal
|
655 |
+
|
656 |
+
if self.position_embeddings_type == "relative":
|
657 |
+
# linear transformation for positional encoding
|
658 |
+
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
659 |
+
# these two learnable bias are used in matrix c and matrix d
|
660 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
661 |
+
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
662 |
+
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
663 |
+
|
664 |
+
def forward(
|
665 |
+
self,
|
666 |
+
hidden_states: torch.Tensor,
|
667 |
+
attention_mask: Optional[torch.Tensor] = None,
|
668 |
+
relative_position_embeddings: Optional[torch.Tensor] = None,
|
669 |
+
output_attentions: bool = False,
|
670 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
671 |
+
# self-attention mechanism
|
672 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
673 |
+
|
674 |
+
# make sure query/key states can be != value states
|
675 |
+
query_key_states = hidden_states
|
676 |
+
value_states = hidden_states
|
677 |
+
|
678 |
+
if self.position_embeddings_type == "rotary":
|
679 |
+
if relative_position_embeddings is None:
|
680 |
+
raise ValueError(
|
681 |
+
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
|
682 |
+
)
|
683 |
+
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
|
684 |
+
|
685 |
+
# project query_key_states and value_states
|
686 |
+
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
687 |
+
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
688 |
+
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
|
689 |
+
|
690 |
+
# => (batch, head, time1, d_k)
|
691 |
+
query = query.transpose(1, 2)
|
692 |
+
key = key.transpose(1, 2)
|
693 |
+
value = value.transpose(1, 2)
|
694 |
+
|
695 |
+
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
696 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
|
697 |
+
probs = None
|
698 |
+
|
699 |
+
# # apply attention_mask if necessary
|
700 |
+
# if attention_mask is not None:
|
701 |
+
# scores = scores + attention_mask
|
702 |
+
|
703 |
+
# # => (batch, head, time1, time2)
|
704 |
+
# probs = torch.softmax(scores, dim=-1)
|
705 |
+
# probs = self.dropout(probs)
|
706 |
+
|
707 |
+
# # => (batch, head, time1, d_k)
|
708 |
+
# hidden_states = torch.matmul(probs, value)
|
709 |
+
|
710 |
+
# => (batch, time1, hidden_size)
|
711 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
|
712 |
+
hidden_states = self.linear_out(hidden_states)
|
713 |
+
|
714 |
+
return hidden_states, probs
|
715 |
+
|
716 |
+
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
|
717 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
718 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
|
719 |
+
|
720 |
+
cos = relative_position_embeddings[0, :sequence_length, ...]
|
721 |
+
sin = relative_position_embeddings[1, :sequence_length, ...]
|
722 |
+
|
723 |
+
# rotate hidden_states with rotary embeddings
|
724 |
+
hidden_states = hidden_states.transpose(0, 1)
|
725 |
+
rotated_states_begin = hidden_states[..., : self.head_size // 2]
|
726 |
+
rotated_states_end = hidden_states[..., self.head_size // 2 :]
|
727 |
+
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
|
728 |
+
hidden_states = (hidden_states * cos) + (rotated_states * sin)
|
729 |
+
hidden_states = hidden_states.transpose(0, 1)
|
730 |
+
|
731 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
|
732 |
+
|
733 |
+
return hidden_states
|
734 |
+
|
735 |
+
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
|
736 |
+
# 1. project positional embeddings
|
737 |
+
# => (batch, head, 2*time1-1, d_k)
|
738 |
+
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
|
739 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
|
740 |
+
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
|
741 |
+
)
|
742 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
|
743 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
|
744 |
+
|
745 |
+
# 2. Add bias to query
|
746 |
+
# => (batch, head, time1, d_k)
|
747 |
+
query = query.transpose(1, 2)
|
748 |
+
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
749 |
+
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
750 |
+
|
751 |
+
# 3. attention score: first compute matrix a and matrix c
|
752 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
753 |
+
# => (batch, head, time1, time2)
|
754 |
+
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
755 |
+
|
756 |
+
# 4. then compute matrix b and matrix d
|
757 |
+
# => (batch, head, time1, 2*time1-1)
|
758 |
+
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
|
759 |
+
|
760 |
+
# 5. shift matrix b and matrix d
|
761 |
+
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
|
762 |
+
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
|
763 |
+
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
|
764 |
+
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
|
765 |
+
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
|
766 |
+
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
|
767 |
+
|
768 |
+
# 6. sum matrices
|
769 |
+
# => (batch, head, time1, time2)
|
770 |
+
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
|
771 |
+
|
772 |
+
return scores
|
773 |
+
|
774 |
+
|
775 |
+
class Wav2Vec2ConformerEncoderLayer(nn.Module):
|
776 |
+
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
|
777 |
+
|
778 |
+
def __init__(self, config):
|
779 |
+
super().__init__()
|
780 |
+
embed_dim = config.hidden_size
|
781 |
+
dropout = config.attention_dropout
|
782 |
+
|
783 |
+
# Feed-forward 1
|
784 |
+
self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
|
785 |
+
self.ffn1 = Wav2Vec2ConformerFeedForward(config)
|
786 |
+
|
787 |
+
# Self-Attention
|
788 |
+
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
|
789 |
+
self.self_attn_dropout = torch.nn.Dropout(dropout)
|
790 |
+
self.self_attn = Wav2Vec2ConformerSelfAttention(config)
|
791 |
+
|
792 |
+
# Conformer Convolution
|
793 |
+
self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
|
794 |
+
|
795 |
+
# Feed-forward 2
|
796 |
+
self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
|
797 |
+
self.ffn2 = Wav2Vec2ConformerFeedForward(config)
|
798 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
799 |
+
|
800 |
+
def forward(
|
801 |
+
self,
|
802 |
+
hidden_states,
|
803 |
+
attention_mask: Optional[torch.Tensor] = None,
|
804 |
+
relative_position_embeddings: Optional[torch.Tensor] = None,
|
805 |
+
output_attentions: bool = False,
|
806 |
+
):
|
807 |
+
hidden_states = hidden_states
|
808 |
+
|
809 |
+
# 1. Feed-Forward 1 layer
|
810 |
+
residual = hidden_states
|
811 |
+
hidden_states = self.ffn1_layer_norm(hidden_states)
|
812 |
+
hidden_states = self.ffn1(hidden_states)
|
813 |
+
hidden_states = hidden_states * 0.5 + residual
|
814 |
+
residual = hidden_states
|
815 |
+
|
816 |
+
# 2. Self-Attention layer
|
817 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
818 |
+
hidden_states, attn_weigts = self.self_attn(
|
819 |
+
hidden_states=hidden_states,
|
820 |
+
attention_mask=attention_mask,
|
821 |
+
relative_position_embeddings=relative_position_embeddings,
|
822 |
+
output_attentions=output_attentions,
|
823 |
+
)
|
824 |
+
hidden_states = self.self_attn_dropout(hidden_states)
|
825 |
+
hidden_states = hidden_states + residual
|
826 |
+
|
827 |
+
# 3. Convolutional Layer
|
828 |
+
residual = hidden_states
|
829 |
+
hidden_states = self.conv_module(hidden_states)
|
830 |
+
hidden_states = residual + hidden_states
|
831 |
+
|
832 |
+
# 4. Feed-Forward 2 Layer
|
833 |
+
residual = hidden_states
|
834 |
+
hidden_states = self.ffn2_layer_norm(hidden_states)
|
835 |
+
hidden_states = self.ffn2(hidden_states)
|
836 |
+
hidden_states = hidden_states * 0.5 + residual
|
837 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
838 |
+
|
839 |
+
return hidden_states, attn_weigts
|
840 |
+
|
841 |
+
|
842 |
+
class Wav2Vec2ConformerEncoder(nn.Module):
|
843 |
+
def __init__(self, config, is_causal=False):
|
844 |
+
super().__init__()
|
845 |
+
config.is_causal = is_causal
|
846 |
+
self.config = config
|
847 |
+
|
848 |
+
if config.position_embeddings_type == "relative":
|
849 |
+
self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
|
850 |
+
elif config.position_embeddings_type == "rotary":
|
851 |
+
self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
|
852 |
+
else:
|
853 |
+
self.embed_positions = None
|
854 |
+
|
855 |
+
self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
|
856 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
857 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
858 |
+
self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
859 |
+
self.gradient_checkpointing = False
|
860 |
+
|
861 |
+
def forward(
|
862 |
+
self,
|
863 |
+
hidden_states,
|
864 |
+
attention_mask=None,
|
865 |
+
output_attentions=False,
|
866 |
+
output_hidden_states=False,
|
867 |
+
return_dict=True,
|
868 |
+
):
|
869 |
+
all_hidden_states = () if output_hidden_states else None
|
870 |
+
all_self_attentions = () if output_attentions else None
|
871 |
+
|
872 |
+
if attention_mask is not None:
|
873 |
+
# make sure padded tokens output 0
|
874 |
+
hidden_states[~attention_mask] = 0.0
|
875 |
+
|
876 |
+
# extend attention_mask
|
877 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
|
878 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
879 |
+
attention_mask = attention_mask.expand(
|
880 |
+
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
881 |
+
)
|
882 |
+
|
883 |
+
hidden_states = self.dropout(hidden_states)
|
884 |
+
|
885 |
+
if self.embed_positions is not None:
|
886 |
+
relative_position_embeddings = self.embed_positions(hidden_states)
|
887 |
+
else:
|
888 |
+
relative_position_embeddings = None
|
889 |
+
|
890 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
891 |
+
|
892 |
+
for i, layer in enumerate(self.layers):
|
893 |
+
if output_hidden_states:
|
894 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
895 |
+
|
896 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
897 |
+
dropout_probability = np.random.uniform(0, 1)
|
898 |
+
|
899 |
+
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
900 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
901 |
+
# under deepspeed zero3 all gpus must run in sync
|
902 |
+
if self.gradient_checkpointing and self.training:
|
903 |
+
# create gradient checkpointing function
|
904 |
+
def create_custom_forward(module):
|
905 |
+
def custom_forward(*inputs):
|
906 |
+
return module(*inputs, output_attentions)
|
907 |
+
|
908 |
+
return custom_forward
|
909 |
+
|
910 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
911 |
+
create_custom_forward(layer),
|
912 |
+
hidden_states,
|
913 |
+
attention_mask,
|
914 |
+
relative_position_embeddings,
|
915 |
+
)
|
916 |
+
else:
|
917 |
+
layer_outputs = layer(
|
918 |
+
hidden_states,
|
919 |
+
attention_mask=attention_mask,
|
920 |
+
relative_position_embeddings=relative_position_embeddings,
|
921 |
+
output_attentions=output_attentions,
|
922 |
+
)
|
923 |
+
hidden_states = layer_outputs[0]
|
924 |
+
|
925 |
+
if skip_the_layer:
|
926 |
+
layer_outputs = (None, None)
|
927 |
+
|
928 |
+
if output_attentions:
|
929 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
930 |
+
|
931 |
+
hidden_states = self.layer_norm(hidden_states)
|
932 |
+
if output_hidden_states:
|
933 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
934 |
+
|
935 |
+
if not return_dict:
|
936 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
937 |
+
return BaseModelOutput(
|
938 |
+
last_hidden_state=hidden_states,
|
939 |
+
hidden_states=all_hidden_states,
|
940 |
+
attentions=all_self_attentions,
|
941 |
+
)
|
942 |
+
|
943 |
+
|
944 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
|
945 |
+
class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
|
946 |
+
"""
|
947 |
+
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
|
948 |
+
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
|
949 |
+
"""
|
950 |
+
|
951 |
+
def __init__(self, config):
|
952 |
+
super().__init__()
|
953 |
+
self.num_groups = config.num_codevector_groups
|
954 |
+
self.num_vars = config.num_codevectors_per_group
|
955 |
+
|
956 |
+
if config.codevector_dim % self.num_groups != 0:
|
957 |
+
raise ValueError(
|
958 |
+
f"`config.codevector_dim {config.codevector_dim} must be divisible "
|
959 |
+
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
|
960 |
+
)
|
961 |
+
|
962 |
+
# storage for codebook variables (codewords)
|
963 |
+
self.codevectors = nn.Parameter(
|
964 |
+
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
|
965 |
+
)
|
966 |
+
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
967 |
+
|
968 |
+
# can be decayed for training
|
969 |
+
self.temperature = 2
|
970 |
+
|
971 |
+
@staticmethod
|
972 |
+
def _compute_perplexity(probs, mask=None):
|
973 |
+
if mask is not None:
|
974 |
+
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
|
975 |
+
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
|
976 |
+
marginal_probs = probs.sum(dim=0) / mask.sum()
|
977 |
+
else:
|
978 |
+
marginal_probs = probs.mean(dim=0)
|
979 |
+
|
980 |
+
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
981 |
+
return perplexity
|
982 |
+
|
983 |
+
def forward(self, hidden_states, mask_time_indices=None):
|
984 |
+
batch_size, sequence_length, hidden_size = hidden_states.shape
|
985 |
+
|
986 |
+
# project to codevector dim
|
987 |
+
hidden_states = self.weight_proj(hidden_states)
|
988 |
+
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
989 |
+
|
990 |
+
if self.training:
|
991 |
+
# sample code vector probs via gumbel in differentiateable way
|
992 |
+
codevector_probs = nn.functional.gumbel_softmax(
|
993 |
+
hidden_states.float(), tau=self.temperature, hard=True
|
994 |
+
).type_as(hidden_states)
|
995 |
+
|
996 |
+
# compute perplexity
|
997 |
+
codevector_soft_dist = torch.softmax(
|
998 |
+
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
999 |
+
)
|
1000 |
+
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
1001 |
+
else:
|
1002 |
+
# take argmax in non-differentiable way
|
1003 |
+
# comptute hard codevector distribution (one hot)
|
1004 |
+
codevector_idx = hidden_states.argmax(dim=-1)
|
1005 |
+
codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
|
1006 |
+
-1, codevector_idx.view(-1, 1), 1.0
|
1007 |
+
)
|
1008 |
+
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
1009 |
+
|
1010 |
+
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
|
1011 |
+
|
1012 |
+
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
1013 |
+
# use probs to retrieve codevectors
|
1014 |
+
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
1015 |
+
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
1016 |
+
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
|
1017 |
+
|
1018 |
+
return codevectors, perplexity
|
1019 |
+
|
1020 |
+
|
1021 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
|
1022 |
+
class Wav2Vec2ConformerAdapter(nn.Module):
|
1023 |
+
def __init__(self, config):
|
1024 |
+
super().__init__()
|
1025 |
+
|
1026 |
+
# feature dim might need to be down-projected
|
1027 |
+
if config.output_hidden_size != config.hidden_size:
|
1028 |
+
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
1029 |
+
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
|
1030 |
+
else:
|
1031 |
+
self.proj = self.proj_layer_norm = None
|
1032 |
+
|
1033 |
+
self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
1034 |
+
self.layerdrop = config.layerdrop
|
1035 |
+
|
1036 |
+
def forward(self, hidden_states):
|
1037 |
+
# down project hidden_states if necessary
|
1038 |
+
if self.proj is not None and self.proj_layer_norm is not None:
|
1039 |
+
hidden_states = self.proj(hidden_states)
|
1040 |
+
hidden_states = self.proj_layer_norm(hidden_states)
|
1041 |
+
|
1042 |
+
hidden_states = hidden_states.transpose(1, 2)
|
1043 |
+
|
1044 |
+
for layer in self.layers:
|
1045 |
+
layerdrop_prob = np.random.random()
|
1046 |
+
if not self.training or (layerdrop_prob > self.layerdrop):
|
1047 |
+
hidden_states = layer(hidden_states)
|
1048 |
+
|
1049 |
+
hidden_states = hidden_states.transpose(1, 2)
|
1050 |
+
return hidden_states
|
1051 |
+
|
1052 |
+
|
1053 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
|
1054 |
+
class Wav2Vec2ConformerAdapterLayer(nn.Module):
|
1055 |
+
def __init__(self, config):
|
1056 |
+
super().__init__()
|
1057 |
+
self.conv = nn.Conv1d(
|
1058 |
+
config.output_hidden_size,
|
1059 |
+
2 * config.output_hidden_size,
|
1060 |
+
config.adapter_kernel_size,
|
1061 |
+
stride=config.adapter_stride,
|
1062 |
+
padding=1,
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
def forward(self, hidden_states):
|
1066 |
+
hidden_states = self.conv(hidden_states)
|
1067 |
+
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
1068 |
+
|
1069 |
+
return hidden_states
|
1070 |
+
|
1071 |
+
|
1072 |
+
class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
|
1073 |
+
"""
|
1074 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
1075 |
+
models.
|
1076 |
+
"""
|
1077 |
+
|
1078 |
+
config_class = Wav2Vec2ConformerConfig
|
1079 |
+
base_model_prefix = "wav2vec2_conformer"
|
1080 |
+
main_input_name = "input_values"
|
1081 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
1082 |
+
supports_gradient_checkpointing = True
|
1083 |
+
|
1084 |
+
def _init_weights(self, module):
|
1085 |
+
"""Initialize the weights"""
|
1086 |
+
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
|
1087 |
+
if isinstance(module, Wav2Vec2ConformerForPreTraining):
|
1088 |
+
module.project_hid.reset_parameters()
|
1089 |
+
module.project_q.reset_parameters()
|
1090 |
+
module.project_hid._is_hf_initialized = True
|
1091 |
+
module.project_q._is_hf_initialized = True
|
1092 |
+
# gumbel softmax requires special init
|
1093 |
+
elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
|
1094 |
+
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
1095 |
+
module.weight_proj.bias.data.zero_()
|
1096 |
+
nn.init.uniform_(module.codevectors)
|
1097 |
+
elif isinstance(module, Wav2Vec2ConformerSelfAttention):
|
1098 |
+
if hasattr(module, "pos_bias_u"):
|
1099 |
+
nn.init.xavier_uniform_(module.pos_bias_u)
|
1100 |
+
if hasattr(module, "pos_bias_v"):
|
1101 |
+
nn.init.xavier_uniform_(module.pos_bias_v)
|
1102 |
+
elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
|
1103 |
+
nn.init.normal_(
|
1104 |
+
module.conv.weight,
|
1105 |
+
mean=0,
|
1106 |
+
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
1107 |
+
)
|
1108 |
+
nn.init.constant_(module.conv.bias, 0)
|
1109 |
+
elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
|
1110 |
+
k = math.sqrt(1 / module.projection.in_features)
|
1111 |
+
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
1112 |
+
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
1113 |
+
elif isinstance(module, nn.Linear):
|
1114 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
1115 |
+
|
1116 |
+
if module.bias is not None:
|
1117 |
+
module.bias.data.zero_()
|
1118 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
1119 |
+
module.bias.data.zero_()
|
1120 |
+
module.weight.data.fill_(1.0)
|
1121 |
+
elif isinstance(module, nn.Conv1d):
|
1122 |
+
nn.init.kaiming_normal_(module.weight)
|
1123 |
+
|
1124 |
+
if module.bias is not None:
|
1125 |
+
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
1126 |
+
nn.init.uniform_(module.bias, a=-k, b=k)
|
1127 |
+
|
1128 |
+
def _get_feat_extract_output_lengths(
|
1129 |
+
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
|
1130 |
+
):
|
1131 |
+
"""
|
1132 |
+
Computes the output length of the convolutional layers
|
1133 |
+
"""
|
1134 |
+
|
1135 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
1136 |
+
|
1137 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
1138 |
+
# 1D convolutional layer output length formula taken
|
1139 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
1140 |
+
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
1141 |
+
|
1142 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
1143 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
1144 |
+
|
1145 |
+
if add_adapter:
|
1146 |
+
for _ in range(self.config.num_adapter_layers):
|
1147 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
1148 |
+
|
1149 |
+
return input_lengths
|
1150 |
+
|
1151 |
+
def _get_feature_vector_attention_mask(
|
1152 |
+
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
|
1153 |
+
):
|
1154 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
1155 |
+
# on inference mode.
|
1156 |
+
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
1157 |
+
|
1158 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
1159 |
+
output_lengths = output_lengths.to(torch.long)
|
1160 |
+
|
1161 |
+
batch_size = attention_mask.shape[0]
|
1162 |
+
|
1163 |
+
attention_mask = torch.zeros(
|
1164 |
+
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
1165 |
+
)
|
1166 |
+
# these two operations makes sure that all values before the output lengths idxs are attended to
|
1167 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
1168 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
1169 |
+
return attention_mask
|
1170 |
+
|
1171 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
1172 |
+
if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
|
1173 |
+
module.gradient_checkpointing = value
|
1174 |
+
|
1175 |
+
|
1176 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
|
1177 |
+
Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
1178 |
+
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
|
1179 |
+
Auli.
|
1180 |
+
|
1181 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
1182 |
+
library implements for all its model (such as downloading or saving etc.).
|
1183 |
+
|
1184 |
+
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
|
1185 |
+
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
|
1186 |
+
|
1187 |
+
Parameters:
|
1188 |
+
config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
|
1189 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
1190 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
1191 |
+
"""
|
1192 |
+
|
1193 |
+
|
1194 |
+
WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
|
1195 |
+
Args:
|
1196 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
1197 |
+
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
1198 |
+
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
1199 |
+
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
|
1200 |
+
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
|
1201 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1202 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
1203 |
+
1]`:
|
1204 |
+
|
1205 |
+
- 1 for tokens that are **not masked**,
|
1206 |
+
- 0 for tokens that are **masked**.
|
1207 |
+
|
1208 |
+
[What are attention masks?](../glossary#attention-mask)
|
1209 |
+
|
1210 |
+
<Tip warning={true}>
|
1211 |
+
|
1212 |
+
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
|
1213 |
+
True`. For all models whose processor has `config.return_attention_mask == False`, such as
|
1214 |
+
[wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
|
1215 |
+
`attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
|
1216 |
+
such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
|
1217 |
+
that these models also yield slightly different results depending on whether `input_values` is padded or
|
1218 |
+
not.
|
1219 |
+
|
1220 |
+
</Tip>
|
1221 |
+
|
1222 |
+
output_attentions (`bool`, *optional*):
|
1223 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
1224 |
+
tensors for more detail.
|
1225 |
+
output_hidden_states (`bool`, *optional*):
|
1226 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
1227 |
+
more detail.
|
1228 |
+
return_dict (`bool`, *optional*):
|
1229 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1230 |
+
"""
|
1231 |
+
|
1232 |
+
|
1233 |
+
@add_start_docstrings(
|
1234 |
+
"The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
|
1235 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
1236 |
+
)
|
1237 |
+
class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
|
1238 |
+
def __init__(self, config: Wav2Vec2ConformerConfig):
|
1239 |
+
super().__init__(config)
|
1240 |
+
self.config = config
|
1241 |
+
self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
|
1242 |
+
self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
|
1243 |
+
|
1244 |
+
# model only needs masking vector if mask prob is > 0.0
|
1245 |
+
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
1246 |
+
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
1247 |
+
|
1248 |
+
self.encoder = Wav2Vec2ConformerEncoder(config)
|
1249 |
+
|
1250 |
+
self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
|
1251 |
+
|
1252 |
+
# Initialize weights and apply final processing
|
1253 |
+
self.post_init()
|
1254 |
+
|
1255 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
|
1256 |
+
def freeze_feature_encoder(self):
|
1257 |
+
"""
|
1258 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
1259 |
+
not be updated during training.
|
1260 |
+
"""
|
1261 |
+
self.feature_extractor._freeze_parameters()
|
1262 |
+
|
1263 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
1264 |
+
def _mask_hidden_states(
|
1265 |
+
self,
|
1266 |
+
hidden_states: torch.FloatTensor,
|
1267 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
1268 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1269 |
+
):
|
1270 |
+
"""
|
1271 |
+
Masks extracted features along time axis and/or along feature axis according to
|
1272 |
+
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
1273 |
+
"""
|
1274 |
+
|
1275 |
+
# `config.apply_spec_augment` can set masking to False
|
1276 |
+
if not getattr(self.config, "apply_spec_augment", True):
|
1277 |
+
return hidden_states
|
1278 |
+
|
1279 |
+
# generate indices & apply SpecAugment along time axis
|
1280 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
1281 |
+
|
1282 |
+
if mask_time_indices is not None:
|
1283 |
+
# apply SpecAugment along time axis with given mask_time_indices
|
1284 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
1285 |
+
elif self.config.mask_time_prob > 0 and self.training:
|
1286 |
+
mask_time_indices = _compute_mask_indices(
|
1287 |
+
(batch_size, sequence_length),
|
1288 |
+
mask_prob=self.config.mask_time_prob,
|
1289 |
+
mask_length=self.config.mask_time_length,
|
1290 |
+
attention_mask=attention_mask,
|
1291 |
+
min_masks=self.config.mask_time_min_masks,
|
1292 |
+
)
|
1293 |
+
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
1294 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
1295 |
+
|
1296 |
+
if self.config.mask_feature_prob > 0 and self.training:
|
1297 |
+
# generate indices & apply SpecAugment along feature axis
|
1298 |
+
mask_feature_indices = _compute_mask_indices(
|
1299 |
+
(batch_size, hidden_size),
|
1300 |
+
mask_prob=self.config.mask_feature_prob,
|
1301 |
+
mask_length=self.config.mask_feature_length,
|
1302 |
+
min_masks=self.config.mask_feature_min_masks,
|
1303 |
+
)
|
1304 |
+
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
1305 |
+
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
1306 |
+
hidden_states[mask_feature_indices] = 0
|
1307 |
+
|
1308 |
+
return hidden_states
|
1309 |
+
|
1310 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
1311 |
+
@add_code_sample_docstrings(
|
1312 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1313 |
+
output_type=Wav2Vec2BaseModelOutput,
|
1314 |
+
config_class=_CONFIG_FOR_DOC,
|
1315 |
+
modality="audio",
|
1316 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
1317 |
+
)
|
1318 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
|
1319 |
+
def forward(
|
1320 |
+
self,
|
1321 |
+
input_values: Optional[torch.Tensor],
|
1322 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1323 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
1324 |
+
output_attentions: Optional[bool] = None,
|
1325 |
+
output_hidden_states: Optional[bool] = None,
|
1326 |
+
return_dict: Optional[bool] = None,
|
1327 |
+
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
1328 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1329 |
+
output_hidden_states = (
|
1330 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1331 |
+
)
|
1332 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1333 |
+
|
1334 |
+
extract_features = self.feature_extractor(input_values)
|
1335 |
+
extract_features = extract_features.transpose(1, 2)
|
1336 |
+
|
1337 |
+
if attention_mask is not None:
|
1338 |
+
# compute reduced attention_mask corresponding to feature vectors
|
1339 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
1340 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
1341 |
+
)
|
1342 |
+
|
1343 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
1344 |
+
hidden_states = self._mask_hidden_states(
|
1345 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
1346 |
+
)
|
1347 |
+
|
1348 |
+
encoder_outputs = self.encoder(
|
1349 |
+
hidden_states,
|
1350 |
+
attention_mask=attention_mask,
|
1351 |
+
output_attentions=output_attentions,
|
1352 |
+
output_hidden_states=output_hidden_states,
|
1353 |
+
return_dict=return_dict,
|
1354 |
+
)
|
1355 |
+
|
1356 |
+
hidden_states = encoder_outputs[0]
|
1357 |
+
|
1358 |
+
if self.adapter is not None:
|
1359 |
+
hidden_states = self.adapter(hidden_states)
|
1360 |
+
|
1361 |
+
if not return_dict:
|
1362 |
+
return (hidden_states, extract_features) + encoder_outputs[1:]
|
1363 |
+
|
1364 |
+
return Wav2Vec2BaseModelOutput(
|
1365 |
+
last_hidden_state=hidden_states,
|
1366 |
+
extract_features=extract_features,
|
1367 |
+
hidden_states=encoder_outputs.hidden_states,
|
1368 |
+
attentions=encoder_outputs.attentions,
|
1369 |
+
)
|
1370 |
+
|
1371 |
+
|
1372 |
+
@add_start_docstrings(
|
1373 |
+
"""Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
|
1374 |
+
)
|
1375 |
+
class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
1376 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
1377 |
+
def __init__(self, config: Wav2Vec2ConformerConfig):
|
1378 |
+
super().__init__(config)
|
1379 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
1380 |
+
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
1381 |
+
|
1382 |
+
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
|
1383 |
+
|
1384 |
+
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
1385 |
+
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
1386 |
+
|
1387 |
+
# Initialize weights and apply final processing
|
1388 |
+
self.post_init()
|
1389 |
+
|
1390 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
|
1391 |
+
def set_gumbel_temperature(self, temperature: int):
|
1392 |
+
"""
|
1393 |
+
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
1394 |
+
"""
|
1395 |
+
self.quantizer.temperature = temperature
|
1396 |
+
|
1397 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
1398 |
+
def freeze_feature_encoder(self):
|
1399 |
+
"""
|
1400 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
1401 |
+
not be updated during training.
|
1402 |
+
"""
|
1403 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
1404 |
+
|
1405 |
+
@staticmethod
|
1406 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
|
1407 |
+
def compute_contrastive_logits(
|
1408 |
+
target_features: torch.FloatTensor,
|
1409 |
+
negative_features: torch.FloatTensor,
|
1410 |
+
predicted_features: torch.FloatTensor,
|
1411 |
+
temperature: int = 0.1,
|
1412 |
+
):
|
1413 |
+
"""
|
1414 |
+
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
1415 |
+
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
|
1416 |
+
"""
|
1417 |
+
target_features = torch.cat([target_features, negative_features], dim=0)
|
1418 |
+
|
1419 |
+
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
|
1420 |
+
target_features
|
1421 |
+
)
|
1422 |
+
|
1423 |
+
# apply temperature
|
1424 |
+
logits = logits / temperature
|
1425 |
+
return logits
|
1426 |
+
|
1427 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
1428 |
+
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
1429 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
|
1430 |
+
def forward(
|
1431 |
+
self,
|
1432 |
+
input_values: Optional[torch.Tensor],
|
1433 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1434 |
+
mask_time_indices: Optional[torch.BoolTensor] = None,
|
1435 |
+
sampled_negative_indices: Optional[torch.BoolTensor] = None,
|
1436 |
+
output_attentions: Optional[bool] = None,
|
1437 |
+
output_hidden_states: Optional[bool] = None,
|
1438 |
+
return_dict: Optional[bool] = None,
|
1439 |
+
) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
|
1440 |
+
r"""
|
1441 |
+
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1442 |
+
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
1443 |
+
masked extracted features in *config.proj_codevector_dim* space.
|
1444 |
+
sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
|
1445 |
+
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
|
1446 |
+
Required input for pre-training.
|
1447 |
+
|
1448 |
+
Returns:
|
1449 |
+
|
1450 |
+
Example:
|
1451 |
+
|
1452 |
+
```python
|
1453 |
+
>>> import torch
|
1454 |
+
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
|
1455 |
+
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
1456 |
+
... _compute_mask_indices,
|
1457 |
+
... _sample_negative_indices,
|
1458 |
+
... )
|
1459 |
+
>>> from datasets import load_dataset
|
1460 |
+
|
1461 |
+
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
1462 |
+
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
1463 |
+
|
1464 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
1465 |
+
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
|
1466 |
+
|
1467 |
+
>>> # compute masked indices
|
1468 |
+
>>> batch_size, raw_sequence_length = input_values.shape
|
1469 |
+
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
|
1470 |
+
>>> mask_time_indices = _compute_mask_indices(
|
1471 |
+
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
|
1472 |
+
... )
|
1473 |
+
>>> sampled_negative_indices = _sample_negative_indices(
|
1474 |
+
... features_shape=(batch_size, sequence_length),
|
1475 |
+
... num_negatives=model.config.num_negatives,
|
1476 |
+
... mask_time_indices=mask_time_indices,
|
1477 |
+
... )
|
1478 |
+
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
|
1479 |
+
>>> sampled_negative_indices = torch.tensor(
|
1480 |
+
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
|
1481 |
+
... )
|
1482 |
+
|
1483 |
+
>>> with torch.no_grad():
|
1484 |
+
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
1485 |
+
|
1486 |
+
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
1487 |
+
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
1488 |
+
|
1489 |
+
>>> # show that cosine similarity is much higher than random
|
1490 |
+
>>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
|
1491 |
+
tensor(True)
|
1492 |
+
|
1493 |
+
>>> # for contrastive loss training model should be put into train mode
|
1494 |
+
>>> model = model.train()
|
1495 |
+
>>> loss = model(
|
1496 |
+
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
|
1497 |
+
... ).loss
|
1498 |
+
```"""
|
1499 |
+
|
1500 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1501 |
+
|
1502 |
+
if mask_time_indices is not None:
|
1503 |
+
mask_time_indices = mask_time_indices.to(torch.bool)
|
1504 |
+
|
1505 |
+
outputs = self.wav2vec2_conformer(
|
1506 |
+
input_values,
|
1507 |
+
attention_mask=attention_mask,
|
1508 |
+
output_attentions=output_attentions,
|
1509 |
+
output_hidden_states=output_hidden_states,
|
1510 |
+
mask_time_indices=mask_time_indices,
|
1511 |
+
return_dict=return_dict,
|
1512 |
+
)
|
1513 |
+
|
1514 |
+
# 1. project all transformed features (including masked) to final vq dim
|
1515 |
+
transformer_features = self.project_hid(outputs[0])
|
1516 |
+
|
1517 |
+
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
1518 |
+
extract_features = self.dropout_features(outputs[1])
|
1519 |
+
|
1520 |
+
if attention_mask is not None:
|
1521 |
+
# compute reduced attention_mask correponding to feature vectors
|
1522 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
1523 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
1524 |
+
)
|
1525 |
+
|
1526 |
+
quantized_features, codevector_perplexity = self.quantizer(
|
1527 |
+
extract_features, mask_time_indices=mask_time_indices
|
1528 |
+
)
|
1529 |
+
quantized_features = self.project_q(quantized_features)
|
1530 |
+
|
1531 |
+
loss = contrastive_loss = diversity_loss = None
|
1532 |
+
if sampled_negative_indices is not None:
|
1533 |
+
batch_size, sequence_length, hidden_size = quantized_features.shape
|
1534 |
+
|
1535 |
+
# for training, we sample negatives
|
1536 |
+
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
1537 |
+
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
1538 |
+
# sample negative quantized vectors BTC => (BxT)C
|
1539 |
+
negative_quantized_features = quantized_features.view(-1, hidden_size)[
|
1540 |
+
sampled_negative_indices.long().view(-1)
|
1541 |
+
]
|
1542 |
+
negative_quantized_features = negative_quantized_features.view(
|
1543 |
+
batch_size, sequence_length, -1, hidden_size
|
1544 |
+
).permute(2, 0, 1, 3)
|
1545 |
+
|
1546 |
+
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
1547 |
+
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
1548 |
+
logits = self.compute_contrastive_logits(
|
1549 |
+
quantized_features[None, :],
|
1550 |
+
negative_quantized_features,
|
1551 |
+
transformer_features,
|
1552 |
+
self.config.contrastive_logits_temperature,
|
1553 |
+
)
|
1554 |
+
|
1555 |
+
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
|
1556 |
+
# its cosine similarity will be masked
|
1557 |
+
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
|
1558 |
+
|
1559 |
+
if neg_is_pos.any():
|
1560 |
+
logits[1:][neg_is_pos] = float("-inf")
|
1561 |
+
|
1562 |
+
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
|
1563 |
+
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
|
1564 |
+
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
1565 |
+
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
1566 |
+
|
1567 |
+
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
|
1568 |
+
# 7. compute diversity loss: \mathbf{L}_d
|
1569 |
+
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
|
1570 |
+
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
1571 |
+
|
1572 |
+
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
|
1573 |
+
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
|
1574 |
+
|
1575 |
+
if not return_dict:
|
1576 |
+
if loss is not None:
|
1577 |
+
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
1578 |
+
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
1579 |
+
|
1580 |
+
return Wav2Vec2ConformerForPreTrainingOutput(
|
1581 |
+
loss=loss,
|
1582 |
+
projected_states=transformer_features,
|
1583 |
+
projected_quantized_states=quantized_features,
|
1584 |
+
codevector_perplexity=codevector_perplexity,
|
1585 |
+
hidden_states=outputs.hidden_states,
|
1586 |
+
attentions=outputs.attentions,
|
1587 |
+
contrastive_loss=contrastive_loss,
|
1588 |
+
diversity_loss=diversity_loss,
|
1589 |
+
)
|
1590 |
+
|
1591 |
+
|
1592 |
+
@add_start_docstrings(
|
1593 |
+
"""Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
1594 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
1595 |
+
)
|
1596 |
+
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
|
1597 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
1598 |
+
def __init__(self, config):
|
1599 |
+
super().__init__(config)
|
1600 |
+
|
1601 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
1602 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
1603 |
+
|
1604 |
+
if config.vocab_size is None:
|
1605 |
+
raise ValueError(
|
1606 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
1607 |
+
"does not define the vocabulary size of the language model head. Please "
|
1608 |
+
"instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
1609 |
+
"or define `vocab_size` of your model's configuration."
|
1610 |
+
)
|
1611 |
+
output_hidden_size = (
|
1612 |
+
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
1613 |
+
)
|
1614 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
1615 |
+
|
1616 |
+
# Initialize weights and apply final processing
|
1617 |
+
self.post_init()
|
1618 |
+
|
1619 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
1620 |
+
def freeze_feature_encoder(self):
|
1621 |
+
"""
|
1622 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
1623 |
+
not be updated during training.
|
1624 |
+
"""
|
1625 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
1626 |
+
|
1627 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
1628 |
+
@add_code_sample_docstrings(
|
1629 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1630 |
+
output_type=CausalLMOutput,
|
1631 |
+
config_class=_CONFIG_FOR_DOC,
|
1632 |
+
expected_output=_CTC_EXPECTED_OUTPUT,
|
1633 |
+
expected_loss=_CTC_EXPECTED_LOSS,
|
1634 |
+
)
|
1635 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
1636 |
+
def forward(
|
1637 |
+
self,
|
1638 |
+
input_values: Optional[torch.Tensor],
|
1639 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1640 |
+
output_attentions: Optional[bool] = None,
|
1641 |
+
output_hidden_states: Optional[bool] = None,
|
1642 |
+
return_dict: Optional[bool] = None,
|
1643 |
+
labels: Optional[torch.Tensor] = None,
|
1644 |
+
) -> Union[Tuple, CausalLMOutput]:
|
1645 |
+
r"""
|
1646 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
1647 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
1648 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
1649 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
1650 |
+
config.vocab_size - 1]`.
|
1651 |
+
"""
|
1652 |
+
|
1653 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1654 |
+
|
1655 |
+
outputs = self.wav2vec2_conformer(
|
1656 |
+
input_values,
|
1657 |
+
attention_mask=attention_mask,
|
1658 |
+
output_attentions=output_attentions,
|
1659 |
+
output_hidden_states=output_hidden_states,
|
1660 |
+
return_dict=return_dict,
|
1661 |
+
)
|
1662 |
+
|
1663 |
+
hidden_states = outputs[0]
|
1664 |
+
hidden_states = self.dropout(hidden_states)
|
1665 |
+
|
1666 |
+
logits = self.lm_head(hidden_states)
|
1667 |
+
|
1668 |
+
loss = None
|
1669 |
+
if labels is not None:
|
1670 |
+
if labels.max() >= self.config.vocab_size:
|
1671 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
1672 |
+
|
1673 |
+
# retrieve loss input_lengths from attention_mask
|
1674 |
+
attention_mask = (
|
1675 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
1676 |
+
)
|
1677 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
1678 |
+
|
1679 |
+
# assuming that padded tokens are filled with -100
|
1680 |
+
# when not being attended to
|
1681 |
+
labels_mask = labels >= 0
|
1682 |
+
target_lengths = labels_mask.sum(-1)
|
1683 |
+
flattened_targets = labels.masked_select(labels_mask)
|
1684 |
+
|
1685 |
+
# ctc_loss doesn't support fp16
|
1686 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
1687 |
+
|
1688 |
+
with torch.backends.cudnn.flags(enabled=False):
|
1689 |
+
loss = nn.functional.ctc_loss(
|
1690 |
+
log_probs,
|
1691 |
+
flattened_targets,
|
1692 |
+
input_lengths,
|
1693 |
+
target_lengths,
|
1694 |
+
blank=self.config.pad_token_id,
|
1695 |
+
reduction=self.config.ctc_loss_reduction,
|
1696 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
1697 |
+
)
|
1698 |
+
|
1699 |
+
if not return_dict:
|
1700 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
1701 |
+
return ((loss,) + output) if loss is not None else output
|
1702 |
+
|
1703 |
+
return CausalLMOutput(
|
1704 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
1705 |
+
)
|
1706 |
+
|
1707 |
+
|
1708 |
+
@add_start_docstrings(
|
1709 |
+
"""
|
1710 |
+
Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
|
1711 |
+
tasks like SUPERB Keyword Spotting.
|
1712 |
+
""",
|
1713 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
1714 |
+
)
|
1715 |
+
class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
|
1716 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
1717 |
+
def __init__(self, config):
|
1718 |
+
super().__init__(config)
|
1719 |
+
|
1720 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
1721 |
+
raise ValueError(
|
1722 |
+
"Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
1723 |
+
)
|
1724 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
1725 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
1726 |
+
if config.use_weighted_layer_sum:
|
1727 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
1728 |
+
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
1729 |
+
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
1730 |
+
|
1731 |
+
# Initialize weights and apply final processing
|
1732 |
+
self.post_init()
|
1733 |
+
|
1734 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
1735 |
+
def freeze_feature_encoder(self):
|
1736 |
+
"""
|
1737 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
1738 |
+
not be updated during training.
|
1739 |
+
"""
|
1740 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
1741 |
+
|
1742 |
+
def freeze_base_model(self):
|
1743 |
+
"""
|
1744 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
1745 |
+
be updated during training. Only the classification head will be updated.
|
1746 |
+
"""
|
1747 |
+
for param in self.wav2vec2_conformer.parameters():
|
1748 |
+
param.requires_grad = False
|
1749 |
+
|
1750 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
1751 |
+
@add_code_sample_docstrings(
|
1752 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1753 |
+
output_type=SequenceClassifierOutput,
|
1754 |
+
config_class=_CONFIG_FOR_DOC,
|
1755 |
+
modality="audio",
|
1756 |
+
)
|
1757 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
1758 |
+
def forward(
|
1759 |
+
self,
|
1760 |
+
input_values: Optional[torch.Tensor],
|
1761 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1762 |
+
output_attentions: Optional[bool] = None,
|
1763 |
+
output_hidden_states: Optional[bool] = None,
|
1764 |
+
return_dict: Optional[bool] = None,
|
1765 |
+
labels: Optional[torch.Tensor] = None,
|
1766 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
1767 |
+
r"""
|
1768 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1769 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1770 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1771 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1772 |
+
"""
|
1773 |
+
|
1774 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1775 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
1776 |
+
|
1777 |
+
outputs = self.wav2vec2_conformer(
|
1778 |
+
input_values,
|
1779 |
+
attention_mask=attention_mask,
|
1780 |
+
output_attentions=output_attentions,
|
1781 |
+
output_hidden_states=output_hidden_states,
|
1782 |
+
return_dict=return_dict,
|
1783 |
+
)
|
1784 |
+
|
1785 |
+
if self.config.use_weighted_layer_sum:
|
1786 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
1787 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
1788 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
1789 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
1790 |
+
else:
|
1791 |
+
hidden_states = outputs[0]
|
1792 |
+
|
1793 |
+
hidden_states = self.projector(hidden_states)
|
1794 |
+
if attention_mask is None:
|
1795 |
+
pooled_output = hidden_states.mean(dim=1)
|
1796 |
+
else:
|
1797 |
+
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
1798 |
+
hidden_states[~padding_mask] = 0.0
|
1799 |
+
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
1800 |
+
|
1801 |
+
logits = self.classifier(pooled_output)
|
1802 |
+
|
1803 |
+
loss = None
|
1804 |
+
if labels is not None:
|
1805 |
+
loss_fct = CrossEntropyLoss()
|
1806 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
1807 |
+
|
1808 |
+
if not return_dict:
|
1809 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
1810 |
+
return ((loss,) + output) if loss is not None else output
|
1811 |
+
|
1812 |
+
return SequenceClassifierOutput(
|
1813 |
+
loss=loss,
|
1814 |
+
logits=logits,
|
1815 |
+
hidden_states=outputs.hidden_states,
|
1816 |
+
attentions=outputs.attentions,
|
1817 |
+
)
|
1818 |
+
|
1819 |
+
|
1820 |
+
@add_start_docstrings(
|
1821 |
+
"""
|
1822 |
+
Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
|
1823 |
+
""",
|
1824 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
1825 |
+
)
|
1826 |
+
class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
|
1827 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
1828 |
+
def __init__(self, config):
|
1829 |
+
super().__init__(config)
|
1830 |
+
|
1831 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
1832 |
+
raise ValueError(
|
1833 |
+
"Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
1834 |
+
)
|
1835 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
1836 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
1837 |
+
if config.use_weighted_layer_sum:
|
1838 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
1839 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1840 |
+
self.num_labels = config.num_labels
|
1841 |
+
|
1842 |
+
self.init_weights()
|
1843 |
+
|
1844 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
1845 |
+
def freeze_feature_encoder(self):
|
1846 |
+
"""
|
1847 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
1848 |
+
not be updated during training.
|
1849 |
+
"""
|
1850 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
1851 |
+
|
1852 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
|
1853 |
+
def freeze_base_model(self):
|
1854 |
+
"""
|
1855 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
1856 |
+
be updated during training. Only the classification head will be updated.
|
1857 |
+
"""
|
1858 |
+
for param in self.wav2vec2_conformer.parameters():
|
1859 |
+
param.requires_grad = False
|
1860 |
+
|
1861 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
1862 |
+
@add_code_sample_docstrings(
|
1863 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1864 |
+
output_type=TokenClassifierOutput,
|
1865 |
+
config_class=_CONFIG_FOR_DOC,
|
1866 |
+
modality="audio",
|
1867 |
+
)
|
1868 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
|
1869 |
+
def forward(
|
1870 |
+
self,
|
1871 |
+
input_values: Optional[torch.Tensor],
|
1872 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1873 |
+
labels: Optional[torch.Tensor] = None,
|
1874 |
+
output_attentions: Optional[bool] = None,
|
1875 |
+
output_hidden_states: Optional[bool] = None,
|
1876 |
+
return_dict: Optional[bool] = None,
|
1877 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
1878 |
+
r"""
|
1879 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1880 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1881 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1882 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1883 |
+
"""
|
1884 |
+
|
1885 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1886 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
1887 |
+
|
1888 |
+
outputs = self.wav2vec2_conformer(
|
1889 |
+
input_values,
|
1890 |
+
attention_mask=attention_mask,
|
1891 |
+
output_attentions=output_attentions,
|
1892 |
+
output_hidden_states=output_hidden_states,
|
1893 |
+
return_dict=return_dict,
|
1894 |
+
)
|
1895 |
+
|
1896 |
+
if self.config.use_weighted_layer_sum:
|
1897 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
1898 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
1899 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
1900 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
1901 |
+
else:
|
1902 |
+
hidden_states = outputs[0]
|
1903 |
+
|
1904 |
+
logits = self.classifier(hidden_states)
|
1905 |
+
|
1906 |
+
loss = None
|
1907 |
+
if labels is not None:
|
1908 |
+
loss_fct = CrossEntropyLoss()
|
1909 |
+
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
|
1910 |
+
|
1911 |
+
if not return_dict:
|
1912 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
1913 |
+
return output
|
1914 |
+
|
1915 |
+
return TokenClassifierOutput(
|
1916 |
+
loss=loss,
|
1917 |
+
logits=logits,
|
1918 |
+
hidden_states=outputs.hidden_states,
|
1919 |
+
attentions=outputs.attentions,
|
1920 |
+
)
|
1921 |
+
|
1922 |
+
|
1923 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
|
1924 |
+
class AMSoftmaxLoss(nn.Module):
|
1925 |
+
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
1926 |
+
super(AMSoftmaxLoss, self).__init__()
|
1927 |
+
self.scale = scale
|
1928 |
+
self.margin = margin
|
1929 |
+
self.num_labels = num_labels
|
1930 |
+
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
1931 |
+
self.loss = nn.CrossEntropyLoss()
|
1932 |
+
|
1933 |
+
def forward(self, hidden_states, labels):
|
1934 |
+
labels = labels.flatten()
|
1935 |
+
weight = nn.functional.normalize(self.weight, dim=0)
|
1936 |
+
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
1937 |
+
cos_theta = torch.mm(hidden_states, weight)
|
1938 |
+
psi = cos_theta - self.margin
|
1939 |
+
|
1940 |
+
onehot = nn.functional.one_hot(labels, self.num_labels)
|
1941 |
+
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
1942 |
+
loss = self.loss(logits, labels)
|
1943 |
+
|
1944 |
+
return loss
|
1945 |
+
|
1946 |
+
|
1947 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
|
1948 |
+
class TDNNLayer(nn.Module):
|
1949 |
+
def __init__(self, config, layer_id=0):
|
1950 |
+
super().__init__()
|
1951 |
+
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
1952 |
+
self.out_conv_dim = config.tdnn_dim[layer_id]
|
1953 |
+
self.kernel_size = config.tdnn_kernel[layer_id]
|
1954 |
+
self.dilation = config.tdnn_dilation[layer_id]
|
1955 |
+
|
1956 |
+
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
1957 |
+
self.activation = nn.ReLU()
|
1958 |
+
|
1959 |
+
def forward(self, hidden_states):
|
1960 |
+
hidden_states = hidden_states.unsqueeze(1)
|
1961 |
+
hidden_states = nn.functional.unfold(
|
1962 |
+
hidden_states,
|
1963 |
+
(self.kernel_size, self.in_conv_dim),
|
1964 |
+
stride=(1, self.in_conv_dim),
|
1965 |
+
dilation=(self.dilation, 1),
|
1966 |
+
)
|
1967 |
+
hidden_states = hidden_states.transpose(1, 2)
|
1968 |
+
hidden_states = self.kernel(hidden_states)
|
1969 |
+
|
1970 |
+
hidden_states = self.activation(hidden_states)
|
1971 |
+
return hidden_states
|
1972 |
+
|
1973 |
+
|
1974 |
+
@add_start_docstrings(
|
1975 |
+
"""
|
1976 |
+
Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
1977 |
+
""",
|
1978 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
1979 |
+
)
|
1980 |
+
class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
|
1981 |
+
def __init__(self, config):
|
1982 |
+
super().__init__(config)
|
1983 |
+
|
1984 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
1985 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
1986 |
+
if config.use_weighted_layer_sum:
|
1987 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
1988 |
+
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
1989 |
+
|
1990 |
+
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
1991 |
+
self.tdnn = nn.ModuleList(tdnn_layers)
|
1992 |
+
|
1993 |
+
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
1994 |
+
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
1995 |
+
|
1996 |
+
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
1997 |
+
|
1998 |
+
self.init_weights()
|
1999 |
+
|
2000 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
2001 |
+
def freeze_feature_encoder(self):
|
2002 |
+
"""
|
2003 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
2004 |
+
not be updated during training.
|
2005 |
+
"""
|
2006 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
2007 |
+
|
2008 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
|
2009 |
+
def freeze_base_model(self):
|
2010 |
+
"""
|
2011 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
2012 |
+
be updated during training. Only the classification head will be updated.
|
2013 |
+
"""
|
2014 |
+
for param in self.wav2vec2_conformer.parameters():
|
2015 |
+
param.requires_grad = False
|
2016 |
+
|
2017 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
|
2018 |
+
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
2019 |
+
"""
|
2020 |
+
Computes the output length of the TDNN layers
|
2021 |
+
"""
|
2022 |
+
|
2023 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
2024 |
+
# 1D convolutional layer output length formula taken
|
2025 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
2026 |
+
return (input_length - kernel_size) // stride + 1
|
2027 |
+
|
2028 |
+
for kernel_size in self.config.tdnn_kernel:
|
2029 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
2030 |
+
|
2031 |
+
return input_lengths
|
2032 |
+
|
2033 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
2034 |
+
@add_code_sample_docstrings(
|
2035 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
2036 |
+
output_type=XVectorOutput,
|
2037 |
+
config_class=_CONFIG_FOR_DOC,
|
2038 |
+
modality="audio",
|
2039 |
+
)
|
2040 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
2041 |
+
def forward(
|
2042 |
+
self,
|
2043 |
+
input_values: Optional[torch.Tensor],
|
2044 |
+
attention_mask: Optional[torch.Tensor] = None,
|
2045 |
+
output_attentions: Optional[bool] = None,
|
2046 |
+
output_hidden_states: Optional[bool] = None,
|
2047 |
+
return_dict: Optional[bool] = None,
|
2048 |
+
labels: Optional[torch.Tensor] = None,
|
2049 |
+
) -> Union[Tuple, XVectorOutput]:
|
2050 |
+
r"""
|
2051 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
2052 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
2053 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
2054 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
2055 |
+
"""
|
2056 |
+
|
2057 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2058 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
2059 |
+
|
2060 |
+
outputs = self.wav2vec2_conformer(
|
2061 |
+
input_values,
|
2062 |
+
attention_mask=attention_mask,
|
2063 |
+
output_attentions=output_attentions,
|
2064 |
+
output_hidden_states=output_hidden_states,
|
2065 |
+
return_dict=return_dict,
|
2066 |
+
)
|
2067 |
+
|
2068 |
+
if self.config.use_weighted_layer_sum:
|
2069 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
2070 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
2071 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
2072 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
2073 |
+
else:
|
2074 |
+
hidden_states = outputs[0]
|
2075 |
+
|
2076 |
+
hidden_states = self.projector(hidden_states)
|
2077 |
+
|
2078 |
+
for tdnn_layer in self.tdnn:
|
2079 |
+
hidden_states = tdnn_layer(hidden_states)
|
2080 |
+
|
2081 |
+
# Statistic Pooling
|
2082 |
+
if attention_mask is None:
|
2083 |
+
mean_features = hidden_states.mean(dim=1)
|
2084 |
+
std_features = hidden_states.std(dim=1)
|
2085 |
+
else:
|
2086 |
+
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
2087 |
+
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
2088 |
+
mean_features = []
|
2089 |
+
std_features = []
|
2090 |
+
for i, length in enumerate(tdnn_output_lengths):
|
2091 |
+
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
2092 |
+
std_features.append(hidden_states[i, :length].std(dim=0))
|
2093 |
+
mean_features = torch.stack(mean_features)
|
2094 |
+
std_features = torch.stack(std_features)
|
2095 |
+
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
2096 |
+
|
2097 |
+
output_embeddings = self.feature_extractor(statistic_pooling)
|
2098 |
+
logits = self.classifier(output_embeddings)
|
2099 |
+
|
2100 |
+
loss = None
|
2101 |
+
if labels is not None:
|
2102 |
+
loss = self.objective(logits, labels)
|
2103 |
+
|
2104 |
+
if not return_dict:
|
2105 |
+
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
2106 |
+
return ((loss,) + output) if loss is not None else output
|
2107 |
+
|
2108 |
+
return XVectorOutput(
|
2109 |
+
loss=loss,
|
2110 |
+
logits=logits,
|
2111 |
+
embeddings=output_embeddings,
|
2112 |
+
hidden_states=outputs.hidden_states,
|
2113 |
+
attentions=outputs.attentions,
|
2114 |
+
)
|
slam_llm/models/musicfm/modules/random_quantizer.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
#
|
3 |
+
# Copyright 2023 ByteDance Inc.
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
|
6 |
+
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
8 |
+
#
|
9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
10 |
+
#
|
11 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
12 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
13 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
14 |
+
# IN THE SOFTWARE.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn, einsum
|
18 |
+
from einops import rearrange
|
19 |
+
|
20 |
+
|
21 |
+
class RandomProjectionQuantizer(nn.Module):
|
22 |
+
"""
|
23 |
+
Random projection and codebook lookup module
|
24 |
+
|
25 |
+
Some code is borrowed from:
|
26 |
+
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
|
27 |
+
But I did normalization using pre-computed global mean & variance instead of using layer norm.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
input_dim,
|
33 |
+
codebook_dim,
|
34 |
+
codebook_size,
|
35 |
+
seed=142,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
# random seed
|
40 |
+
torch.manual_seed(seed)
|
41 |
+
|
42 |
+
# randomly initialized projection
|
43 |
+
random_projection = torch.empty(input_dim, codebook_dim)
|
44 |
+
nn.init.xavier_normal_(random_projection)
|
45 |
+
self.register_buffer("random_projection", random_projection)
|
46 |
+
|
47 |
+
# randomly initialized codebook
|
48 |
+
codebook = torch.empty(codebook_size, codebook_dim)
|
49 |
+
nn.init.normal_(codebook)
|
50 |
+
self.register_buffer("codebook", codebook)
|
51 |
+
|
52 |
+
def codebook_lookup(self, x):
|
53 |
+
# reshape
|
54 |
+
b = x.shape[0]
|
55 |
+
x = rearrange(x, "b n e -> (b n) e")
|
56 |
+
|
57 |
+
# L2 normalization
|
58 |
+
normalized_x = nn.functional.normalize(x, dim=1, p=2)
|
59 |
+
normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
|
60 |
+
|
61 |
+
# compute distances
|
62 |
+
distances = torch.cdist(normalized_codebook, normalized_x)
|
63 |
+
|
64 |
+
# get nearest
|
65 |
+
nearest_indices = torch.argmin(distances, dim=0)
|
66 |
+
|
67 |
+
# reshape
|
68 |
+
xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
|
69 |
+
|
70 |
+
return xq
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
def forward(self, x):
|
74 |
+
# always eval
|
75 |
+
self.eval()
|
76 |
+
|
77 |
+
# random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
|
78 |
+
x = einsum("b n d, d e -> b n e", x, self.random_projection)
|
79 |
+
|
80 |
+
# codebook lookup
|
81 |
+
xq = self.codebook_lookup(x)
|
82 |
+
|
83 |
+
return xq
|
slam_llm/models/projector.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class EncoderProjectorConcat(nn.Module):
|
6 |
+
def __init__(self, config):
|
7 |
+
super().__init__()
|
8 |
+
self.k = config.encoder_projector_ds_rate
|
9 |
+
self.encoder_dim = config.encoder_dim
|
10 |
+
self.llm_dim = config.llm_dim
|
11 |
+
self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048)
|
12 |
+
self.relu = nn.ReLU()
|
13 |
+
self.linear2 = nn.Linear(2048, config.llm_dim)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
batch_size, seq_len, dim = x.size()
|
17 |
+
num_frames_to_discard = seq_len % self.k
|
18 |
+
if num_frames_to_discard > 0:
|
19 |
+
x = x[:, :-num_frames_to_discard, :]
|
20 |
+
seq_len = x.size(1)
|
21 |
+
|
22 |
+
x = x.contiguous()
|
23 |
+
x = x.view(batch_size, seq_len // self.k, dim * self.k)
|
24 |
+
x = self.linear1(x)
|
25 |
+
x = self.relu(x)
|
26 |
+
x = self.linear2(x)
|
27 |
+
return x
|
28 |
+
|
29 |
+
class EncoderProjectorCov1d(nn.Module):
|
30 |
+
def __init__(self, config):
|
31 |
+
super().__init__()
|
32 |
+
self.k = config.encoder_projector_ds_rate
|
33 |
+
self.encoder_dim = config.encoder_dim
|
34 |
+
self.llm_dim = config.llm_dim
|
35 |
+
self.conv1d = nn.Conv1d(in_channels=self.encoder_dim, out_channels=self.encoder_dim, kernel_size=self.k, stride=self.k, padding=0)
|
36 |
+
self.linear1 = nn.Linear(self.encoder_dim, 2048)
|
37 |
+
self.relu1 = nn.ReLU()
|
38 |
+
self.linear2 = nn.Linear(2048, self.llm_dim)
|
39 |
+
self.relu2 = nn.ReLU()
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = x.transpose(1, 2)
|
43 |
+
x = self.conv1d(x)
|
44 |
+
x = x.transpose(1, 2)
|
45 |
+
x = self.relu1(x)
|
46 |
+
x = self.linear1(x)
|
47 |
+
x = self.relu2(x)
|
48 |
+
x = self.linear2(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
class EncoderProjectorQFormer(nn.Module):
|
52 |
+
def __init__(self, config):
|
53 |
+
super().__init__()
|
54 |
+
self.encoder_dim = config.encoder_dim
|
55 |
+
self.llm_dim = config.llm_dim
|
56 |
+
from transformers import Blip2QFormerConfig, Blip2QFormerModel
|
57 |
+
configuration = Blip2QFormerConfig()
|
58 |
+
configuration.encoder_hidden_size = self.encoder_dim
|
59 |
+
configuration.num_hidden_layers = 8
|
60 |
+
|
61 |
+
self.query_len = 64
|
62 |
+
self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
|
63 |
+
self.query.data.normal_(mean=0.0, std=1.0)
|
64 |
+
self.qformer = Blip2QFormerModel(configuration)
|
65 |
+
|
66 |
+
self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
|
67 |
+
self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
|
68 |
+
|
69 |
+
def forward(self, x, atts):
|
70 |
+
query = self.query.expand(x.shape[0], -1, -1)
|
71 |
+
|
72 |
+
query_output = self.qformer(
|
73 |
+
query_embeds=query,
|
74 |
+
encoder_hidden_states=x,
|
75 |
+
encoder_attention_mask=atts,
|
76 |
+
return_dict=True,
|
77 |
+
)
|
78 |
+
|
79 |
+
query_proj = self.norm(self.linear(query_output.last_hidden_state))
|
80 |
+
|
81 |
+
return query_proj
|
slam_llm/models/slam_model.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import types
|
3 |
+
import torch
|
4 |
+
import soundfile as sf
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.distributed as dist
|
8 |
+
from typing import List, Optional, Tuple, Union
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
|
10 |
+
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
11 |
+
|
12 |
+
from slam_llm.utils.config_utils import generate_peft_config
|
13 |
+
from slam_llm.utils.train_utils import print_module_size, print_model_size
|
14 |
+
from peft import PeftModel, PeftConfig
|
15 |
+
from torch.nn import CrossEntropyLoss
|
16 |
+
from slam_llm.utils.metric import compute_accuracy
|
17 |
+
|
18 |
+
import logging
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
def model_factory(train_config, model_config, **kwargs):
|
22 |
+
# return necessary components for training
|
23 |
+
tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
|
24 |
+
|
25 |
+
encoder = setup_encoder(train_config, model_config, **kwargs)
|
26 |
+
|
27 |
+
# llm
|
28 |
+
llm = setup_llm(train_config, model_config, **kwargs)
|
29 |
+
|
30 |
+
# projector
|
31 |
+
encoder_projector = setup_encoder_projector(
|
32 |
+
train_config, model_config, **kwargs
|
33 |
+
)
|
34 |
+
model = slam_model(
|
35 |
+
encoder,
|
36 |
+
llm,
|
37 |
+
encoder_projector,
|
38 |
+
tokenizer,
|
39 |
+
train_config,
|
40 |
+
model_config,
|
41 |
+
**kwargs,
|
42 |
+
)
|
43 |
+
|
44 |
+
ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft)
|
45 |
+
if ckpt_path is not None:
|
46 |
+
logger.info("loading other parts from: {}".format(ckpt_path))
|
47 |
+
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
|
48 |
+
model.load_state_dict(ckpt_dict, strict=False)
|
49 |
+
|
50 |
+
print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
|
51 |
+
return model, tokenizer
|
52 |
+
|
53 |
+
|
54 |
+
def setup_tokenizer(train_config, model_config, **kwargs):
|
55 |
+
# Load the tokenizer and add special tokens
|
56 |
+
if "vallex" in model_config.llm_name.lower():
|
57 |
+
return None
|
58 |
+
elif "mupt" in model_config.llm_name.lower():
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path,
|
60 |
+
trust_remote_code=True,
|
61 |
+
use_fast=False)
|
62 |
+
else:
|
63 |
+
tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path)
|
64 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
65 |
+
return tokenizer
|
66 |
+
|
67 |
+
|
68 |
+
def setup_encoder(train_config, model_config, **kwargs):
|
69 |
+
encoder_list = model_config.encoder_name.split(",") if model_config.encoder_name else []
|
70 |
+
if len(encoder_list) == 0:
|
71 |
+
return None
|
72 |
+
if len(encoder_list) == 1:
|
73 |
+
encoder_name = encoder_list[0]
|
74 |
+
if encoder_name == "whisper" or encoder_name == "qwen-audio":
|
75 |
+
from slam_llm.models.encoder import WhisperWrappedEncoder
|
76 |
+
encoder = WhisperWrappedEncoder.load(model_config)
|
77 |
+
if encoder_name == "beats":
|
78 |
+
from slam_llm.models.encoder import BEATsEncoder
|
79 |
+
encoder = BEATsEncoder.load(model_config)
|
80 |
+
if encoder_name == "eat":
|
81 |
+
from slam_llm.models.encoder import EATEncoder
|
82 |
+
encoder = EATEncoder.load(model_config)
|
83 |
+
if encoder_name == "SpatialAST":
|
84 |
+
from slam_llm.models.encoder import SpatialASTEncoder
|
85 |
+
encoder = SpatialASTEncoder.load(model_config)
|
86 |
+
if encoder_name == "wavlm":
|
87 |
+
from slam_llm.models.encoder import WavLMEncoder
|
88 |
+
encoder = WavLMEncoder.load(model_config)
|
89 |
+
if encoder_name == "av_hubert":
|
90 |
+
from slam_llm.models.encoder import AVHubertEncoder
|
91 |
+
encoder = AVHubertEncoder.load(model_config)
|
92 |
+
if encoder_name == "hubert":
|
93 |
+
from slam_llm.models.encoder import HubertEncoder
|
94 |
+
encoder = HubertEncoder.load(model_config)
|
95 |
+
if encoder_name == "musicfm":
|
96 |
+
from slam_llm.models.encoder import MusicFMEncoder
|
97 |
+
encoder = MusicFMEncoder.load(model_config)
|
98 |
+
|
99 |
+
if "llama" in encoder_name.lower():
|
100 |
+
from slam_llm.models.encoder import HfTextEncoder
|
101 |
+
encoder = HfTextEncoder.load(model_config)
|
102 |
+
print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
|
103 |
+
|
104 |
+
if train_config.freeze_encoder:
|
105 |
+
for name, param in encoder.named_parameters():
|
106 |
+
param.requires_grad = False
|
107 |
+
encoder.eval()
|
108 |
+
print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
|
109 |
+
|
110 |
+
return encoder
|
111 |
+
|
112 |
+
def setup_llm(train_config, model_config, **kwargs):
|
113 |
+
from pkg_resources import packaging
|
114 |
+
use_cache = False if train_config.enable_fsdp or train_config.enable_ddp else None
|
115 |
+
if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.low_cpu_fsdp:
|
116 |
+
"""
|
117 |
+
for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
|
118 |
+
this avoids cpu oom when loading large models like llama 70B, in which case
|
119 |
+
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
|
120 |
+
overhead and currently requires latest nightly.
|
121 |
+
"""
|
122 |
+
# v = packaging.version.parse(torch.__version__)
|
123 |
+
# verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
|
124 |
+
# if not verify_latest_nightly:
|
125 |
+
# raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
|
126 |
+
# "please install latest nightly.")
|
127 |
+
rank = int(os.environ["RANK"])
|
128 |
+
if rank == 0:
|
129 |
+
if "vallex" in model_config.llm_name.lower():
|
130 |
+
from src.slam_llm.models.vallex.vallex_config import VallexConfig
|
131 |
+
from src.slam_llm.models.vallex.vallex_model import VALLE
|
132 |
+
vallex_config = VallexConfig(
|
133 |
+
**model_config
|
134 |
+
)
|
135 |
+
model = VALLE(vallex_config)
|
136 |
+
elif "aya" in model_config.llm_name.lower():
|
137 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
138 |
+
model_config.llm_path,
|
139 |
+
load_in_8bit=True if train_config.quantization else None,
|
140 |
+
device_map="auto" if train_config.quantization else None,
|
141 |
+
use_cache=use_cache,
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
model = AutoModelForCausalLM.from_pretrained(
|
145 |
+
model_config.llm_path,
|
146 |
+
load_in_8bit=True if train_config.quantization else None,
|
147 |
+
device_map="auto" if train_config.quantization else None,
|
148 |
+
use_cache=use_cache,
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
llama_config = AutoConfig.from_pretrained(model_config.llm_path)
|
152 |
+
llama_config.use_cache = use_cache
|
153 |
+
# with torch.device("meta"):
|
154 |
+
if "aya" in model_config.llm_name.lower():
|
155 |
+
model = AutoModelForSeq2SeqLM(llama_config)
|
156 |
+
else:
|
157 |
+
model = AutoModelForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta`
|
158 |
+
|
159 |
+
else:
|
160 |
+
if "vallex" in model_config.llm_name.lower():
|
161 |
+
from src.slam_llm.models.vallex.vallex_config import VallexConfig
|
162 |
+
from src.slam_llm.models.vallex.vallex_model import VALLE
|
163 |
+
vallex_config = VallexConfig(
|
164 |
+
**model_config
|
165 |
+
)
|
166 |
+
model = VALLE(vallex_config)
|
167 |
+
elif "aya" in model_config.llm_name.lower():
|
168 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
169 |
+
model_config.llm_path,
|
170 |
+
load_in_8bit=True if train_config.quantization else None,
|
171 |
+
device_map="auto" if train_config.quantization else None,
|
172 |
+
use_cache=use_cache,
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
model = AutoModelForCausalLM.from_pretrained(
|
176 |
+
model_config.llm_path,
|
177 |
+
load_in_8bit=True if train_config.quantization else None,
|
178 |
+
device_map="auto" if train_config.quantization else None,
|
179 |
+
use_cache=use_cache,
|
180 |
+
)
|
181 |
+
if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.use_fast_kernels:
|
182 |
+
"""
|
183 |
+
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
|
184 |
+
using of Flash Attention or Xformer memory-efficient kernels
|
185 |
+
based on the hardware being used. This would speed up fine-tuning.
|
186 |
+
"""
|
187 |
+
try:
|
188 |
+
from optimum.bettertransformer import BetterTransformer
|
189 |
+
model = BetterTransformer.transform(model)
|
190 |
+
except ImportError:
|
191 |
+
logger.warning("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
|
192 |
+
|
193 |
+
print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
|
194 |
+
|
195 |
+
# Prepare the model for int8 training if quantization is enabled
|
196 |
+
if train_config.quantization:
|
197 |
+
model = prepare_model_for_kbit_training(model)
|
198 |
+
|
199 |
+
if train_config.freeze_llm: # TODO:to test offical `freeze_layers` and `num_freeze_layers`
|
200 |
+
for name, param in model.named_parameters():
|
201 |
+
param.requires_grad = False
|
202 |
+
model.eval()
|
203 |
+
|
204 |
+
if kwargs.get("peft_ckpt", None): # (FIX:MZY):reload will get wrong results when decoding
|
205 |
+
logger.info("loading peft_ckpt from: {}".format(kwargs.get("peft_ckpt")))
|
206 |
+
model = PeftModel.from_pretrained(model=model, model_id=kwargs.get("peft_ckpt"), is_trainable=True)
|
207 |
+
model.print_trainable_parameters()
|
208 |
+
elif train_config.use_peft:
|
209 |
+
logger.info("setup peft...")
|
210 |
+
peft_config = generate_peft_config(train_config)
|
211 |
+
model = get_peft_model(model, peft_config)
|
212 |
+
model.print_trainable_parameters()
|
213 |
+
|
214 |
+
print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
|
215 |
+
return model
|
216 |
+
|
217 |
+
def setup_encoder_projector(train_config, model_config, **kwargs):
|
218 |
+
if model_config.encoder_projector == "linear":
|
219 |
+
from slam_llm.models.projector import EncoderProjectorConcat
|
220 |
+
encoder_projector = EncoderProjectorConcat(model_config)
|
221 |
+
elif model_config.encoder_projector == "cov1d-linear":
|
222 |
+
from slam_llm.models.projector import EncoderProjectorCov1d
|
223 |
+
encoder_projector = EncoderProjectorCov1d(model_config)
|
224 |
+
elif model_config.encoder_projector == "q-former":
|
225 |
+
from slam_llm.models.projector import EncoderProjectorQFormer
|
226 |
+
encoder_projector = EncoderProjectorQFormer(model_config)
|
227 |
+
else:
|
228 |
+
return None
|
229 |
+
print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
|
230 |
+
return encoder_projector
|
231 |
+
|
232 |
+
|
233 |
+
class slam_model(nn.Module):
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
encoder: nn.Module,
|
237 |
+
llm: nn.Module,
|
238 |
+
encoder_projector: nn.Module,
|
239 |
+
tokenizer,
|
240 |
+
train_config,
|
241 |
+
model_config,
|
242 |
+
**kwargs
|
243 |
+
):
|
244 |
+
super().__init__()
|
245 |
+
# modality encoder
|
246 |
+
self.encoder = encoder
|
247 |
+
|
248 |
+
# llm
|
249 |
+
self.llm = llm
|
250 |
+
|
251 |
+
# projector
|
252 |
+
self.encoder_projector = encoder_projector
|
253 |
+
|
254 |
+
# tokenizer
|
255 |
+
self.tokenizer = tokenizer
|
256 |
+
self.metric = kwargs.get("metric", "acc")
|
257 |
+
|
258 |
+
self.train_config = train_config
|
259 |
+
self.model_config = model_config
|
260 |
+
|
261 |
+
if train_config.get("enable_deepspeed", False):
|
262 |
+
def new_forward(self, input):
|
263 |
+
output = F.layer_norm(
|
264 |
+
input.float(),
|
265 |
+
self.normalized_shape,
|
266 |
+
self.weight.float() if self.weight is not None else None,
|
267 |
+
self.bias.float() if self.bias is not None else None,
|
268 |
+
self.eps,
|
269 |
+
)
|
270 |
+
return output.type_as(input)
|
271 |
+
for item in self.modules():
|
272 |
+
if isinstance(item, nn.LayerNorm):
|
273 |
+
item.forward = types.MethodType(new_forward, item)
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
def forward(self,
|
278 |
+
input_ids: torch.LongTensor = None,
|
279 |
+
attention_mask: Optional[torch.Tensor] = None,
|
280 |
+
position_ids: Optional[torch.LongTensor] = None,
|
281 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
282 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
283 |
+
labels: Optional[torch.LongTensor] = None,
|
284 |
+
use_cache: Optional[bool] = None,
|
285 |
+
output_attentions: Optional[bool] = None,
|
286 |
+
output_hidden_states: Optional[bool] = None,
|
287 |
+
return_dict: Optional[bool] = None,
|
288 |
+
**kwargs,
|
289 |
+
):
|
290 |
+
audio_mel = kwargs.get("audio_mel", None)
|
291 |
+
audio_mel_mask = kwargs.get("audio_mel_mask", None)
|
292 |
+
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
|
293 |
+
|
294 |
+
audio = kwargs.get("audio", None)
|
295 |
+
audio_mask = kwargs.get("audio_mask", None)
|
296 |
+
visual = kwargs.get("visual", None)
|
297 |
+
visual_mask = kwargs.get("visual_mask", None)
|
298 |
+
|
299 |
+
|
300 |
+
# for text encoder
|
301 |
+
instruct_ids = kwargs.get("instruct_ids", None)
|
302 |
+
instruct_mask = kwargs.get("instruct_mask", None)
|
303 |
+
|
304 |
+
modality_mask = kwargs.get("modality_mask", None)
|
305 |
+
|
306 |
+
zh_data = kwargs.get("zh", None)
|
307 |
+
en_data = kwargs.get("en", None)
|
308 |
+
|
309 |
+
encoder_outs = None
|
310 |
+
if audio_mel is not None or audio is not None or visual is not None:
|
311 |
+
if self.train_config.freeze_encoder: # freeze encoder
|
312 |
+
self.encoder.eval()
|
313 |
+
|
314 |
+
if self.model_config.encoder_name == "whisper":
|
315 |
+
encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim
|
316 |
+
if self.model_config.encoder_name == "beats":
|
317 |
+
encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim
|
318 |
+
if self.model_config.encoder_name == "eat":
|
319 |
+
encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x']
|
320 |
+
if self.model_config.encoder_name == "SpatialAST":
|
321 |
+
encoder_outs = self.encoder(audio) # output: [bs, seq_len=3+512, dim=768]
|
322 |
+
if self.model_config.encoder_name == "wavlm":
|
323 |
+
encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask
|
324 |
+
if self.model_config.encoder_name == "hubert":
|
325 |
+
results = self.encoder(source = audio, padding_mask = 1-audio_mask)
|
326 |
+
if self.model_config.encoder_type == "pretrain":
|
327 |
+
encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"]
|
328 |
+
if self.model_config.encoder_type == "finetune":
|
329 |
+
encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
|
330 |
+
encoder_outs = encoder_outs.transpose(0, 1)
|
331 |
+
if self.model_config.encoder_name == "av_hubert":
|
332 |
+
results = self.encoder(source={'video':visual, 'audio':audio}, padding_mask=visual_mask) # bs*seq*dim
|
333 |
+
encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
|
334 |
+
encoder_outs = encoder_outs.transpose(0, 1)
|
335 |
+
audio_mel_post_mask = (~audio_mel_post_mask).float()
|
336 |
+
if self.model_config.encoder_name == 'musicfm':
|
337 |
+
encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask
|
338 |
+
if self.encoder is None:
|
339 |
+
encoder_outs = audio_mel if audio_mel is not None else audio
|
340 |
+
|
341 |
+
if self.model_config.encoder_projector == "q-former":
|
342 |
+
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
|
343 |
+
if self.model_config.encoder_projector == "linear":
|
344 |
+
encoder_outs = self.encoder_projector(encoder_outs)
|
345 |
+
if self.model_config.encoder_projector == "cov1d-linear":
|
346 |
+
encoder_outs = self.encoder_projector(encoder_outs)
|
347 |
+
|
348 |
+
if instruct_ids is not None:
|
349 |
+
if self.encoder is not None:
|
350 |
+
encoder_outs = self.encoder(input_ids=instruct_ids, attention_mask=instruct_mask).last_hidden_state
|
351 |
+
|
352 |
+
if self.model_config.encoder_projector == "q-former":
|
353 |
+
encoder_outs = self.encoder_projector(encoder_outs, instruct_mask)
|
354 |
+
if self.model_config.encoder_projector == "linear":
|
355 |
+
encoder_outs = self.encoder_projector(encoder_outs)
|
356 |
+
|
357 |
+
if input_ids is not None:
|
358 |
+
input_ids[input_ids == -1] = 0
|
359 |
+
if isinstance(self.llm, T5ForConditionalGeneration):
|
360 |
+
inputs_embeds = self.llm.shared(input_ids)
|
361 |
+
else:
|
362 |
+
if hasattr(self.llm.model, "embed_tokens"):
|
363 |
+
inputs_embeds = self.llm.model.embed_tokens(input_ids)
|
364 |
+
elif hasattr(self.llm.model.model, "embed_tokens"):
|
365 |
+
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
|
366 |
+
else:
|
367 |
+
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
|
368 |
+
|
369 |
+
if modality_mask is not None:
|
370 |
+
modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1)
|
371 |
+
modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist()
|
372 |
+
|
373 |
+
encoder_outs_pad = torch.zeros_like(inputs_embeds)
|
374 |
+
for i in range(encoder_outs.shape[0]):
|
375 |
+
encoder_outs_pad[
|
376 |
+
i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i]
|
377 |
+
] = encoder_outs[i][:modality_lengths[i]]
|
378 |
+
|
379 |
+
inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])
|
380 |
+
|
381 |
+
if kwargs.get("inference_mode", False):
|
382 |
+
return inputs_embeds, attention_mask
|
383 |
+
|
384 |
+
if zh_data is not None and en_data is not None:
|
385 |
+
model_outputs, acc = self.llm(zh=zh_data, en=en_data)
|
386 |
+
else:
|
387 |
+
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
388 |
+
acc = -1
|
389 |
+
if self.metric:
|
390 |
+
with torch.no_grad():
|
391 |
+
preds = torch.argmax(model_outputs.logits, -1)
|
392 |
+
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)
|
393 |
+
|
394 |
+
return model_outputs, acc
|
395 |
+
|
396 |
+
@torch.no_grad()
|
397 |
+
def generate(self,
|
398 |
+
input_ids: torch.LongTensor = None,
|
399 |
+
attention_mask: Optional[torch.Tensor] = None,
|
400 |
+
position_ids: Optional[torch.LongTensor] = None,
|
401 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
402 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
403 |
+
labels: Optional[torch.LongTensor] = None,
|
404 |
+
use_cache: Optional[bool] = None,
|
405 |
+
output_attentions: Optional[bool] = None,
|
406 |
+
output_hidden_states: Optional[bool] = None,
|
407 |
+
return_dict: Optional[bool] = None,
|
408 |
+
**kwargs,
|
409 |
+
):
|
410 |
+
kwargs["inference_mode"] = True
|
411 |
+
|
412 |
+
inputs_embeds, attention_mask = self.forward(
|
413 |
+
input_ids=input_ids,
|
414 |
+
attention_mask=attention_mask,
|
415 |
+
position_ids=position_ids,
|
416 |
+
past_key_values=past_key_values,
|
417 |
+
inputs_embeds=inputs_embeds,
|
418 |
+
labels=labels,
|
419 |
+
use_cache=use_cache,
|
420 |
+
output_attentions=output_attentions,
|
421 |
+
output_hidden_states=output_hidden_states,
|
422 |
+
return_dict=return_dict,
|
423 |
+
**kwargs,
|
424 |
+
)
|
425 |
+
|
426 |
+
model_outputs = self.llm.generate(
|
427 |
+
inputs_embeds=inputs_embeds,
|
428 |
+
# max_length=kwargs.get("max_length", 200),
|
429 |
+
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
430 |
+
num_beams=kwargs.get("num_beams", 4),
|
431 |
+
do_sample=kwargs.get("do_sample", False),
|
432 |
+
min_length=kwargs.get("min_length", 1),
|
433 |
+
top_p=kwargs.get("top_p", 1.0),
|
434 |
+
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
435 |
+
length_penalty=kwargs.get("length_penalty", 1.0),
|
436 |
+
temperature=kwargs.get("temperature", 1.0),
|
437 |
+
attention_mask=attention_mask,
|
438 |
+
bos_token_id=self.tokenizer.bos_token_id,
|
439 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
440 |
+
pad_token_id=self.tokenizer.pad_token_id
|
441 |
+
)
|
442 |
+
|
443 |
+
return model_outputs
|
slam_llm/models/vallex/__init__.py
ADDED
File without changes
|
slam_llm/models/vallex/activation.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, List
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Linear, Module
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
9 |
+
|
10 |
+
|
11 |
+
class MultiheadAttention(Module):
|
12 |
+
__constants__ = ["batch_first"]
|
13 |
+
bias_k: Optional[torch.Tensor]
|
14 |
+
bias_v: Optional[torch.Tensor]
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
embed_dim,
|
19 |
+
num_heads,
|
20 |
+
dropout=0.0,
|
21 |
+
bias=True,
|
22 |
+
add_bias_kv=False,
|
23 |
+
add_zero_attn=False,
|
24 |
+
kdim=None,
|
25 |
+
vdim=None,
|
26 |
+
batch_first=False,
|
27 |
+
linear1_cls=Linear,
|
28 |
+
linear2_cls=Linear,
|
29 |
+
device=None,
|
30 |
+
dtype=None,
|
31 |
+
) -> None:
|
32 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
33 |
+
super(MultiheadAttention, self).__init__()
|
34 |
+
self.embed_dim = embed_dim
|
35 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
36 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
37 |
+
self._qkv_same_embed_dim = False
|
38 |
+
|
39 |
+
self.num_heads = num_heads
|
40 |
+
self.dropout = dropout
|
41 |
+
self.batch_first = batch_first
|
42 |
+
self.head_dim = embed_dim // num_heads
|
43 |
+
self.num_heads = num_heads
|
44 |
+
assert (
|
45 |
+
self.head_dim * num_heads == self.embed_dim
|
46 |
+
), "embed_dim must be divisible by num_heads"
|
47 |
+
|
48 |
+
self.k_proj = Linear(self.kdim, embed_dim)
|
49 |
+
self.v_proj = Linear(self.kdim, embed_dim)
|
50 |
+
self.q_proj = Linear(self.kdim, embed_dim)
|
51 |
+
|
52 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
53 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
54 |
+
)
|
55 |
+
|
56 |
+
self.add_zero_attn = add_zero_attn
|
57 |
+
self.scaling = self.head_dim**-0.5
|
58 |
+
|
59 |
+
def __setstate__(self, state):
|
60 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
61 |
+
if "_qkv_same_embed_dim" not in state:
|
62 |
+
state["_qkv_same_embed_dim"] = True
|
63 |
+
|
64 |
+
super(MultiheadAttention, self).__setstate__(state)
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
query: Tensor,
|
69 |
+
key: Tensor,
|
70 |
+
value: Tensor,
|
71 |
+
key_padding_mask: Optional[Tensor] = None,
|
72 |
+
need_weights: bool = True,
|
73 |
+
attn_mask: Optional[Tensor] = None,
|
74 |
+
average_attn_weights: bool = True,
|
75 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
76 |
+
|
77 |
+
# T,B,C
|
78 |
+
B, T, C = query.size()
|
79 |
+
|
80 |
+
q = self.q_proj(query)
|
81 |
+
k = self.k_proj(key)
|
82 |
+
v = self.v_proj(value)
|
83 |
+
q *= self.scaling
|
84 |
+
|
85 |
+
k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
|
86 |
+
q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
|
87 |
+
v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
|
88 |
+
|
89 |
+
attn_weights = q @ k.transpose(-2, -1) # B, nh, T, T
|
90 |
+
|
91 |
+
if attn_mask is not None:
|
92 |
+
# attn_mask is inf
|
93 |
+
# attn_mask = attn_mask.unsqueeze(0)
|
94 |
+
# attn_weights += attn_mask
|
95 |
+
if torch.is_floating_point(attn_mask):
|
96 |
+
# print(attn_weights.size(), attn_mask.size())
|
97 |
+
attn_weights += attn_mask.unsqueeze(0).unsqueeze(1)
|
98 |
+
else:
|
99 |
+
attn_weights = attn_weights.masked_fill(attn_mask, float('-inf'))
|
100 |
+
|
101 |
+
if key_padding_mask is not None:
|
102 |
+
# don't attend to padding symbols
|
103 |
+
attn_weights = attn_weights.view(B, self.num_heads, T, T)
|
104 |
+
attn_weights = attn_weights.masked_fill(
|
105 |
+
key_padding_mask.unsqueeze(1)
|
106 |
+
.unsqueeze(2)
|
107 |
+
.to(torch.bool),
|
108 |
+
float("-inf"),
|
109 |
+
)
|
110 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
111 |
+
attn = attn_weights_float @ v
|
112 |
+
|
113 |
+
y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
114 |
+
y = self.out_proj(y)
|
115 |
+
return y, attn_weights
|
116 |
+
|
117 |
+
def infer(self,
|
118 |
+
x: Tensor,
|
119 |
+
key_padding_mask: Optional[Tensor] = None,
|
120 |
+
need_weights: bool = True,
|
121 |
+
attn_mask: Optional[Tensor] = None,
|
122 |
+
average_attn_weights: bool = True,
|
123 |
+
past_kv = None,
|
124 |
+
use_cache = False):
|
125 |
+
|
126 |
+
# print("debug:"+str(x.size()))
|
127 |
+
|
128 |
+
B, T, C = x.size()
|
129 |
+
|
130 |
+
q = self.q_proj(x)
|
131 |
+
k = self.k_proj(x)
|
132 |
+
v = self.v_proj(x)
|
133 |
+
q *= self.scaling
|
134 |
+
|
135 |
+
# k = k.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs)
|
136 |
+
# q = q.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs)
|
137 |
+
# v = v.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs)
|
138 |
+
|
139 |
+
k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
|
140 |
+
q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
|
141 |
+
v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
|
142 |
+
|
143 |
+
if past_kv is not None:
|
144 |
+
past_key = past_kv[0]
|
145 |
+
past_value = past_kv[1]
|
146 |
+
k = torch.cat((past_key, k), dim=-2)
|
147 |
+
v = torch.cat((past_value, v), dim=-2)
|
148 |
+
|
149 |
+
FULL_T = k.shape[-2]
|
150 |
+
|
151 |
+
if use_cache is True:
|
152 |
+
present = (k, v)
|
153 |
+
else:
|
154 |
+
present = None
|
155 |
+
|
156 |
+
# print(q.size(), k.size())
|
157 |
+
attn_weights = q @ k.transpose(-2, -1)
|
158 |
+
# print(attn_mask.size())
|
159 |
+
attn_weights = attn_weights.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
|
160 |
+
|
161 |
+
# if key_padding_mask is not None:
|
162 |
+
# # don't attend to padding symbols
|
163 |
+
# attn_weights = attn_weights.view(B, self.num_heads, T, T)
|
164 |
+
# attn_weights = attn_weights.view(B, -1, self.num_heads, T, T)
|
165 |
+
# attn_weights = attn_weights.masked_fill(
|
166 |
+
# key_padding_mask.unsqueeze(1)
|
167 |
+
# .unsqueeze(2)
|
168 |
+
# .unsqueeze(3)
|
169 |
+
# .to(torch.bool),
|
170 |
+
# float("-inf"),
|
171 |
+
# )
|
172 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1, )
|
173 |
+
# attn_weights = attn_weights_float.type_as(attn_weights)
|
174 |
+
# attn = torch.bmm(attn_weights, v)
|
175 |
+
attn = attn_weights_float @ v
|
176 |
+
|
177 |
+
y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
178 |
+
y = self.out_proj(y)
|
179 |
+
return (y, present)
|
slam_llm/models/vallex/scaling.py
ADDED
@@ -0,0 +1,1404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import logging
|
20 |
+
import random
|
21 |
+
import math
|
22 |
+
from functools import reduce
|
23 |
+
from itertools import repeat
|
24 |
+
from typing import Optional, Tuple, Union
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
from torch import Tensor
|
30 |
+
from torch.nn import Embedding as ScaledEmbedding
|
31 |
+
|
32 |
+
class Transpose(nn.Identity):
|
33 |
+
"""(N, T, D) -> (N, D, T)"""
|
34 |
+
|
35 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
36 |
+
return input.transpose(1, 2)
|
37 |
+
|
38 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
39 |
+
@staticmethod
|
40 |
+
def forward(
|
41 |
+
ctx,
|
42 |
+
x: Tensor,
|
43 |
+
scale_factor: Tensor,
|
44 |
+
sign_factor: Optional[Tensor],
|
45 |
+
channel_dim: int,
|
46 |
+
) -> Tensor:
|
47 |
+
if channel_dim < 0:
|
48 |
+
channel_dim += x.ndim
|
49 |
+
ctx.channel_dim = channel_dim
|
50 |
+
xgt0 = x > 0
|
51 |
+
if sign_factor is None:
|
52 |
+
ctx.save_for_backward(xgt0, scale_factor)
|
53 |
+
else:
|
54 |
+
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
55 |
+
return x
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
59 |
+
if len(ctx.saved_tensors) == 3:
|
60 |
+
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
61 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
62 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
63 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
64 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
65 |
+
else:
|
66 |
+
xgt0, scale_factor = ctx.saved_tensors
|
67 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
68 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
69 |
+
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
70 |
+
neg_delta_grad = x_grad.abs() * factor
|
71 |
+
return (
|
72 |
+
x_grad - neg_delta_grad,
|
73 |
+
None,
|
74 |
+
None,
|
75 |
+
None,
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
def _compute_scale_factor(
|
80 |
+
x: Tensor,
|
81 |
+
channel_dim: int,
|
82 |
+
min_abs: float,
|
83 |
+
max_abs: float,
|
84 |
+
gain_factor: float,
|
85 |
+
max_factor: float,
|
86 |
+
) -> Tensor:
|
87 |
+
if channel_dim < 0:
|
88 |
+
channel_dim += x.ndim
|
89 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
90 |
+
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
91 |
+
|
92 |
+
if min_abs == 0.0:
|
93 |
+
below_threshold = 0.0
|
94 |
+
else:
|
95 |
+
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
96 |
+
# x_abs)_mean , min_abs.
|
97 |
+
below_threshold = (
|
98 |
+
(min_abs - x_abs_mean) * (gain_factor / min_abs)
|
99 |
+
).clamp(min=0, max=max_factor)
|
100 |
+
|
101 |
+
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
102 |
+
min=0, max=max_factor
|
103 |
+
)
|
104 |
+
|
105 |
+
return below_threshold - above_threshold
|
106 |
+
|
107 |
+
|
108 |
+
def _compute_sign_factor(
|
109 |
+
x: Tensor,
|
110 |
+
channel_dim: int,
|
111 |
+
min_positive: float,
|
112 |
+
max_positive: float,
|
113 |
+
gain_factor: float,
|
114 |
+
max_factor: float,
|
115 |
+
) -> Tensor:
|
116 |
+
if channel_dim < 0:
|
117 |
+
channel_dim += x.ndim
|
118 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
119 |
+
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
|
120 |
+
if min_positive == 0.0:
|
121 |
+
factor1 = 0.0
|
122 |
+
else:
|
123 |
+
# 0 if proportion_positive >= min_positive, else can be
|
124 |
+
# as large as max_factor.
|
125 |
+
factor1 = (
|
126 |
+
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
127 |
+
).clamp_(min=0, max=max_factor)
|
128 |
+
|
129 |
+
if max_positive == 1.0:
|
130 |
+
factor2 = 0.0
|
131 |
+
else:
|
132 |
+
# 0 if self.proportion_positive <= max_positive, else can be
|
133 |
+
# as large as -max_factor.
|
134 |
+
factor2 = (
|
135 |
+
(proportion_positive - max_positive)
|
136 |
+
* (gain_factor / (1.0 - max_positive))
|
137 |
+
).clamp_(min=0, max=max_factor)
|
138 |
+
sign_factor = factor1 - factor2
|
139 |
+
# require min_positive != 0 or max_positive != 1:
|
140 |
+
assert not isinstance(sign_factor, float)
|
141 |
+
return sign_factor
|
142 |
+
|
143 |
+
|
144 |
+
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
145 |
+
"""
|
146 |
+
This object is used in class ActivationBalancer when the user specified
|
147 |
+
min_positive=0, max_positive=1, so there are no constraints on the signs
|
148 |
+
of the activations and only the absolute value has a constraint.
|
149 |
+
"""
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def forward(
|
153 |
+
ctx,
|
154 |
+
x: Tensor,
|
155 |
+
sign_factor: Tensor,
|
156 |
+
scale_factor: Tensor,
|
157 |
+
channel_dim: int,
|
158 |
+
) -> Tensor:
|
159 |
+
if channel_dim < 0:
|
160 |
+
channel_dim += x.ndim
|
161 |
+
ctx.channel_dim = channel_dim
|
162 |
+
xgt0 = x > 0
|
163 |
+
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
164 |
+
return x
|
165 |
+
|
166 |
+
@staticmethod
|
167 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
168 |
+
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
169 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
170 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
171 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
172 |
+
|
173 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
174 |
+
neg_delta_grad = x_grad.abs() * factor
|
175 |
+
return (
|
176 |
+
x_grad - neg_delta_grad,
|
177 |
+
None,
|
178 |
+
None,
|
179 |
+
None,
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
class RandomClampFunction(torch.autograd.Function):
|
184 |
+
@staticmethod
|
185 |
+
def forward(
|
186 |
+
ctx,
|
187 |
+
x: Tensor,
|
188 |
+
min: Optional[float],
|
189 |
+
max: Optional[float],
|
190 |
+
prob: float,
|
191 |
+
reflect: float,
|
192 |
+
) -> Tensor:
|
193 |
+
x_clamped = torch.clamp(x, min=min, max=max)
|
194 |
+
mask = torch.rand_like(x) < prob
|
195 |
+
ans = torch.where(mask, x_clamped, x)
|
196 |
+
if x.requires_grad:
|
197 |
+
ctx.save_for_backward(ans == x)
|
198 |
+
ctx.reflect = reflect
|
199 |
+
if reflect != 0.0:
|
200 |
+
ans = ans * (1.0 + reflect) - (x * reflect)
|
201 |
+
return ans
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def backward(
|
205 |
+
ctx, ans_grad: Tensor
|
206 |
+
) -> Tuple[Tensor, None, None, None, None]:
|
207 |
+
(is_same,) = ctx.saved_tensors
|
208 |
+
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
209 |
+
reflect = ctx.reflect
|
210 |
+
if reflect != 0.0:
|
211 |
+
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
212 |
+
return x_grad, None, None, None, None
|
213 |
+
|
214 |
+
|
215 |
+
def random_clamp(
|
216 |
+
x: Tensor,
|
217 |
+
min: Optional[float] = None,
|
218 |
+
max: Optional[float] = None,
|
219 |
+
prob: float = 0.5,
|
220 |
+
reflect: float = 0.0,
|
221 |
+
):
|
222 |
+
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
223 |
+
|
224 |
+
|
225 |
+
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
226 |
+
"""
|
227 |
+
A randomized way of casting a floating point value to half precision.
|
228 |
+
"""
|
229 |
+
if x.dtype == torch.float16:
|
230 |
+
return x
|
231 |
+
x_abs = x.abs()
|
232 |
+
is_too_small = x_abs < min_abs
|
233 |
+
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
234 |
+
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
235 |
+
# for those elements].
|
236 |
+
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
237 |
+
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
238 |
+
|
239 |
+
|
240 |
+
class RandomGradFunction(torch.autograd.Function):
|
241 |
+
"""
|
242 |
+
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
243 |
+
randomized approach that preserves expectations (intended to reduce roundoff).
|
244 |
+
"""
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
248 |
+
ctx.min_abs = min_abs
|
249 |
+
return x
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
253 |
+
if ans_grad.dtype == torch.float16:
|
254 |
+
return (
|
255 |
+
random_cast_to_half(
|
256 |
+
ans_grad.to(torch.float32), min_abs=ctx.min_abs
|
257 |
+
),
|
258 |
+
None,
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
return ans_grad, None
|
262 |
+
|
263 |
+
|
264 |
+
class RandomGrad(torch.nn.Module):
|
265 |
+
"""
|
266 |
+
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
267 |
+
accuracy of training when using amp (automatic mixed precision)
|
268 |
+
"""
|
269 |
+
|
270 |
+
def __init__(self, min_abs: float = 5.0e-06):
|
271 |
+
super(RandomGrad, self).__init__()
|
272 |
+
self.min_abs = min_abs
|
273 |
+
|
274 |
+
def forward(self, x: Tensor):
|
275 |
+
if (
|
276 |
+
torch.jit.is_scripting()
|
277 |
+
or not self.training
|
278 |
+
or torch.jit.is_tracing()
|
279 |
+
):
|
280 |
+
return x
|
281 |
+
else:
|
282 |
+
return RandomGradFunction.apply(x, self.min_abs)
|
283 |
+
|
284 |
+
|
285 |
+
class SoftmaxFunction(torch.autograd.Function):
|
286 |
+
"""
|
287 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
288 |
+
be more accurate for training than the default behavior.
|
289 |
+
"""
|
290 |
+
|
291 |
+
@staticmethod
|
292 |
+
def forward(ctx, x: Tensor, dim: int):
|
293 |
+
ans = x.softmax(dim=dim)
|
294 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
295 |
+
# (presumably) that op does not support float16, and autocast
|
296 |
+
# is enabled.
|
297 |
+
if torch.is_autocast_enabled():
|
298 |
+
ans = ans.to(torch.float16)
|
299 |
+
ctx.save_for_backward(ans)
|
300 |
+
ctx.x_dtype = x.dtype
|
301 |
+
ctx.dim = dim
|
302 |
+
return ans
|
303 |
+
|
304 |
+
@staticmethod
|
305 |
+
def backward(ctx, ans_grad: Tensor):
|
306 |
+
(ans,) = ctx.saved_tensors
|
307 |
+
with torch.cuda.amp.autocast(enabled=False):
|
308 |
+
ans_grad = ans_grad.to(torch.float32)
|
309 |
+
ans = ans.to(torch.float32)
|
310 |
+
x_grad = ans_grad * ans
|
311 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
312 |
+
return x_grad, None
|
313 |
+
|
314 |
+
|
315 |
+
def softmax(x: Tensor, dim: int):
|
316 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
317 |
+
return x.softmax(dim)
|
318 |
+
|
319 |
+
return SoftmaxFunction.apply(x, dim)
|
320 |
+
|
321 |
+
|
322 |
+
class MaxEigLimiterFunction(torch.autograd.Function):
|
323 |
+
@staticmethod
|
324 |
+
def forward(
|
325 |
+
ctx,
|
326 |
+
x: Tensor,
|
327 |
+
coeffs: Tensor,
|
328 |
+
direction: Tensor,
|
329 |
+
channel_dim: int,
|
330 |
+
grad_scale: float,
|
331 |
+
) -> Tensor:
|
332 |
+
ctx.channel_dim = channel_dim
|
333 |
+
ctx.grad_scale = grad_scale
|
334 |
+
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
335 |
+
return x
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def backward(ctx, x_grad, *args):
|
339 |
+
with torch.enable_grad():
|
340 |
+
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
341 |
+
x_orig.requires_grad = True
|
342 |
+
num_channels = x_orig.shape[ctx.channel_dim]
|
343 |
+
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
344 |
+
new_direction.requires_grad = False
|
345 |
+
x = x - x.mean(dim=0)
|
346 |
+
x_var = (x ** 2).mean()
|
347 |
+
x_residual = x - coeffs * new_direction
|
348 |
+
x_residual_var = (x_residual ** 2).mean()
|
349 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
350 |
+
# by the top eigen-direction. This is to be minimized.
|
351 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
352 |
+
variance_proportion.backward()
|
353 |
+
x_orig_grad = x_orig.grad
|
354 |
+
x_extra_grad = (
|
355 |
+
x_orig.grad
|
356 |
+
* ctx.grad_scale
|
357 |
+
* x_grad.norm()
|
358 |
+
/ (x_orig_grad.norm() + 1.0e-20)
|
359 |
+
)
|
360 |
+
return x_grad + x_extra_grad.detach(), None, None, None, None
|
361 |
+
|
362 |
+
|
363 |
+
class BasicNorm(torch.nn.Module):
|
364 |
+
"""
|
365 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
366 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
367 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
368 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
369 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
370 |
+
on the other (useful) features. Presumably the weight and bias of the
|
371 |
+
LayerNorm are required to allow it to do this.
|
372 |
+
|
373 |
+
So the idea is to introduce this large constant value as an explicit
|
374 |
+
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
375 |
+
doesn't have to do this trick. We make the "eps" learnable.
|
376 |
+
|
377 |
+
Args:
|
378 |
+
num_channels: the number of channels, e.g. 512.
|
379 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
380 |
+
interprted as an offset from the input's ndim if negative.
|
381 |
+
shis is NOT the num_channels; it should typically be one of
|
382 |
+
{-2, -1, 0, 1, 2, 3}.
|
383 |
+
eps: the initial "epsilon" that we add as ballast in:
|
384 |
+
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
385 |
+
Note: our epsilon is actually large, but we keep the name
|
386 |
+
to indicate the connection with conventional LayerNorm.
|
387 |
+
learn_eps: if true, we learn epsilon; if false, we keep it
|
388 |
+
at the initial value.
|
389 |
+
eps_min: float
|
390 |
+
eps_max: float
|
391 |
+
"""
|
392 |
+
|
393 |
+
def __init__(
|
394 |
+
self,
|
395 |
+
num_channels: int,
|
396 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
397 |
+
eps: float = 0.25,
|
398 |
+
learn_eps: bool = True,
|
399 |
+
eps_min: float = -3.0,
|
400 |
+
eps_max: float = 3.0,
|
401 |
+
) -> None:
|
402 |
+
super(BasicNorm, self).__init__()
|
403 |
+
self.num_channels = num_channels
|
404 |
+
self.channel_dim = channel_dim
|
405 |
+
if learn_eps:
|
406 |
+
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
407 |
+
else:
|
408 |
+
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
409 |
+
self.eps_min = eps_min
|
410 |
+
self.eps_max = eps_max
|
411 |
+
|
412 |
+
def forward(self, x: Tensor) -> Tensor:
|
413 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
414 |
+
eps = self.eps
|
415 |
+
if self.training and random.random() < 0.25:
|
416 |
+
# with probability 0.25, in training mode, clamp eps between the min
|
417 |
+
# and max; this will encourage it to learn parameters within the
|
418 |
+
# allowed range by making parameters that are outside the allowed
|
419 |
+
# range noisy.
|
420 |
+
|
421 |
+
# gradients to allow the parameter to get back into the allowed
|
422 |
+
# region if it happens to exit it.
|
423 |
+
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
424 |
+
scales = (
|
425 |
+
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
426 |
+
) ** -0.5
|
427 |
+
return x * scales
|
428 |
+
|
429 |
+
|
430 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
431 |
+
"""
|
432 |
+
Behaves like a constructor of a modified version of nn.Linear
|
433 |
+
that gives an easy way to set the default initial parameter scale.
|
434 |
+
|
435 |
+
Args:
|
436 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
437 |
+
e.g. in_features, out_features, bias=False.
|
438 |
+
|
439 |
+
initial_scale: you can override this if you want to increase
|
440 |
+
or decrease the initial magnitude of the module's output
|
441 |
+
(affects the initialization of weight_scale and bias_scale).
|
442 |
+
Another option, if you want to do something like this, is
|
443 |
+
to re-initialize the parameters.
|
444 |
+
"""
|
445 |
+
ans = nn.Linear(*args, **kwargs)
|
446 |
+
with torch.no_grad():
|
447 |
+
ans.weight[:] *= initial_scale
|
448 |
+
if ans.bias is not None:
|
449 |
+
torch.nn.init.uniform_(
|
450 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
451 |
+
)
|
452 |
+
return ans
|
453 |
+
|
454 |
+
|
455 |
+
def ScaledConv1d(
|
456 |
+
*args,
|
457 |
+
initial_scale: float = 1.0,
|
458 |
+
kernel_size: int = 3,
|
459 |
+
padding: str = "same",
|
460 |
+
**kwargs,
|
461 |
+
) -> nn.Conv1d:
|
462 |
+
"""
|
463 |
+
Behaves like a constructor of a modified version of nn.Conv1d
|
464 |
+
that gives an easy way to set the default initial parameter scale.
|
465 |
+
|
466 |
+
Args:
|
467 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
468 |
+
e.g. in_features, out_features, bias=False.
|
469 |
+
|
470 |
+
initial_scale: you can override this if you want to increase
|
471 |
+
or decrease the initial magnitude of the module's output
|
472 |
+
(affects the initialization of weight_scale and bias_scale).
|
473 |
+
Another option, if you want to do something like this, is
|
474 |
+
to re-initialize the parameters.
|
475 |
+
"""
|
476 |
+
ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
|
477 |
+
with torch.no_grad():
|
478 |
+
ans.weight[:] *= initial_scale
|
479 |
+
if ans.bias is not None:
|
480 |
+
torch.nn.init.uniform_(
|
481 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
482 |
+
)
|
483 |
+
return ans
|
484 |
+
|
485 |
+
|
486 |
+
def TransposeScaledConv1d(
|
487 |
+
*args,
|
488 |
+
initial_scale: float = 1.0,
|
489 |
+
kernel_size: int = 3,
|
490 |
+
padding: str = "same",
|
491 |
+
**kwargs,
|
492 |
+
) -> nn.Sequential:
|
493 |
+
"""
|
494 |
+
Transpose -> ScaledConv1d
|
495 |
+
"""
|
496 |
+
return nn.Sequential(
|
497 |
+
Transpose(),
|
498 |
+
ScaledConv1d(
|
499 |
+
*args,
|
500 |
+
initial_scale=initial_scale,
|
501 |
+
kernel_size=kernel_size,
|
502 |
+
padding=padding,
|
503 |
+
**kwargs,
|
504 |
+
),
|
505 |
+
)
|
506 |
+
|
507 |
+
|
508 |
+
def ScaledConv1dTranspose(
|
509 |
+
*args,
|
510 |
+
initial_scale: float = 1.0,
|
511 |
+
kernel_size: int = 3,
|
512 |
+
padding: str = "same",
|
513 |
+
**kwargs,
|
514 |
+
) -> nn.Sequential:
|
515 |
+
"""
|
516 |
+
Transpose -> ScaledConv1d
|
517 |
+
"""
|
518 |
+
return nn.Sequential(
|
519 |
+
ScaledConv1d(
|
520 |
+
*args,
|
521 |
+
initial_scale=initial_scale,
|
522 |
+
kernel_size=kernel_size,
|
523 |
+
padding=padding,
|
524 |
+
**kwargs,
|
525 |
+
),
|
526 |
+
Transpose(),
|
527 |
+
)
|
528 |
+
|
529 |
+
|
530 |
+
def TransposeConv1d(
|
531 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
532 |
+
) -> nn.Sequential:
|
533 |
+
"""
|
534 |
+
Transpose -> Conv1d
|
535 |
+
"""
|
536 |
+
return nn.Sequential(
|
537 |
+
Transpose(),
|
538 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
539 |
+
)
|
540 |
+
|
541 |
+
|
542 |
+
def Conv1dTranspose(
|
543 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
544 |
+
) -> nn.Sequential:
|
545 |
+
"""
|
546 |
+
ScaledConv1d -> Transpose
|
547 |
+
"""
|
548 |
+
return nn.Sequential(
|
549 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
550 |
+
Transpose(),
|
551 |
+
)
|
552 |
+
|
553 |
+
|
554 |
+
class SRLinear(nn.Linear):
|
555 |
+
"""https://arxiv.org/abs/2303.06296
|
556 |
+
Stabilizing Transformer Training by Preventing Attention Entropy Collapse
|
557 |
+
"""
|
558 |
+
|
559 |
+
def __init__(self, in_features, out_features, bias=True, **kwargs):
|
560 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
561 |
+
self.register_buffer(
|
562 |
+
"u", nn.functional.normalize(torch.randn(in_features), dim=0)
|
563 |
+
)
|
564 |
+
with torch.no_grad():
|
565 |
+
sigma = self.get_sigma()
|
566 |
+
self.register_buffer("spectral_norm", sigma)
|
567 |
+
self.sigma = nn.Parameter(torch.ones(1))
|
568 |
+
|
569 |
+
def get_sigma(self):
|
570 |
+
with torch.no_grad():
|
571 |
+
u = self.u
|
572 |
+
v = self.weight.mv(u)
|
573 |
+
v = nn.functional.normalize(v, dim=0)
|
574 |
+
u = self.weight.T.mv(v)
|
575 |
+
u = nn.functional.normalize(u, dim=0)
|
576 |
+
self.u.data.copy_(u)
|
577 |
+
return torch.einsum("c,cd,d->", v, self.weight, u)
|
578 |
+
|
579 |
+
def get_weight(self):
|
580 |
+
sigma = self.get_sigma()
|
581 |
+
if self.training:
|
582 |
+
self.spectral_norm.data.copy_(sigma)
|
583 |
+
weight = (self.sigma / sigma) * self.weight
|
584 |
+
return weight
|
585 |
+
|
586 |
+
def forward(self, x):
|
587 |
+
return nn.functional.linear(x, self.get_weight(), self.bias)
|
588 |
+
|
589 |
+
|
590 |
+
class SRConv1d(SRLinear):
|
591 |
+
def __init__(
|
592 |
+
self,
|
593 |
+
in_features,
|
594 |
+
out_features,
|
595 |
+
kernel_size,
|
596 |
+
stride: int = 1,
|
597 |
+
padding: str = "same",
|
598 |
+
bias: bool = True,
|
599 |
+
**kwargs,
|
600 |
+
):
|
601 |
+
in_features = in_features * kernel_size
|
602 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
603 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
604 |
+
self.kernel_size = kernel_size
|
605 |
+
self.stride = stride
|
606 |
+
self.padding = padding
|
607 |
+
|
608 |
+
def forward(self, x):
|
609 |
+
in_features = self.in_features // self.kernel_size
|
610 |
+
weight = self.get_weight().view(
|
611 |
+
self.out_features, in_features, self.kernel_size
|
612 |
+
)
|
613 |
+
return nn.functional.conv1d(
|
614 |
+
x, weight, bias=self.bias, stride=self.stride, padding=self.padding
|
615 |
+
)
|
616 |
+
|
617 |
+
|
618 |
+
def TransposeSRConv1d(
|
619 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
620 |
+
) -> nn.Sequential:
|
621 |
+
"""
|
622 |
+
Transpose -> SRConv1d
|
623 |
+
"""
|
624 |
+
return nn.Sequential(
|
625 |
+
Transpose(),
|
626 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
627 |
+
)
|
628 |
+
|
629 |
+
|
630 |
+
def SRConv1dTranspose(
|
631 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
632 |
+
) -> nn.Sequential:
|
633 |
+
"""
|
634 |
+
SRConv1d -> Transpose
|
635 |
+
"""
|
636 |
+
return nn.Sequential(
|
637 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
638 |
+
Transpose(),
|
639 |
+
)
|
640 |
+
|
641 |
+
|
642 |
+
class ActivationBalancer(torch.nn.Module):
|
643 |
+
"""
|
644 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
645 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
646 |
+
time. It does this by multiplying negative derivative values by up to
|
647 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
648 |
+
interpolated from 1 at the threshold to those extremal values when none
|
649 |
+
of the inputs are positive.
|
650 |
+
|
651 |
+
Args:
|
652 |
+
num_channels: the number of channels
|
653 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
654 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
655 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
656 |
+
that (x > 0), below which we start to modify the derivatives.
|
657 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
658 |
+
that (x > 0), above which we start to modify the derivatives.
|
659 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
660 |
+
either the sign constraint or the magnitude constraint;
|
661 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
662 |
+
values in the range [0.98..1.02].
|
663 |
+
sign_gain_factor: determines the 'gain' with which we increase the
|
664 |
+
change in gradient once the constraints on min_positive and max_positive
|
665 |
+
are violated.
|
666 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
667 |
+
change in gradient once the constraints on min_abs and max_abs
|
668 |
+
are violated.
|
669 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
670 |
+
value per channel, which we allow, before we start to modify
|
671 |
+
the derivatives to prevent this.
|
672 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
673 |
+
value per channel, which we allow, before we start to modify
|
674 |
+
the derivatives to prevent this.
|
675 |
+
min_prob: determines the minimum probability with which we modify the
|
676 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
677 |
+
on each forward(). This is done randomly to prevent all layers
|
678 |
+
from doing it at the same time. Early in training we may use
|
679 |
+
higher probabilities than this; it will decay to this value.
|
680 |
+
"""
|
681 |
+
|
682 |
+
def __init__(
|
683 |
+
self,
|
684 |
+
num_channels: int,
|
685 |
+
channel_dim: int,
|
686 |
+
min_positive: float = 0.05,
|
687 |
+
max_positive: float = 0.95,
|
688 |
+
max_factor: float = 0.04,
|
689 |
+
sign_gain_factor: float = 0.01,
|
690 |
+
scale_gain_factor: float = 0.02,
|
691 |
+
min_abs: float = 0.2,
|
692 |
+
max_abs: float = 100.0,
|
693 |
+
min_prob: float = 0.1,
|
694 |
+
):
|
695 |
+
super(ActivationBalancer, self).__init__()
|
696 |
+
self.num_channels = num_channels
|
697 |
+
self.channel_dim = channel_dim
|
698 |
+
self.min_positive = min_positive
|
699 |
+
self.max_positive = max_positive
|
700 |
+
self.max_factor = max_factor
|
701 |
+
self.min_abs = min_abs
|
702 |
+
self.max_abs = max_abs
|
703 |
+
self.min_prob = min_prob
|
704 |
+
self.sign_gain_factor = sign_gain_factor
|
705 |
+
self.scale_gain_factor = scale_gain_factor
|
706 |
+
|
707 |
+
# count measures how many times the forward() function has been called.
|
708 |
+
# We occasionally sync this to a tensor called `count`, that exists to
|
709 |
+
# make sure it is synced to disk when we load and save the model.
|
710 |
+
self.cpu_count = 0
|
711 |
+
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
712 |
+
|
713 |
+
def forward(self, x: Tensor) -> Tensor:
|
714 |
+
if (
|
715 |
+
torch.jit.is_scripting()
|
716 |
+
or not x.requires_grad
|
717 |
+
or torch.jit.is_tracing()
|
718 |
+
):
|
719 |
+
return _no_op(x)
|
720 |
+
|
721 |
+
count = self.cpu_count
|
722 |
+
self.cpu_count += 1
|
723 |
+
|
724 |
+
if random.random() < 0.01:
|
725 |
+
# Occasionally sync self.cpu_count with self.count.
|
726 |
+
# count affects the decay of 'prob'. don't do this on every iter,
|
727 |
+
# because syncing with the GPU is slow.
|
728 |
+
self.cpu_count = max(self.cpu_count, self.count.item())
|
729 |
+
self.count.fill_(self.cpu_count)
|
730 |
+
|
731 |
+
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
732 |
+
# a floor at min_prob (==0.1, by default)
|
733 |
+
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
734 |
+
|
735 |
+
if random.random() < prob:
|
736 |
+
sign_gain_factor = 0.5
|
737 |
+
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
738 |
+
sign_factor = _compute_sign_factor(
|
739 |
+
x,
|
740 |
+
self.channel_dim,
|
741 |
+
self.min_positive,
|
742 |
+
self.max_positive,
|
743 |
+
gain_factor=self.sign_gain_factor / prob,
|
744 |
+
max_factor=self.max_factor,
|
745 |
+
)
|
746 |
+
else:
|
747 |
+
sign_factor = None
|
748 |
+
|
749 |
+
scale_factor = _compute_scale_factor(
|
750 |
+
x.detach(),
|
751 |
+
self.channel_dim,
|
752 |
+
min_abs=self.min_abs,
|
753 |
+
max_abs=self.max_abs,
|
754 |
+
gain_factor=self.scale_gain_factor / prob,
|
755 |
+
max_factor=self.max_factor,
|
756 |
+
)
|
757 |
+
return ActivationBalancerFunction.apply(
|
758 |
+
x,
|
759 |
+
scale_factor,
|
760 |
+
sign_factor,
|
761 |
+
self.channel_dim,
|
762 |
+
)
|
763 |
+
else:
|
764 |
+
return _no_op(x)
|
765 |
+
|
766 |
+
|
767 |
+
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
768 |
+
"""
|
769 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
770 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
771 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
772 |
+
|
773 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
774 |
+
in automatic mixed precision training. For this reasons we use this,
|
775 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
776 |
+
to disallow really implausible values of scores to be given to softmax.
|
777 |
+
"""
|
778 |
+
x_sign = x.sign()
|
779 |
+
over_limit = (x.abs() - limit) > 0
|
780 |
+
# The following is a memory efficient way to penalize the absolute values of
|
781 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
782 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
783 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
784 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
785 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
786 |
+
# limit).relu().
|
787 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
788 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
789 |
+
# sum() due to how with_loss() works.
|
790 |
+
x = with_loss(x, aux_loss)
|
791 |
+
# you must use x for something, or this will be ineffective.
|
792 |
+
return x
|
793 |
+
|
794 |
+
|
795 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
796 |
+
if x.ndim == 2:
|
797 |
+
return x.diag()
|
798 |
+
else:
|
799 |
+
(batch, dim, dim) = x.shape
|
800 |
+
x = x.reshape(batch, dim * dim)
|
801 |
+
x = x[:, :: dim + 1]
|
802 |
+
assert x.shape == (batch, dim)
|
803 |
+
return x
|
804 |
+
|
805 |
+
|
806 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
807 |
+
"""
|
808 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
809 |
+
of the centered feature covariance are the same within each group's covariance matrix
|
810 |
+
and also between groups.
|
811 |
+
Args:
|
812 |
+
x: a Tensor of shape (*, num_channels)
|
813 |
+
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
814 |
+
Returns:
|
815 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
816 |
+
greater than 1.0 otherwise.
|
817 |
+
"""
|
818 |
+
assert x.dtype != torch.float16
|
819 |
+
x = x.reshape(-1, x.shape[-1])
|
820 |
+
(num_frames, num_channels) = x.shape
|
821 |
+
assert num_channels % num_groups == 0
|
822 |
+
channels_per_group = num_channels // num_groups
|
823 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
824 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
825 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
826 |
+
# My experience has been that when we "mess with the gradients" like this,
|
827 |
+
# it's better not do anything that tries to move the mean around, because
|
828 |
+
# that can easily cause instability.
|
829 |
+
x = x - x.mean(dim=1, keepdim=True)
|
830 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
831 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
832 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
833 |
+
# the following expression is what we'd get if we took the matrix product
|
834 |
+
# of each covariance and measured the mean of its trace, i.e.
|
835 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
836 |
+
x_covarsq_mean_diag = (x_covar ** 2).sum() / (
|
837 |
+
num_groups * channels_per_group
|
838 |
+
)
|
839 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
840 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
|
841 |
+
return metric
|
842 |
+
|
843 |
+
|
844 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
845 |
+
@staticmethod
|
846 |
+
def forward(
|
847 |
+
ctx,
|
848 |
+
x: Tensor,
|
849 |
+
num_groups: int,
|
850 |
+
whitening_limit: float,
|
851 |
+
grad_scale: float,
|
852 |
+
) -> Tensor:
|
853 |
+
ctx.save_for_backward(x)
|
854 |
+
ctx.num_groups = num_groups
|
855 |
+
ctx.whitening_limit = whitening_limit
|
856 |
+
ctx.grad_scale = grad_scale
|
857 |
+
return x
|
858 |
+
|
859 |
+
@staticmethod
|
860 |
+
def backward(ctx, x_grad: Tensor):
|
861 |
+
(x_orig,) = ctx.saved_tensors
|
862 |
+
with torch.enable_grad():
|
863 |
+
with torch.cuda.amp.autocast(enabled=False):
|
864 |
+
x_detached = x_orig.to(torch.float32).detach()
|
865 |
+
x_detached.requires_grad = True
|
866 |
+
|
867 |
+
metric = _whitening_metric(x_detached, ctx.num_groups)
|
868 |
+
|
869 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
870 |
+
logging.info(
|
871 |
+
f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
872 |
+
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
|
873 |
+
)
|
874 |
+
|
875 |
+
(metric - ctx.whitening_limit).relu().backward()
|
876 |
+
penalty_grad = x_detached.grad
|
877 |
+
scale = ctx.grad_scale * (
|
878 |
+
x_grad.to(torch.float32).norm()
|
879 |
+
/ (penalty_grad.norm() + 1.0e-20)
|
880 |
+
)
|
881 |
+
penalty_grad = penalty_grad * scale
|
882 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
883 |
+
|
884 |
+
|
885 |
+
class Whiten(nn.Module):
|
886 |
+
def __init__(
|
887 |
+
self,
|
888 |
+
num_groups: int,
|
889 |
+
whitening_limit: float,
|
890 |
+
prob: Union[float, Tuple[float, float]],
|
891 |
+
grad_scale: float,
|
892 |
+
):
|
893 |
+
"""
|
894 |
+
Args:
|
895 |
+
num_groups: the number of groups to divide the channel dim into before
|
896 |
+
whitening. We will attempt to make the feature covariance
|
897 |
+
within each group, after mean subtraction, as "white" as possible,
|
898 |
+
while having the same trace across all groups.
|
899 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
900 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
901 |
+
white, with exactly the same trace across groups; larger values
|
902 |
+
give more freedom. E.g. 2.0.
|
903 |
+
prob: the probability with which we apply the gradient modification
|
904 |
+
(also affects the grad scale). May be supplied as a float,
|
905 |
+
or as a pair (min_prob, max_prob)
|
906 |
+
|
907 |
+
grad_scale: determines the scale on the gradient term from this object,
|
908 |
+
relative to the rest of the gradient on the attention weights.
|
909 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
910 |
+
"""
|
911 |
+
super(Whiten, self).__init__()
|
912 |
+
assert num_groups >= 1
|
913 |
+
assert whitening_limit >= 1
|
914 |
+
assert grad_scale >= 0
|
915 |
+
self.num_groups = num_groups
|
916 |
+
self.whitening_limit = whitening_limit
|
917 |
+
if isinstance(prob, float):
|
918 |
+
assert 0 < prob <= 1
|
919 |
+
self.prob = prob
|
920 |
+
else:
|
921 |
+
(self.min_prob, self.max_prob) = prob
|
922 |
+
assert 0 < self.min_prob < self.max_prob <= 1
|
923 |
+
self.prob = self.max_prob
|
924 |
+
|
925 |
+
self.grad_scale = grad_scale
|
926 |
+
|
927 |
+
def forward(self, x: Tensor) -> Tensor:
|
928 |
+
"""
|
929 |
+
In the forward pass, this function just returns the input unmodified.
|
930 |
+
In the backward pass, it will modify the gradients to ensure that the
|
931 |
+
distribution in each group has close to (lambda times I) as the covariance
|
932 |
+
after mean subtraction, with the same lambda across groups.
|
933 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
934 |
+
constraint.
|
935 |
+
|
936 |
+
Args:
|
937 |
+
x: the input of shape (*, num_channels)
|
938 |
+
|
939 |
+
Returns:
|
940 |
+
x, unmodified. You should make sure
|
941 |
+
you use the returned value, or the graph will be freed
|
942 |
+
and nothing will happen in backprop.
|
943 |
+
"""
|
944 |
+
if (
|
945 |
+
not x.requires_grad
|
946 |
+
or random.random() > self.prob
|
947 |
+
or self.grad_scale == 0
|
948 |
+
):
|
949 |
+
return _no_op(x)
|
950 |
+
else:
|
951 |
+
if hasattr(self, "min_prob") and random.random() < 0.25:
|
952 |
+
# occasionally switch between min_prob and max_prob, based on whether
|
953 |
+
# we are above or below the threshold.
|
954 |
+
if (
|
955 |
+
_whitening_metric(x.to(torch.float32), self.num_groups)
|
956 |
+
> self.whitening_limit
|
957 |
+
):
|
958 |
+
# there would be a change to the grad.
|
959 |
+
self.prob = self.max_prob
|
960 |
+
else:
|
961 |
+
self.prob = self.min_prob
|
962 |
+
|
963 |
+
return WhiteningPenaltyFunction.apply(
|
964 |
+
x, self.num_groups, self.whitening_limit, self.grad_scale
|
965 |
+
)
|
966 |
+
|
967 |
+
|
968 |
+
class WithLoss(torch.autograd.Function):
|
969 |
+
@staticmethod
|
970 |
+
def forward(ctx, x: Tensor, y: Tensor):
|
971 |
+
ctx.y_shape = y.shape
|
972 |
+
return x
|
973 |
+
|
974 |
+
@staticmethod
|
975 |
+
def backward(ctx, ans_grad: Tensor):
|
976 |
+
return ans_grad, torch.ones(
|
977 |
+
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
|
978 |
+
)
|
979 |
+
|
980 |
+
|
981 |
+
def with_loss(x, y):
|
982 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
983 |
+
return x
|
984 |
+
# returns x but adds y.sum() to the loss function.
|
985 |
+
return WithLoss.apply(x, y)
|
986 |
+
|
987 |
+
|
988 |
+
def _no_op(x: Tensor) -> Tensor:
|
989 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
990 |
+
return x
|
991 |
+
else:
|
992 |
+
# a no-op function that will have a node in the autograd graph,
|
993 |
+
# to avoid certain bugs relating to backward hooks
|
994 |
+
return x.chunk(1, dim=-1)[0]
|
995 |
+
|
996 |
+
|
997 |
+
class Identity(torch.nn.Module):
|
998 |
+
def __init__(self):
|
999 |
+
super(Identity, self).__init__()
|
1000 |
+
|
1001 |
+
def forward(self, x):
|
1002 |
+
return _no_op(x)
|
1003 |
+
|
1004 |
+
|
1005 |
+
class MaxEig(torch.nn.Module):
|
1006 |
+
"""
|
1007 |
+
Modifies the backpropped derivatives of a function to try to discourage
|
1008 |
+
that any given direction in activation space accounts for more than
|
1009 |
+
a specified proportion of the covariance (e.g. 0.2).
|
1010 |
+
|
1011 |
+
|
1012 |
+
Args:
|
1013 |
+
num_channels: the number of channels
|
1014 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
1015 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
1016 |
+
max_var_per_eig: the maximum proportion of the variance of the
|
1017 |
+
features/channels, after mean subtraction, that can come from
|
1018 |
+
any given eigenvalue.
|
1019 |
+
min_prob: the minimum probability with which we apply this during any invocation
|
1020 |
+
of forward(), assuming last time we applied the constraint it was
|
1021 |
+
not active; supplied for speed.
|
1022 |
+
scale: determines the scale with which we modify the gradients, relative
|
1023 |
+
to the existing / unmodified gradients
|
1024 |
+
"""
|
1025 |
+
|
1026 |
+
def __init__(
|
1027 |
+
self,
|
1028 |
+
num_channels: int,
|
1029 |
+
channel_dim: int,
|
1030 |
+
max_var_per_eig: float = 0.2,
|
1031 |
+
min_prob: float = 0.01,
|
1032 |
+
scale: float = 0.01,
|
1033 |
+
):
|
1034 |
+
super(MaxEig, self).__init__()
|
1035 |
+
self.num_channels = num_channels
|
1036 |
+
self.channel_dim = channel_dim
|
1037 |
+
self.scale = scale
|
1038 |
+
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
1039 |
+
self.max_var_per_eig = max_var_per_eig
|
1040 |
+
|
1041 |
+
# we figure out the dominant direction using the power method: starting with
|
1042 |
+
# a random vector, keep multiplying by the covariance and renormalizing.
|
1043 |
+
with torch.no_grad():
|
1044 |
+
# arbitrary.. would use randn() but want to leave the rest of the model's
|
1045 |
+
# random parameters unchanged for comparison
|
1046 |
+
direction = torch.arange(num_channels).to(torch.float)
|
1047 |
+
direction = direction / direction.norm()
|
1048 |
+
self.register_buffer("max_eig_direction", direction)
|
1049 |
+
|
1050 |
+
self.min_prob = min_prob
|
1051 |
+
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
1052 |
+
# We'll regress this towards prob, each time we try to apply it and it is not
|
1053 |
+
# active.
|
1054 |
+
self.cur_prob = 1.0
|
1055 |
+
|
1056 |
+
def forward(self, x: Tensor) -> Tensor:
|
1057 |
+
if (
|
1058 |
+
torch.jit.is_scripting()
|
1059 |
+
or self.max_var_per_eig <= 0
|
1060 |
+
or random.random() > self.cur_prob
|
1061 |
+
or torch.jit.is_tracing()
|
1062 |
+
):
|
1063 |
+
return _no_op(x)
|
1064 |
+
|
1065 |
+
with torch.cuda.amp.autocast(enabled=False):
|
1066 |
+
eps = 1.0e-20
|
1067 |
+
orig_x = x
|
1068 |
+
x = x.to(torch.float32)
|
1069 |
+
with torch.no_grad():
|
1070 |
+
x = x.transpose(self.channel_dim, -1).reshape(
|
1071 |
+
-1, self.num_channels
|
1072 |
+
)
|
1073 |
+
x = x - x.mean(dim=0)
|
1074 |
+
new_direction, coeffs = self._find_direction_coeffs(
|
1075 |
+
x, self.max_eig_direction
|
1076 |
+
)
|
1077 |
+
x_var = (x ** 2).mean()
|
1078 |
+
x_residual = x - coeffs * new_direction
|
1079 |
+
x_residual_var = (x_residual ** 2).mean()
|
1080 |
+
|
1081 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
1082 |
+
# by the top eigen-direction.
|
1083 |
+
variance_proportion = (x_var - x_residual_var) / (
|
1084 |
+
x_var + 1.0e-20
|
1085 |
+
)
|
1086 |
+
|
1087 |
+
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
1088 |
+
self._set_direction(
|
1089 |
+
0.1 * self.max_eig_direction + new_direction
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
if random.random() < 0.01 or __name__ == "__main__":
|
1093 |
+
logging.info(
|
1094 |
+
f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
if variance_proportion >= self.max_var_per_eig:
|
1098 |
+
# The constraint is active. Note, we should quite rarely
|
1099 |
+
# reach here, only near the beginning of training if we are
|
1100 |
+
# starting to diverge, should this constraint be active.
|
1101 |
+
cur_prob = self.cur_prob
|
1102 |
+
self.cur_prob = (
|
1103 |
+
1.0 # next time, do the update with probability 1.0.
|
1104 |
+
)
|
1105 |
+
return MaxEigLimiterFunction.apply(
|
1106 |
+
orig_x, coeffs, new_direction, self.channel_dim, self.scale
|
1107 |
+
)
|
1108 |
+
else:
|
1109 |
+
# let self.cur_prob exponentially approach self.min_prob, as
|
1110 |
+
# long as the constraint is inactive.
|
1111 |
+
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
|
1112 |
+
return orig_x
|
1113 |
+
|
1114 |
+
def _set_direction(self, direction: Tensor):
|
1115 |
+
"""
|
1116 |
+
Sets self.max_eig_direction to a normalized version of `direction`
|
1117 |
+
"""
|
1118 |
+
direction = direction.detach()
|
1119 |
+
direction = direction / direction.norm()
|
1120 |
+
direction_sum = direction.sum().item()
|
1121 |
+
if direction_sum - direction_sum == 0: # no inf/nan
|
1122 |
+
self.max_eig_direction[:] = direction
|
1123 |
+
else:
|
1124 |
+
logging.info(
|
1125 |
+
f"Warning: sum of direction in MaxEig is {direction_sum}, "
|
1126 |
+
"num_channels={self.num_channels}, channel_dim={self.channel_dim}"
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
def _find_direction_coeffs(
|
1130 |
+
self, x: Tensor, prev_direction: Tensor
|
1131 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
1132 |
+
"""
|
1133 |
+
Figure out (an approximation to) the proportion of the variance of a set of
|
1134 |
+
feature vectors that can be attributed to the top eigen-direction.
|
1135 |
+
Args:
|
1136 |
+
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
1137 |
+
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
1138 |
+
of the top eigen-direction, or a random direction if this is the first
|
1139 |
+
iteration. Does not have to be normalized, but should be nonzero.
|
1140 |
+
|
1141 |
+
Returns: (cur_direction, coeffs), where:
|
1142 |
+
cur_direction: a Tensor of shape (num_channels,) that is the current
|
1143 |
+
estimate of the top eigen-direction.
|
1144 |
+
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
1145 |
+
approximately minimizes, (x - coeffs * cur_direction).norm()
|
1146 |
+
"""
|
1147 |
+
(num_frames, num_channels) = x.shape
|
1148 |
+
assert num_channels > 1 and num_frames > 1
|
1149 |
+
assert prev_direction.shape == (num_channels,)
|
1150 |
+
# `coeffs` are the coefficients of `prev_direction` in x.
|
1151 |
+
# actually represent the coeffs up to a constant positive factor.
|
1152 |
+
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
1153 |
+
cur_direction = (x * coeffs).sum(dim=0) / (
|
1154 |
+
(coeffs ** 2).sum() + 1.0e-20
|
1155 |
+
)
|
1156 |
+
return cur_direction, coeffs
|
1157 |
+
|
1158 |
+
|
1159 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
1160 |
+
"""
|
1161 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
1162 |
+
This is a definition, originally motivated by its close numerical
|
1163 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
1164 |
+
|
1165 |
+
Memory-efficient derivative computation:
|
1166 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
1167 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
1168 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
1169 |
+
double_swish'(x) = x * s'(x) + s(x).
|
1170 |
+
= x * s(x) * (1-s(x)) + s(x).
|
1171 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
1172 |
+
... so we just need to remember s(x) but not x itself.
|
1173 |
+
"""
|
1174 |
+
|
1175 |
+
@staticmethod
|
1176 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
1177 |
+
requires_grad = x.requires_grad
|
1178 |
+
x_dtype = x.dtype
|
1179 |
+
if x.dtype == torch.float16:
|
1180 |
+
x = x.to(torch.float32)
|
1181 |
+
|
1182 |
+
s = torch.sigmoid(x - 1.0)
|
1183 |
+
y = x * s
|
1184 |
+
|
1185 |
+
if requires_grad:
|
1186 |
+
deriv = y * (1 - s) + s
|
1187 |
+
# notes on derivative of x * sigmoid(x - 1):
|
1188 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
1189 |
+
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
1190 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
1191 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
1192 |
+
# floors), should be expectation-preserving.
|
1193 |
+
floor = -0.043637
|
1194 |
+
ceil = 1.2
|
1195 |
+
d_scaled = (deriv - floor) * (
|
1196 |
+
255.0 / (ceil - floor)
|
1197 |
+
) + torch.rand_like(deriv)
|
1198 |
+
if __name__ == "__main__":
|
1199 |
+
# for self-testing only.
|
1200 |
+
assert d_scaled.min() >= 0.0
|
1201 |
+
assert d_scaled.max() < 256.0
|
1202 |
+
d_int = d_scaled.to(torch.uint8)
|
1203 |
+
ctx.save_for_backward(d_int)
|
1204 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
1205 |
+
y = y.to(torch.float16)
|
1206 |
+
return y
|
1207 |
+
|
1208 |
+
@staticmethod
|
1209 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
1210 |
+
(d,) = ctx.saved_tensors
|
1211 |
+
# the same constants as used in forward pass.
|
1212 |
+
floor = -0.043637
|
1213 |
+
ceil = 1.2
|
1214 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
1215 |
+
return y_grad * d
|
1216 |
+
|
1217 |
+
|
1218 |
+
class DoubleSwish(torch.nn.Module):
|
1219 |
+
def forward(self, x: Tensor) -> Tensor:
|
1220 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
1221 |
+
that we approximate closely with x * sigmoid(x-1).
|
1222 |
+
"""
|
1223 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
1224 |
+
return x * torch.sigmoid(x - 1.0)
|
1225 |
+
return DoubleSwishFunction.apply(x)
|
1226 |
+
|
1227 |
+
|
1228 |
+
def BalancedDoubleSwish(
|
1229 |
+
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
1230 |
+
) -> nn.Sequential:
|
1231 |
+
"""
|
1232 |
+
ActivationBalancer -> DoubleSwish
|
1233 |
+
"""
|
1234 |
+
balancer = ActivationBalancer(
|
1235 |
+
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
1236 |
+
)
|
1237 |
+
return nn.Sequential(
|
1238 |
+
balancer,
|
1239 |
+
DoubleSwish(),
|
1240 |
+
)
|
1241 |
+
|
1242 |
+
|
1243 |
+
def _test_max_eig():
|
1244 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1245 |
+
logging.info(f"proportion = {proportion}")
|
1246 |
+
x = torch.randn(100, 128)
|
1247 |
+
direction = torch.randn(128)
|
1248 |
+
coeffs = torch.randn(100, 1)
|
1249 |
+
x += proportion * direction * coeffs
|
1250 |
+
|
1251 |
+
x.requires_grad = True
|
1252 |
+
|
1253 |
+
num_channels = 128
|
1254 |
+
m = MaxEig(
|
1255 |
+
num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
|
1256 |
+
) # grad_scale
|
1257 |
+
|
1258 |
+
for _ in range(4):
|
1259 |
+
y = m(x)
|
1260 |
+
|
1261 |
+
y_grad = torch.randn_like(x)
|
1262 |
+
y.backward(gradient=y_grad)
|
1263 |
+
|
1264 |
+
if proportion < 0.2:
|
1265 |
+
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
1266 |
+
elif proportion > 1.0:
|
1267 |
+
assert not torch.allclose(x.grad, y_grad)
|
1268 |
+
|
1269 |
+
|
1270 |
+
def _test_whiten():
|
1271 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1272 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
1273 |
+
x = torch.randn(100, 128)
|
1274 |
+
direction = torch.randn(128)
|
1275 |
+
coeffs = torch.randn(100, 1)
|
1276 |
+
x += proportion * direction * coeffs
|
1277 |
+
|
1278 |
+
x.requires_grad = True
|
1279 |
+
|
1280 |
+
num_channels = 128
|
1281 |
+
m = Whiten(
|
1282 |
+
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
1283 |
+
) # grad_scale
|
1284 |
+
|
1285 |
+
for _ in range(4):
|
1286 |
+
y = m(x)
|
1287 |
+
|
1288 |
+
y_grad = torch.randn_like(x)
|
1289 |
+
y.backward(gradient=y_grad)
|
1290 |
+
|
1291 |
+
if proportion < 0.2:
|
1292 |
+
assert torch.allclose(x.grad, y_grad)
|
1293 |
+
elif proportion > 1.0:
|
1294 |
+
assert not torch.allclose(x.grad, y_grad)
|
1295 |
+
|
1296 |
+
|
1297 |
+
def _test_activation_balancer_sign():
|
1298 |
+
probs = torch.arange(0, 1, 0.01)
|
1299 |
+
N = 1000
|
1300 |
+
x = 1.0 * (
|
1301 |
+
(2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
|
1302 |
+
)
|
1303 |
+
x = x.detach()
|
1304 |
+
x.requires_grad = True
|
1305 |
+
m = ActivationBalancer(
|
1306 |
+
probs.numel(),
|
1307 |
+
channel_dim=0,
|
1308 |
+
min_positive=0.05,
|
1309 |
+
max_positive=0.95,
|
1310 |
+
max_factor=0.2,
|
1311 |
+
min_abs=0.0,
|
1312 |
+
)
|
1313 |
+
|
1314 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
1315 |
+
|
1316 |
+
y = m(x)
|
1317 |
+
y.backward(gradient=y_grad)
|
1318 |
+
print("_test_activation_balancer_sign: x = ", x)
|
1319 |
+
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
1320 |
+
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
1321 |
+
|
1322 |
+
|
1323 |
+
def _test_activation_balancer_magnitude():
|
1324 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
1325 |
+
N = 1000
|
1326 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
1327 |
+
-1
|
1328 |
+
)
|
1329 |
+
x = x.detach()
|
1330 |
+
x.requires_grad = True
|
1331 |
+
m = ActivationBalancer(
|
1332 |
+
magnitudes.numel(),
|
1333 |
+
channel_dim=0,
|
1334 |
+
min_positive=0.0,
|
1335 |
+
max_positive=1.0,
|
1336 |
+
max_factor=0.2,
|
1337 |
+
min_abs=0.2,
|
1338 |
+
max_abs=0.8,
|
1339 |
+
min_prob=1.0,
|
1340 |
+
)
|
1341 |
+
|
1342 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
1343 |
+
|
1344 |
+
y = m(x)
|
1345 |
+
y.backward(gradient=y_grad)
|
1346 |
+
print("_test_activation_balancer_magnitude: x = ", x)
|
1347 |
+
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
1348 |
+
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
1349 |
+
|
1350 |
+
|
1351 |
+
def _test_basic_norm():
|
1352 |
+
num_channels = 128
|
1353 |
+
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
1354 |
+
|
1355 |
+
x = torch.randn(500, num_channels)
|
1356 |
+
|
1357 |
+
y = m(x)
|
1358 |
+
|
1359 |
+
assert y.shape == x.shape
|
1360 |
+
x_rms = (x ** 2).mean().sqrt()
|
1361 |
+
y_rms = (y ** 2).mean().sqrt()
|
1362 |
+
print("x rms = ", x_rms)
|
1363 |
+
print("y rms = ", y_rms)
|
1364 |
+
assert y_rms < x_rms
|
1365 |
+
assert y_rms > 0.5 * x_rms
|
1366 |
+
|
1367 |
+
|
1368 |
+
def _test_double_swish_deriv():
|
1369 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
1370 |
+
x.requires_grad = True
|
1371 |
+
m = DoubleSwish()
|
1372 |
+
|
1373 |
+
tol = (1.2 - (-0.043637)) / 255.0
|
1374 |
+
torch.autograd.gradcheck(m, x, atol=tol)
|
1375 |
+
|
1376 |
+
# for self-test.
|
1377 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
1378 |
+
x.requires_grad = True
|
1379 |
+
y = m(x)
|
1380 |
+
|
1381 |
+
|
1382 |
+
def _test_softmax():
|
1383 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
1384 |
+
b = a.clone()
|
1385 |
+
a.requires_grad = True
|
1386 |
+
b.requires_grad = True
|
1387 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
1388 |
+
print("a grad = ", a.grad)
|
1389 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
1390 |
+
print("b grad = ", b.grad)
|
1391 |
+
assert torch.allclose(a.grad, b.grad)
|
1392 |
+
|
1393 |
+
|
1394 |
+
if __name__ == "__main__":
|
1395 |
+
logging.getLogger().setLevel(logging.INFO)
|
1396 |
+
torch.set_num_threads(1)
|
1397 |
+
torch.set_num_interop_threads(1)
|
1398 |
+
_test_softmax()
|
1399 |
+
_test_whiten()
|
1400 |
+
_test_max_eig()
|
1401 |
+
_test_activation_balancer_sign()
|
1402 |
+
_test_activation_balancer_magnitude()
|
1403 |
+
_test_basic_norm()
|
1404 |
+
_test_double_swish_deriv()
|
slam_llm/models/vallex/transformers.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import numbers
|
3 |
+
from functools import partial
|
4 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import Tensor, nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from .activation import MultiheadAttention
|
11 |
+
from .scaling import BasicNorm as _BasicNorm
|
12 |
+
|
13 |
+
_shape_t = Union[int, List[int], torch.Size]
|
14 |
+
|
15 |
+
|
16 |
+
class LayerNorm(nn.Module):
|
17 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
18 |
+
normalized_shape: Tuple[int, ...]
|
19 |
+
eps: float
|
20 |
+
elementwise_affine: bool
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
normalized_shape: _shape_t,
|
25 |
+
eps: float = 1e-5,
|
26 |
+
elementwise_affine: bool = True,
|
27 |
+
device=None,
|
28 |
+
dtype=None,
|
29 |
+
) -> None:
|
30 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
31 |
+
super(LayerNorm, self).__init__()
|
32 |
+
if isinstance(normalized_shape, numbers.Integral):
|
33 |
+
# mypy error: incompatible types in assignment
|
34 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
35 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
36 |
+
self.eps = eps
|
37 |
+
self.elementwise_affine = elementwise_affine
|
38 |
+
if self.elementwise_affine:
|
39 |
+
self.weight = nn.Parameter(
|
40 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
41 |
+
)
|
42 |
+
self.bias = nn.Parameter(
|
43 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
44 |
+
)
|
45 |
+
else:
|
46 |
+
self.register_parameter("weight", None)
|
47 |
+
self.register_parameter("bias", None)
|
48 |
+
|
49 |
+
self.reset_parameters()
|
50 |
+
|
51 |
+
def reset_parameters(self) -> None:
|
52 |
+
if self.elementwise_affine:
|
53 |
+
nn.init.ones_(self.weight)
|
54 |
+
nn.init.zeros_(self.bias)
|
55 |
+
|
56 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
57 |
+
if isinstance(input, tuple):
|
58 |
+
input, embedding = input
|
59 |
+
return (
|
60 |
+
F.layer_norm(
|
61 |
+
input,
|
62 |
+
self.normalized_shape,
|
63 |
+
self.weight,
|
64 |
+
self.bias,
|
65 |
+
self.eps,
|
66 |
+
),
|
67 |
+
embedding,
|
68 |
+
)
|
69 |
+
|
70 |
+
assert embedding is None
|
71 |
+
return F.layer_norm(
|
72 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
73 |
+
)
|
74 |
+
|
75 |
+
def extra_repr(self) -> str:
|
76 |
+
return (
|
77 |
+
"{normalized_shape}, eps={eps}, "
|
78 |
+
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
class AdaptiveLayerNorm(nn.Module):
|
83 |
+
r"""Adaptive Layer Normalization"""
|
84 |
+
|
85 |
+
def __init__(self, d_model, norm) -> None:
|
86 |
+
super(AdaptiveLayerNorm, self).__init__()
|
87 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
88 |
+
self.norm = norm
|
89 |
+
self.d_model = d_model
|
90 |
+
self.eps = self.norm.eps
|
91 |
+
|
92 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
93 |
+
if isinstance(input, tuple):
|
94 |
+
input, embedding = input
|
95 |
+
weight, bias = torch.split(
|
96 |
+
self.project_layer(embedding),
|
97 |
+
split_size_or_sections=self.d_model,
|
98 |
+
dim=-1,
|
99 |
+
)
|
100 |
+
return (weight * self.norm(input) + bias, embedding)
|
101 |
+
|
102 |
+
weight, bias = torch.split(
|
103 |
+
self.project_layer(embedding),
|
104 |
+
split_size_or_sections=self.d_model,
|
105 |
+
dim=-1,
|
106 |
+
)
|
107 |
+
return weight * self.norm(input) + bias
|
108 |
+
|
109 |
+
|
110 |
+
class BasicNorm(_BasicNorm):
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
d_model: int,
|
114 |
+
eps: float = 1e-5,
|
115 |
+
device=None,
|
116 |
+
dtype=None,
|
117 |
+
):
|
118 |
+
super(BasicNorm, self).__init__(d_model, eps=eps)
|
119 |
+
|
120 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
121 |
+
if isinstance(input, tuple):
|
122 |
+
input, embedding = input
|
123 |
+
return (
|
124 |
+
super(BasicNorm, self).forward(input),
|
125 |
+
embedding,
|
126 |
+
)
|
127 |
+
|
128 |
+
assert embedding is None
|
129 |
+
return super(BasicNorm, self).forward(input)
|
130 |
+
|
131 |
+
|
132 |
+
class BalancedBasicNorm(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
d_model: int,
|
136 |
+
eps: float = 1e-5,
|
137 |
+
device=None,
|
138 |
+
dtype=None,
|
139 |
+
):
|
140 |
+
super(BalancedBasicNorm, self).__init__()
|
141 |
+
self.balancer = ActivationBalancer(
|
142 |
+
d_model,
|
143 |
+
channel_dim=-1,
|
144 |
+
min_positive=0.45,
|
145 |
+
max_positive=0.55,
|
146 |
+
max_abs=6.0,
|
147 |
+
)
|
148 |
+
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
|
149 |
+
|
150 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
151 |
+
if isinstance(input, tuple):
|
152 |
+
input, embedding = input
|
153 |
+
return self.norm((self.balancer(input), embedding))
|
154 |
+
|
155 |
+
assert embedding is None
|
156 |
+
return self.norm(self.balancer(input))
|
157 |
+
|
158 |
+
|
159 |
+
class IdentityNorm(nn.Module):
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
d_model: int,
|
163 |
+
eps: float = 1e-5,
|
164 |
+
device=None,
|
165 |
+
dtype=None,
|
166 |
+
) -> None:
|
167 |
+
super(IdentityNorm, self).__init__()
|
168 |
+
|
169 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
170 |
+
if isinstance(input, tuple):
|
171 |
+
return input
|
172 |
+
|
173 |
+
assert embedding is None
|
174 |
+
return input
|
175 |
+
|
176 |
+
|
177 |
+
class TransformerEncoderLayer(nn.Module):
|
178 |
+
__constants__ = ["batch_first", "norm_first"]
|
179 |
+
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
d_model: int,
|
183 |
+
nhead: int,
|
184 |
+
dim_feedforward: int = 2048,
|
185 |
+
dropout: float = 0.1,
|
186 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
187 |
+
batch_first: bool = False,
|
188 |
+
norm_first: bool = False,
|
189 |
+
device=None,
|
190 |
+
dtype=None,
|
191 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
192 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
193 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
194 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
195 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
196 |
+
layer_norm_eps: float = 1e-5,
|
197 |
+
adaptive_layer_norm=False,
|
198 |
+
) -> None:
|
199 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
200 |
+
super(TransformerEncoderLayer, self).__init__()
|
201 |
+
self.self_attn = MultiheadAttention(
|
202 |
+
d_model,
|
203 |
+
nhead,
|
204 |
+
dropout=dropout,
|
205 |
+
batch_first=batch_first,
|
206 |
+
linear1_cls=linear1_self_attention_cls,
|
207 |
+
linear2_cls=linear2_self_attention_cls,
|
208 |
+
**factory_kwargs,
|
209 |
+
)
|
210 |
+
|
211 |
+
# Implementation of Feedforward model
|
212 |
+
|
213 |
+
self.dropout = nn.Dropout(dropout)
|
214 |
+
|
215 |
+
self.norm_first = norm_first
|
216 |
+
self.dropout1 = nn.Dropout(dropout)
|
217 |
+
self.dropout2 = nn.Dropout(dropout)
|
218 |
+
|
219 |
+
# Legacy string support for activation function.
|
220 |
+
if isinstance(activation, str):
|
221 |
+
activation = _get_activation_fn(activation)
|
222 |
+
elif isinstance(activation, partial):
|
223 |
+
activation = activation(d_model)
|
224 |
+
|
225 |
+
self.activation = activation
|
226 |
+
|
227 |
+
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
228 |
+
if layer_norm_cls == IdentityNorm:
|
229 |
+
norm2 = BalancedBasicNorm(
|
230 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
norm2 = layer_norm_cls(
|
234 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
235 |
+
)
|
236 |
+
|
237 |
+
self.norm1 = norm1
|
238 |
+
self.linear1 = linear1_feedforward_cls(
|
239 |
+
d_model, dim_feedforward, **factory_kwargs
|
240 |
+
)
|
241 |
+
self.linear2 = linear2_feedforward_cls(
|
242 |
+
dim_feedforward, d_model, **factory_kwargs
|
243 |
+
)
|
244 |
+
self.norm2 = norm2
|
245 |
+
|
246 |
+
|
247 |
+
def __setstate__(self, state):
|
248 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
249 |
+
if not hasattr(self, "activation"):
|
250 |
+
self.activation = F.relu
|
251 |
+
|
252 |
+
def forward(
|
253 |
+
self,
|
254 |
+
src: Tensor,
|
255 |
+
src_mask: Optional[Tensor] = None,
|
256 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
257 |
+
) -> Tensor:
|
258 |
+
r"""Pass the input through the encoder layer.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
src: the sequence to the encoder layer (required).
|
262 |
+
src_mask: the mask for the src sequence (optional).
|
263 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
264 |
+
|
265 |
+
Shape:
|
266 |
+
see the docs in Transformer class.
|
267 |
+
"""
|
268 |
+
x, stage_embedding = src, None
|
269 |
+
is_src_tuple = False
|
270 |
+
if isinstance(src, tuple):
|
271 |
+
x, stage_embedding = src
|
272 |
+
is_src_tuple = True
|
273 |
+
|
274 |
+
if src_key_padding_mask is not None:
|
275 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
276 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
277 |
+
src_key_padding_mask
|
278 |
+
):
|
279 |
+
raise AssertionError(
|
280 |
+
"only bool and floating types of key_padding_mask are supported"
|
281 |
+
)
|
282 |
+
|
283 |
+
if self.norm_first:
|
284 |
+
x = x + self._sa_block(
|
285 |
+
self.norm1(x, stage_embedding),
|
286 |
+
src_mask,
|
287 |
+
src_key_padding_mask,
|
288 |
+
)
|
289 |
+
|
290 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
291 |
+
else:
|
292 |
+
x = self.norm1(
|
293 |
+
x + self._sa_block(x, src_mask, src_key_padding_mask),
|
294 |
+
stage_embedding,
|
295 |
+
)
|
296 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
297 |
+
|
298 |
+
if is_src_tuple:
|
299 |
+
return (x, stage_embedding)
|
300 |
+
return x
|
301 |
+
|
302 |
+
def infer(
|
303 |
+
self,
|
304 |
+
src: Tensor,
|
305 |
+
src_mask: Optional[Tensor] = None,
|
306 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
307 |
+
past_kv: Optional[Tensor] = None,
|
308 |
+
use_cache: bool = False,
|
309 |
+
):
|
310 |
+
x, stage_embedding = src, None
|
311 |
+
is_src_tuple = False
|
312 |
+
if isinstance(src, tuple):
|
313 |
+
x, stage_embedding = src
|
314 |
+
is_src_tuple = True
|
315 |
+
|
316 |
+
if src_key_padding_mask is not None:
|
317 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
318 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
319 |
+
src_key_padding_mask
|
320 |
+
):
|
321 |
+
raise AssertionError(
|
322 |
+
"only bool and floating types of key_padding_mask are supported"
|
323 |
+
)
|
324 |
+
|
325 |
+
if self.norm_first:
|
326 |
+
x_attn_out, kv = self.self_attn.infer(
|
327 |
+
self.norm1(x, stage_embedding),
|
328 |
+
attn_mask=src_mask,
|
329 |
+
key_padding_mask=src_key_padding_mask,
|
330 |
+
need_weights=False,
|
331 |
+
past_kv=past_kv,
|
332 |
+
use_cache=use_cache,
|
333 |
+
)
|
334 |
+
x = x + x_attn_out
|
335 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
336 |
+
|
337 |
+
if is_src_tuple:
|
338 |
+
return (x, stage_embedding)
|
339 |
+
return (x, kv)
|
340 |
+
|
341 |
+
# self-attention block
|
342 |
+
def _sa_block(
|
343 |
+
self,
|
344 |
+
x: Tensor,
|
345 |
+
attn_mask: Optional[Tensor],
|
346 |
+
key_padding_mask: Optional[Tensor],
|
347 |
+
) -> Tensor:
|
348 |
+
x = self.self_attn(
|
349 |
+
x,
|
350 |
+
x,
|
351 |
+
x,
|
352 |
+
attn_mask=attn_mask,
|
353 |
+
key_padding_mask=key_padding_mask,
|
354 |
+
need_weights=False,
|
355 |
+
)[0]
|
356 |
+
return self.dropout1(x)
|
357 |
+
|
358 |
+
# feed forward block
|
359 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
360 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
361 |
+
return self.dropout2(x)
|
362 |
+
|
363 |
+
|
364 |
+
class TransformerEncoder(nn.Module):
|
365 |
+
__constants__ = ["norm"]
|
366 |
+
|
367 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
368 |
+
super(TransformerEncoder, self).__init__()
|
369 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
370 |
+
self.num_layers = num_layers
|
371 |
+
self.norm = norm
|
372 |
+
|
373 |
+
def forward(
|
374 |
+
self,
|
375 |
+
src: Tensor,
|
376 |
+
mask: Optional[Tensor] = None,
|
377 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
378 |
+
return_layer_states: bool = False,
|
379 |
+
) -> Tensor:
|
380 |
+
output = src
|
381 |
+
for i, mod in enumerate(self.layers):
|
382 |
+
output = mod(
|
383 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
384 |
+
)
|
385 |
+
# print(i, output.mean())
|
386 |
+
if self.norm is not None:
|
387 |
+
output = self.norm(output)
|
388 |
+
|
389 |
+
return output
|
390 |
+
|
391 |
+
def infer(
|
392 |
+
self,
|
393 |
+
src: Tensor,
|
394 |
+
mask: Optional[Tensor] = None,
|
395 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
396 |
+
return_layer_states: bool = False,
|
397 |
+
past_kv: Optional[Tensor] = None,
|
398 |
+
use_cache: bool = False,
|
399 |
+
):
|
400 |
+
if past_kv is None:
|
401 |
+
past_length = 0
|
402 |
+
past_kv = tuple([None] * self.num_layers)
|
403 |
+
else:
|
404 |
+
past_length = past_kv[0][0].size(-2)
|
405 |
+
new_kv = () if use_cache else None
|
406 |
+
output = src
|
407 |
+
for i, (mod, past_layer_kv) in enumerate(zip(self.layers, past_kv)):
|
408 |
+
output, kv = mod.infer(
|
409 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
|
410 |
+
)
|
411 |
+
# print(i, output.mean())
|
412 |
+
if use_cache:
|
413 |
+
new_kv = new_kv + (kv,)
|
414 |
+
|
415 |
+
if self.norm is not None:
|
416 |
+
output = self.norm(output)
|
417 |
+
|
418 |
+
return output, new_kv
|
419 |
+
|
420 |
+
|
421 |
+
class TransformerDecoderLayer(nn.Module):
|
422 |
+
__constants__ = ["batch_first", "norm_first"]
|
423 |
+
|
424 |
+
def __init__(
|
425 |
+
self,
|
426 |
+
d_model: int,
|
427 |
+
nhead: int,
|
428 |
+
dim_feedforward: int = 2048,
|
429 |
+
dropout: float = 0.1,
|
430 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
431 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
432 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
433 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
434 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
435 |
+
batch_first: bool = False,
|
436 |
+
norm_first: bool = False,
|
437 |
+
device=None,
|
438 |
+
dtype=None,
|
439 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
440 |
+
layer_norm_eps: float = 1e-5,
|
441 |
+
adaptive_layer_norm=False,
|
442 |
+
) -> None:
|
443 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
444 |
+
super(TransformerDecoderLayer, self).__init__()
|
445 |
+
self.self_attn = MultiheadAttention(
|
446 |
+
d_model,
|
447 |
+
nhead,
|
448 |
+
dropout=dropout,
|
449 |
+
batch_first=batch_first,
|
450 |
+
linear1_cls=linear1_self_attention_cls,
|
451 |
+
linear2_cls=linear2_self_attention_cls,
|
452 |
+
**factory_kwargs,
|
453 |
+
)
|
454 |
+
self.multihead_attn = MultiheadAttention(
|
455 |
+
d_model,
|
456 |
+
nhead,
|
457 |
+
dropout=dropout,
|
458 |
+
batch_first=batch_first,
|
459 |
+
linear1_cls=linear1_self_attention_cls,
|
460 |
+
linear2_cls=linear2_self_attention_cls,
|
461 |
+
**factory_kwargs,
|
462 |
+
)
|
463 |
+
# Implementation of Feedforward model
|
464 |
+
self.linear1 = linear1_feedforward_cls(
|
465 |
+
d_model, dim_feedforward, **factory_kwargs
|
466 |
+
)
|
467 |
+
self.dropout = nn.Dropout(dropout)
|
468 |
+
self.linear2 = linear2_feedforward_cls(
|
469 |
+
dim_feedforward, d_model, **factory_kwargs
|
470 |
+
)
|
471 |
+
|
472 |
+
self.norm_first = norm_first
|
473 |
+
self.dropout1 = nn.Dropout(dropout)
|
474 |
+
self.dropout2 = nn.Dropout(dropout)
|
475 |
+
self.dropout3 = nn.Dropout(dropout)
|
476 |
+
|
477 |
+
# Legacy string support for activation function.
|
478 |
+
if isinstance(activation, str):
|
479 |
+
self.activation = _get_activation_fn(activation)
|
480 |
+
elif isinstance(activation, partial):
|
481 |
+
self.activation = activation(d_model)
|
482 |
+
else:
|
483 |
+
self.activation = activation
|
484 |
+
|
485 |
+
if adaptive_layer_norm:
|
486 |
+
norm1 = layer_norm_cls(
|
487 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
488 |
+
)
|
489 |
+
norm2 = layer_norm_cls(
|
490 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
491 |
+
)
|
492 |
+
norm3 = layer_norm_cls(
|
493 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
494 |
+
)
|
495 |
+
|
496 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
497 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
498 |
+
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
|
499 |
+
else:
|
500 |
+
self.norm1 = layer_norm_cls(
|
501 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
502 |
+
)
|
503 |
+
self.norm2 = layer_norm_cls(
|
504 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
505 |
+
)
|
506 |
+
if layer_norm_cls == IdentityNorm:
|
507 |
+
self.norm3 = BalancedBasicNorm(
|
508 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
509 |
+
)
|
510 |
+
else:
|
511 |
+
self.norm3 = layer_norm_cls(
|
512 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
513 |
+
)
|
514 |
+
|
515 |
+
def forward(
|
516 |
+
self,
|
517 |
+
tgt: Tensor,
|
518 |
+
memory: Tensor,
|
519 |
+
tgt_mask: Optional[Tensor] = None,
|
520 |
+
memory_mask: Optional[Tensor] = None,
|
521 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
522 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
523 |
+
) -> Tensor:
|
524 |
+
tgt_is_tuple = False
|
525 |
+
if isinstance(tgt, tuple):
|
526 |
+
x, stage_embedding = tgt
|
527 |
+
tgt_is_tuple = True
|
528 |
+
else:
|
529 |
+
x, stage_embedding = tgt, None
|
530 |
+
|
531 |
+
if self.norm_first:
|
532 |
+
x = x + self._sa_block(
|
533 |
+
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
|
534 |
+
)
|
535 |
+
x = x + self._mha_block(
|
536 |
+
self.norm2(x, stage_embedding),
|
537 |
+
memory,
|
538 |
+
memory_mask,
|
539 |
+
memory_key_padding_mask,
|
540 |
+
)
|
541 |
+
x = x + self._ff_block(self.norm3(x, stage_embedding))
|
542 |
+
else:
|
543 |
+
x = self.norm1(
|
544 |
+
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
|
545 |
+
stage_embedding,
|
546 |
+
)
|
547 |
+
x = self.norm2(
|
548 |
+
x
|
549 |
+
+ self._mha_block(
|
550 |
+
x, memory, memory_mask, memory_key_padding_mask
|
551 |
+
),
|
552 |
+
stage_embedding,
|
553 |
+
)
|
554 |
+
x = self.norm3(x + self._ff_block(x), stage_embedding)
|
555 |
+
|
556 |
+
if tgt_is_tuple:
|
557 |
+
return (x, stage_embedding)
|
558 |
+
return x
|
559 |
+
|
560 |
+
# self-attention block
|
561 |
+
def _sa_block(
|
562 |
+
self,
|
563 |
+
x: Tensor,
|
564 |
+
attn_mask: Optional[Tensor],
|
565 |
+
key_padding_mask: Optional[Tensor],
|
566 |
+
) -> Tensor:
|
567 |
+
x = self.self_attn(
|
568 |
+
x,
|
569 |
+
x,
|
570 |
+
x,
|
571 |
+
attn_mask=attn_mask,
|
572 |
+
key_padding_mask=key_padding_mask,
|
573 |
+
need_weights=False,
|
574 |
+
)[0]
|
575 |
+
return self.dropout1(x)
|
576 |
+
|
577 |
+
# multihead attention block
|
578 |
+
def _mha_block(
|
579 |
+
self,
|
580 |
+
x: Tensor,
|
581 |
+
mem: Tensor,
|
582 |
+
attn_mask: Optional[Tensor],
|
583 |
+
key_padding_mask: Optional[Tensor],
|
584 |
+
) -> Tensor:
|
585 |
+
x = self.multihead_attn(
|
586 |
+
x,
|
587 |
+
mem,
|
588 |
+
mem,
|
589 |
+
attn_mask=attn_mask,
|
590 |
+
key_padding_mask=key_padding_mask,
|
591 |
+
need_weights=False,
|
592 |
+
)[0]
|
593 |
+
return self.dropout2(x)
|
594 |
+
|
595 |
+
# feed forward block
|
596 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
597 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
598 |
+
return self.dropout3(x)
|
599 |
+
|
600 |
+
|
601 |
+
def _get_clones(module, N):
|
602 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
603 |
+
|
604 |
+
|
605 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
606 |
+
if activation == "relu":
|
607 |
+
return F.relu
|
608 |
+
elif activation == "gelu":
|
609 |
+
return F.gelu
|
610 |
+
|
611 |
+
raise RuntimeError(
|
612 |
+
"activation should be relu/gelu, not {}".format(activation)
|
613 |
+
)
|
slam_llm/models/vallex/vallex_config.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.configuration_utils import PretrainedConfig
|
2 |
+
from transformers.utils import logging
|
3 |
+
from fairseq.data import Dictionary
|
4 |
+
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
|
5 |
+
|
6 |
+
logger = logging.get_logger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class VallexConfig(PretrainedConfig):
|
11 |
+
|
12 |
+
model_type = "vallex"
|
13 |
+
|
14 |
+
def __init__(self,
|
15 |
+
n_layer=24,
|
16 |
+
n_head=16,
|
17 |
+
n_dim=1024,
|
18 |
+
prefix_mode=1,
|
19 |
+
num_quantizers=8,
|
20 |
+
sample_rate=24000,
|
21 |
+
ar_at_dict="",
|
22 |
+
ar_st_dict="",
|
23 |
+
nar_at_dict="",
|
24 |
+
nar_st_dict="",
|
25 |
+
nar_scale_factor=1.0,
|
26 |
+
prepend_bos=True,
|
27 |
+
norm_first=True,
|
28 |
+
eps=0.0,
|
29 |
+
only_ar=False,
|
30 |
+
only_nar=False,
|
31 |
+
**kwargs
|
32 |
+
):
|
33 |
+
self.n_layer = n_layer
|
34 |
+
self.n_head = n_head
|
35 |
+
self.n_dim = n_dim
|
36 |
+
self.prefix_mode = prefix_mode
|
37 |
+
self.num_quantizers = num_quantizers
|
38 |
+
self.sample_rate = sample_rate
|
39 |
+
self.nar_scale_factor = nar_scale_factor
|
40 |
+
self.prepend_bos = prepend_bos
|
41 |
+
self.norm_first = norm_first
|
42 |
+
|
43 |
+
self.ar_at_dict = ar_at_dict
|
44 |
+
self.ar_st_dict = ar_st_dict
|
45 |
+
self.nar_at_dict = nar_at_dict
|
46 |
+
self.nar_st_dict = nar_st_dict
|
47 |
+
self.eps = eps
|
48 |
+
self.only_ar = only_ar
|
49 |
+
self.only_nar = only_nar
|
50 |
+
|
51 |
+
super().__init__(
|
52 |
+
**kwargs
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
AutoConfig.register("vallex", VallexConfig)
|
slam_llm/models/vallex/vallex_model.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Dict, Iterator, List, Tuple, Union
|
3 |
+
from fairseq import utils
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import math
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from fairseq.data import Dictionary
|
10 |
+
from src.slam_llm.models.vallex.transformers import (
|
11 |
+
LayerNorm,
|
12 |
+
TransformerEncoder,
|
13 |
+
TransformerEncoderLayer,
|
14 |
+
)
|
15 |
+
from src.slam_llm.models.vallex.vallex_config import VallexConfig
|
16 |
+
from transformers.modeling_utils import PreTrainedModel
|
17 |
+
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
|
18 |
+
from dataclasses import dataclass
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class ModelOutput:
|
22 |
+
logits: torch.Tensor
|
23 |
+
loss: torch.Tensor
|
24 |
+
acc: torch.Tensor
|
25 |
+
|
26 |
+
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True, scale=1, prob_mask=None):
|
27 |
+
if target.dim() == lprobs.dim() - 1:
|
28 |
+
target = target.unsqueeze(-1)
|
29 |
+
if prob_mask is not None:
|
30 |
+
lprobs = lprobs.masked_fill(prob_mask, 0.0)
|
31 |
+
n_class = (1-prob_mask.float()).sum()
|
32 |
+
else:
|
33 |
+
n_class = lprobs.size(-1)
|
34 |
+
nll_loss = -lprobs.gather(dim=-1, index=target)
|
35 |
+
# nll_loss = nll_loss * scale
|
36 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True) * scale
|
37 |
+
if ignore_index is not None:
|
38 |
+
pad_mask = target.eq(ignore_index)
|
39 |
+
nll_loss.masked_fill_(pad_mask, 0.0)
|
40 |
+
smooth_loss.masked_fill_(pad_mask, 0.0)
|
41 |
+
pad_mask_float = (1 - pad_mask.to(torch.float)).sum()
|
42 |
+
else:
|
43 |
+
nll_loss = nll_loss.squeeze(-1)
|
44 |
+
smooth_loss = smooth_loss.squeeze(-1)
|
45 |
+
if reduce:
|
46 |
+
nll_loss = nll_loss.sum()
|
47 |
+
smooth_loss = smooth_loss.sum()
|
48 |
+
eps_i = epsilon / (n_class - 1)
|
49 |
+
loss = (1.0 - epsilon - eps_i) * nll_loss + \
|
50 |
+
eps_i * smooth_loss
|
51 |
+
return loss / pad_mask_float, nll_loss / pad_mask_float
|
52 |
+
|
53 |
+
|
54 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
55 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
56 |
+
super().__init__()
|
57 |
+
self.embedding_dim = embedding_dim
|
58 |
+
self.padding_idx = padding_idx if padding_idx is not None else 0
|
59 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
60 |
+
init_size, embedding_dim, padding_idx
|
61 |
+
)
|
62 |
+
self.onnx_trace = False
|
63 |
+
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
64 |
+
self.max_positions = int(1e5)
|
65 |
+
|
66 |
+
def prepare_for_onnx_export_(self):
|
67 |
+
self.onnx_trace = True
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def get_embedding(
|
71 |
+
num_embeddings: int, embedding_dim: int, padding_idx = None
|
72 |
+
):
|
73 |
+
half_dim = embedding_dim // 2
|
74 |
+
emb = math.log(10000) / (half_dim - 1)
|
75 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
76 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
|
77 |
+
1
|
78 |
+
) * emb.unsqueeze(0)
|
79 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
|
80 |
+
num_embeddings, -1
|
81 |
+
)
|
82 |
+
if embedding_dim % 2 == 1:
|
83 |
+
# zero pad
|
84 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
85 |
+
if padding_idx is not None:
|
86 |
+
emb[padding_idx, :] = 0
|
87 |
+
return emb
|
88 |
+
|
89 |
+
def forward(
|
90 |
+
self,
|
91 |
+
input,
|
92 |
+
incremental_state = None,
|
93 |
+
timestep = None,
|
94 |
+
positions = None,
|
95 |
+
):
|
96 |
+
bspair = torch.onnx.operators.shape_as_tensor(input)
|
97 |
+
bsz, seq_len = bspair[0], bspair[1]
|
98 |
+
max_pos = self.padding_idx + 1 + seq_len
|
99 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
100 |
+
# recompute/expand embeddings if needed
|
101 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
102 |
+
max_pos, self.embedding_dim, self.padding_idx
|
103 |
+
)
|
104 |
+
self.weights = self.weights.to(self._float_tensor)
|
105 |
+
|
106 |
+
if incremental_state is not None:
|
107 |
+
# positions is the same for every token when decoding a single step
|
108 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
109 |
+
if self.onnx_trace:
|
110 |
+
return (
|
111 |
+
self.weights.index_select(index=self.padding_idx + pos, dim=0)
|
112 |
+
.unsqueeze(1)
|
113 |
+
.repeat(bsz, 1, 1)
|
114 |
+
)
|
115 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
116 |
+
|
117 |
+
positions = utils.make_positions(
|
118 |
+
input, self.padding_idx, onnx_trace=self.onnx_trace
|
119 |
+
)
|
120 |
+
if self.onnx_trace:
|
121 |
+
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
|
122 |
+
embedding_shape = torch.cat(
|
123 |
+
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
|
124 |
+
)
|
125 |
+
embeddings = torch.onnx.operators.reshape_from_tensor_shape(
|
126 |
+
flat_embeddings, embedding_shape
|
127 |
+
)
|
128 |
+
return embeddings
|
129 |
+
return (
|
130 |
+
self.weights.index_select(0, positions.view(-1))
|
131 |
+
.view(bsz, seq_len, -1)
|
132 |
+
.detach()
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
class Transpose(nn.Identity):
|
137 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
138 |
+
return input.transpose(1, 2)
|
139 |
+
|
140 |
+
|
141 |
+
class VALLF(PreTrainedModel):
|
142 |
+
config_class = VallexConfig
|
143 |
+
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
config: VallexConfig
|
147 |
+
):
|
148 |
+
super().__init__(config)
|
149 |
+
|
150 |
+
self.ar_at_dict = Dictionary.load(self.config.ar_at_dict)
|
151 |
+
self.ar_st_dict = Dictionary.load(self.config.ar_st_dict)
|
152 |
+
self.nar_at_dict = Dictionary.load(self.config.nar_at_dict)
|
153 |
+
self.nar_st_dict = Dictionary.load(self.config.nar_st_dict)
|
154 |
+
|
155 |
+
self.ar_at_dict.tts_flag = self.ar_at_dict.add_symbol("<TTS>")
|
156 |
+
self.ar_st_dict.asr_flag = self.ar_st_dict.add_symbol("<ASR>")
|
157 |
+
self.ar_st_dict.mt_flag = self.ar_st_dict.add_symbol("<MT>")
|
158 |
+
|
159 |
+
self.padding_idx = self.ar_at_dict.pad()
|
160 |
+
self.config = config
|
161 |
+
d_model = self.config.n_dim
|
162 |
+
nar_scale_factor = self.config.nar_scale_factor
|
163 |
+
prepend_bos = self.config.prepend_bos
|
164 |
+
|
165 |
+
norm_first = self.config.norm_first
|
166 |
+
num_layers = self.config.n_layer
|
167 |
+
self.NUM_AUDIO_TOKENS = self.ar_at_dict.eos()
|
168 |
+
|
169 |
+
nar_d_model = int(d_model * nar_scale_factor)
|
170 |
+
|
171 |
+
self.ar_text_embedding = nn.Embedding(len(self.ar_st_dict), d_model, self.ar_st_dict.pad()) # W_x
|
172 |
+
if config.only_ar:
|
173 |
+
pass
|
174 |
+
else:
|
175 |
+
self.nar_text_embedding = nn.Embedding(len(self.nar_st_dict), d_model, self.nar_st_dict.pad())
|
176 |
+
|
177 |
+
# ID self.NUM_AUDIO_TOKENS -> PAD
|
178 |
+
# ID self.NUM_AUDIO_TOKENS + 1 -> BOS
|
179 |
+
self.ar_audio_prepend_bos = prepend_bos
|
180 |
+
self.ar_audio_embedding = EncodecDecoderLstm(
|
181 |
+
dictionary=self.ar_at_dict, emb_dim=d_model
|
182 |
+
)
|
183 |
+
|
184 |
+
self.ar_text_prenet = nn.Identity()
|
185 |
+
self.ar_audio_prenet = nn.Identity()
|
186 |
+
|
187 |
+
self.ar_text_position = SinusoidalPositionalEmbedding(
|
188 |
+
d_model,
|
189 |
+
padding_idx=self.ar_at_dict.pad(),
|
190 |
+
init_size=1024+self.ar_at_dict.pad()+1
|
191 |
+
)
|
192 |
+
self.ar_audio_position = SinusoidalPositionalEmbedding(
|
193 |
+
d_model,
|
194 |
+
padding_idx=self.ar_at_dict.pad(),
|
195 |
+
init_size=1024+self.ar_at_dict.pad()+1
|
196 |
+
)
|
197 |
+
|
198 |
+
self.ar_decoder = TransformerEncoder(
|
199 |
+
TransformerEncoderLayer(
|
200 |
+
d_model,
|
201 |
+
self.config.n_head,
|
202 |
+
dim_feedforward=d_model * 4,
|
203 |
+
dropout=0.1,
|
204 |
+
batch_first=True,
|
205 |
+
norm_first=norm_first,
|
206 |
+
),
|
207 |
+
num_layers=num_layers,
|
208 |
+
norm=LayerNorm(d_model) if norm_first else None,
|
209 |
+
)
|
210 |
+
self.ar_predict_layer = nn.Linear(
|
211 |
+
d_model, len(self.ar_at_dict), bias=False
|
212 |
+
)
|
213 |
+
|
214 |
+
self.rng = random.Random(0)
|
215 |
+
self.num_heads = self.config.n_head
|
216 |
+
self.prefix_mode = self.config.prefix_mode
|
217 |
+
self.num_quantizers = self.config.num_quantizers
|
218 |
+
|
219 |
+
assert self.num_quantizers >= 1
|
220 |
+
if config.only_ar:
|
221 |
+
pass
|
222 |
+
else:
|
223 |
+
if self.num_quantizers > 1:
|
224 |
+
self.nar_audio_embeddings = NATEncodecDecoderLstm(
|
225 |
+
codecs=[0, 1, 2, 3, 4, 5, 6, 7], dictionary=self.nar_at_dict, emb_dim=d_model
|
226 |
+
) # W_a
|
227 |
+
|
228 |
+
self.nar_text_prenet = nn.Identity()
|
229 |
+
self.nar_audio_prenet = nn.Identity()
|
230 |
+
|
231 |
+
self.nar_text_position = SinusoidalPositionalEmbedding(
|
232 |
+
d_model,
|
233 |
+
padding_idx=self.nar_at_dict.pad(),
|
234 |
+
init_size=1024+self.nar_at_dict.pad()+1
|
235 |
+
)
|
236 |
+
self.nar_audio_position = SinusoidalPositionalEmbedding(
|
237 |
+
d_model,
|
238 |
+
padding_idx=self.nar_at_dict.pad(),
|
239 |
+
init_size=1024+self.nar_at_dict.pad()+1
|
240 |
+
)
|
241 |
+
|
242 |
+
self.nar_decoder = TransformerEncoder(
|
243 |
+
TransformerEncoderLayer(
|
244 |
+
nar_d_model,
|
245 |
+
int(self.num_heads * nar_scale_factor),
|
246 |
+
dim_feedforward=nar_d_model * 4,
|
247 |
+
dropout=0.1,
|
248 |
+
batch_first=True,
|
249 |
+
norm_first=norm_first,
|
250 |
+
adaptive_layer_norm=True,
|
251 |
+
),
|
252 |
+
num_layers=int(num_layers * nar_scale_factor),
|
253 |
+
norm=nn.LayerNorm(nar_d_model)
|
254 |
+
if norm_first
|
255 |
+
else None,
|
256 |
+
)
|
257 |
+
self.nar_predict_layers = nn.ModuleList(
|
258 |
+
[
|
259 |
+
nn.Linear(nar_d_model, len(self.nar_at_dict), bias=False)
|
260 |
+
for i in range(self.num_quantizers)
|
261 |
+
]
|
262 |
+
)
|
263 |
+
self.nar_stage_embeddings = None
|
264 |
+
|
265 |
+
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
|
266 |
+
assert stage > 0
|
267 |
+
if stage == 1:
|
268 |
+
for name, param in self.named_parameters():
|
269 |
+
if name.startswith("ar_"):
|
270 |
+
print(f" AR parameter: {name}")
|
271 |
+
yield param
|
272 |
+
|
273 |
+
if stage == 2:
|
274 |
+
for name, param in self.named_parameters():
|
275 |
+
if name.startswith("nar_"):
|
276 |
+
print(f"NAR parameter: {name}")
|
277 |
+
yield param
|
278 |
+
|
279 |
+
def stage_named_parameters(
|
280 |
+
self, stage: int = 1
|
281 |
+
) -> Iterator[Tuple[str, nn.Parameter]]:
|
282 |
+
assert stage > 0
|
283 |
+
if stage == 1:
|
284 |
+
for pair in self.named_parameters():
|
285 |
+
if pair[0].startswith("ar_"):
|
286 |
+
yield pair
|
287 |
+
|
288 |
+
if stage == 2:
|
289 |
+
for pair in self.named_parameters():
|
290 |
+
if pair[0].startswith("nar_"):
|
291 |
+
yield pair
|
292 |
+
|
293 |
+
def pad_y_eos(self, y, y_mask_int, eos_id):
|
294 |
+
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
295 |
+
y_mask_int, (0, 1), value=1
|
296 |
+
)
|
297 |
+
# inputs, targets
|
298 |
+
if self.ar_audio_prepend_bos:
|
299 |
+
return (
|
300 |
+
F.pad(targets[:, :-1], (1, 0), value=self.NUM_AUDIO_TOKENS + 1),
|
301 |
+
targets,
|
302 |
+
)
|
303 |
+
|
304 |
+
return targets[:, :-1], targets[:, 1:]
|
305 |
+
|
306 |
+
class VALLE(VALLF):
|
307 |
+
config_class = VallexConfig
|
308 |
+
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
config: VallexConfig,
|
312 |
+
**kwargs,
|
313 |
+
):
|
314 |
+
super(VALLE, self).__init__(
|
315 |
+
config,
|
316 |
+
**kwargs,
|
317 |
+
)
|
318 |
+
print(config)
|
319 |
+
self.config = config
|
320 |
+
d_model = self.config.n_dim
|
321 |
+
self.eps = config.eps
|
322 |
+
|
323 |
+
self.language_ID = {
|
324 |
+
'en': 0,
|
325 |
+
'zh': 1,
|
326 |
+
}
|
327 |
+
self.ar_language_embedding = nn.Embedding(3, d_model, padding_idx=2)
|
328 |
+
self.nar_language_embedding = nn.Embedding(3, d_model, padding_idx=2)
|
329 |
+
self.embed_scale = 32.0
|
330 |
+
|
331 |
+
def forward(
|
332 |
+
self,
|
333 |
+
zh,
|
334 |
+
en
|
335 |
+
):
|
336 |
+
"""
|
337 |
+
"zh": {
|
338 |
+
"st_tokens": zh_st,
|
339 |
+
"at_tokens_wbos": zh_prev_at,
|
340 |
+
"at_tokens_tgt": zh_tgt_at,
|
341 |
+
"self_atten_mask": zh_self_atten_mask,
|
342 |
+
"padding_mask": zh_padding_mask,
|
343 |
+
"langid": zh_id.long()
|
344 |
+
},
|
345 |
+
"en": {
|
346 |
+
"st_tokens": en_st,
|
347 |
+
"at_tokens_wbos": en_prev_at,
|
348 |
+
"at_tokens_tgt": en_tgt_at,
|
349 |
+
"self_atten_mask": en_self_atten_mask,
|
350 |
+
"padding_mask": en_padding_mask,
|
351 |
+
"langid": en_id.long()
|
352 |
+
}
|
353 |
+
"""
|
354 |
+
flag = (np.random.randint(low=0, high=1000) % 2 == 0) # zh or en
|
355 |
+
if flag:
|
356 |
+
data = zh
|
357 |
+
else:
|
358 |
+
data = en
|
359 |
+
|
360 |
+
st_tokens = data["st_tokens"]
|
361 |
+
at_tokens_wbos = data["at_tokens_wbos"]
|
362 |
+
at_tokens_tgt = data["at_tokens_tgt"]
|
363 |
+
self_atten_mask = data["self_atten_mask"]
|
364 |
+
padding_mask = data["padding_mask"]
|
365 |
+
langid = data["langid"]
|
366 |
+
|
367 |
+
st_len = st_tokens.size(1)
|
368 |
+
st_emb = self.embed_scale * self.ar_text_embedding(st_tokens)
|
369 |
+
src_lang_emb = self.embed_scale * self.ar_language_embedding(langid)
|
370 |
+
st_emb += src_lang_emb
|
371 |
+
st_pos = self.ar_text_position(st_tokens)
|
372 |
+
st_emb += st_pos
|
373 |
+
|
374 |
+
at_emb, _ = self.ar_audio_embedding(at_tokens_wbos, None)
|
375 |
+
at_emb = self.embed_scale * at_emb
|
376 |
+
tgt_lang_emb = self.embed_scale * self.ar_language_embedding(langid)
|
377 |
+
at_emb += tgt_lang_emb
|
378 |
+
at_pos = self.ar_audio_position(at_tokens_wbos)
|
379 |
+
at_emb += at_pos
|
380 |
+
|
381 |
+
x = torch.concat([st_emb, at_emb], dim=1)
|
382 |
+
|
383 |
+
x = self.ar_decoder(
|
384 |
+
x,
|
385 |
+
mask=self_atten_mask,
|
386 |
+
src_key_padding_mask=padding_mask
|
387 |
+
)
|
388 |
+
x = self.ar_predict_layer(x)
|
389 |
+
x = x[:, st_len:, :]
|
390 |
+
loss, nll_loss, lprob, right_rate = self.calculate_loss(
|
391 |
+
x, at_tokens_tgt
|
392 |
+
)
|
393 |
+
return ModelOutput(logits=lprob, loss=loss, acc=right_rate), right_rate
|
394 |
+
|
395 |
+
def calculate_loss(self, encoder_out, target, reduce=True, scale=1.0, prob_mask=None, acc=True):
|
396 |
+
lprob = self.get_normalized_probs(encoder_out, log_probs=True)
|
397 |
+
with torch.no_grad():
|
398 |
+
mask = target.ne(self.padding_idx)
|
399 |
+
n_correct = torch.sum(
|
400 |
+
lprob.argmax(-1).masked_select(mask).eq(target.masked_select(mask))
|
401 |
+
)
|
402 |
+
total = torch.sum(mask)
|
403 |
+
right_rate = n_correct * 100.0 / total
|
404 |
+
|
405 |
+
lprob, target = lprob.view(-1, lprob.size(-1)), target.view(-1)
|
406 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
407 |
+
lprob,
|
408 |
+
target,
|
409 |
+
self.eps,
|
410 |
+
ignore_index=self.padding_idx,
|
411 |
+
reduce=reduce,
|
412 |
+
scale=scale,
|
413 |
+
prob_mask=prob_mask
|
414 |
+
)
|
415 |
+
|
416 |
+
return loss, nll_loss, lprob, right_rate
|
417 |
+
|
418 |
+
def get_normalized_probs(self, encoder_out, log_probs, sample=None):
|
419 |
+
if torch.is_tensor(encoder_out):
|
420 |
+
logits = encoder_out.float()
|
421 |
+
if log_probs:
|
422 |
+
return F.log_softmax(logits, dim=-1)
|
423 |
+
else:
|
424 |
+
return F.softmax(logits, dim=-1)
|
425 |
+
|
426 |
+
|
427 |
+
def inference_24L(
|
428 |
+
self,
|
429 |
+
x: torch.Tensor,
|
430 |
+
x_lens: torch.Tensor,
|
431 |
+
y: torch.Tensor,
|
432 |
+
enroll_x_lens: torch.Tensor,
|
433 |
+
top_k: int = -100,
|
434 |
+
temperature: float = 1.0,
|
435 |
+
prompt_language: str = None,
|
436 |
+
text_language: str = None,
|
437 |
+
best_of: int = 1,
|
438 |
+
length_penalty: float = 1.0,
|
439 |
+
return_worst: bool = False,
|
440 |
+
at_eos: int = -1
|
441 |
+
) -> torch.Tensor:
|
442 |
+
assert x.ndim == 2, x.shape
|
443 |
+
assert x_lens.ndim == 1, x_lens.shape
|
444 |
+
assert y.ndim == 3, y.shape
|
445 |
+
assert y.shape[0] == 1, y.shape
|
446 |
+
|
447 |
+
assert torch.all(x_lens > 0)
|
448 |
+
self.NUM_AUDIO_TOKENS = at_eos
|
449 |
+
text = x
|
450 |
+
x = self.embed_scale * self.ar_text_embedding(text)
|
451 |
+
prompt_language_id = prompt_language.to(x.device)
|
452 |
+
text_language_id = text_language.to(x.device)
|
453 |
+
src_lang_emb = self.embed_scale * self.ar_language_embedding(prompt_language_id)
|
454 |
+
tgt_lang_emb = self.embed_scale * self.ar_language_embedding(text_language_id)
|
455 |
+
x[:, :enroll_x_lens, :] += src_lang_emb
|
456 |
+
x[:, enroll_x_lens:, :] += tgt_lang_emb
|
457 |
+
x = self.ar_text_prenet(x)
|
458 |
+
x_pos = self.ar_text_position(text)
|
459 |
+
|
460 |
+
text_len = x_lens.max()
|
461 |
+
prompts = y
|
462 |
+
prefix_len = y.shape[1]
|
463 |
+
|
464 |
+
# AR Decoder
|
465 |
+
# TODO: Managing decoder steps avoid repetitive computation
|
466 |
+
y = prompts[..., 0]
|
467 |
+
if self.ar_audio_prepend_bos:
|
468 |
+
y = F.pad(y, (1, 0), value=self.ar_at_dict.tts_flag)
|
469 |
+
|
470 |
+
x_len = x_lens.max()
|
471 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
472 |
+
|
473 |
+
kv_cache = None
|
474 |
+
use_kv_caching = True
|
475 |
+
|
476 |
+
sum_logprobs = torch.zeros(best_of, device=y.device) # implement batch decoding here
|
477 |
+
x = x.repeat(best_of, 1, 1)
|
478 |
+
y = y.repeat(best_of, 1)
|
479 |
+
lstm_h = None
|
480 |
+
first_ar = True
|
481 |
+
while True:
|
482 |
+
if first_ar:
|
483 |
+
y_emb, lstm_h = self.ar_audio_embedding(y, lstm_h)
|
484 |
+
y_emb = y_emb * self.embed_scale
|
485 |
+
y_emb = self.ar_audio_prenet(y_emb)
|
486 |
+
y_pos = self.ar_audio_position(y)
|
487 |
+
y_emb[:, :prefix_len] = y_emb[:, :prefix_len] + src_lang_emb
|
488 |
+
y_emb[:, prefix_len:] = y_emb[:, prefix_len:] + tgt_lang_emb
|
489 |
+
xy_pos_token = torch.concat([x_pos+x, y_pos+y_emb], dim=1)
|
490 |
+
first_ar = False
|
491 |
+
else:
|
492 |
+
y_emb_cur, lstm_h = self.ar_audio_embedding(y[:, -1:], lstm_h)
|
493 |
+
y_emb_cur = y_emb_cur * self.embed_scale
|
494 |
+
y_emb_cur = self.ar_audio_prenet(y_emb_cur)
|
495 |
+
y_pos_cur = self.ar_audio_position(y)[:, -1:]
|
496 |
+
y_emb_cur = y_emb_cur + src_lang_emb
|
497 |
+
y_emb_cur = y_emb_cur + tgt_lang_emb
|
498 |
+
xy_pos_token = torch.concat([xy_pos_token, y_pos_cur+y_emb_cur], dim=1)
|
499 |
+
# print(xy_pos_token.size())
|
500 |
+
|
501 |
+
y_len = y.shape[1]
|
502 |
+
x_attn_mask_pad = F.pad(
|
503 |
+
x_attn_mask,
|
504 |
+
(0, y_len),
|
505 |
+
value=True,
|
506 |
+
)
|
507 |
+
y_attn_mask = F.pad(
|
508 |
+
torch.triu(
|
509 |
+
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
|
510 |
+
),
|
511 |
+
(x_len, 0),
|
512 |
+
value=False,
|
513 |
+
)
|
514 |
+
xy_attn_mask = torch.concat(
|
515 |
+
[x_attn_mask_pad, y_attn_mask], dim=0
|
516 |
+
).to(y.device)
|
517 |
+
|
518 |
+
|
519 |
+
if use_kv_caching and kv_cache is not None:
|
520 |
+
xy_pos = xy_pos_token[:, [-1]]
|
521 |
+
xy_attn_mask = xy_attn_mask[:, [-1]]
|
522 |
+
else:
|
523 |
+
xy_pos = xy_pos_token
|
524 |
+
|
525 |
+
xy_dec, kv_cache = self.ar_decoder.infer(
|
526 |
+
xy_pos,
|
527 |
+
mask=xy_attn_mask,
|
528 |
+
past_kv=kv_cache,
|
529 |
+
use_cache=use_kv_caching,
|
530 |
+
)
|
531 |
+
|
532 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
533 |
+
samples, current_logprobs = topk_sampling(
|
534 |
+
logits, top_k=top_k, top_p=1, temperature=temperature
|
535 |
+
)
|
536 |
+
sum_logprobs += current_logprobs * (y[:, -1] != self.NUM_AUDIO_TOKENS)
|
537 |
+
samples[y[:, -1] == self.NUM_AUDIO_TOKENS] = self.NUM_AUDIO_TOKENS
|
538 |
+
completed = (samples[:, -1] == self.NUM_AUDIO_TOKENS).all()
|
539 |
+
if (
|
540 |
+
completed
|
541 |
+
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 32
|
542 |
+
):
|
543 |
+
if prompts.shape[1] == y.shape[1]:
|
544 |
+
raise SyntaxError(
|
545 |
+
"well trained model shouldn't reach here."
|
546 |
+
)
|
547 |
+
lengths = torch.sum(y != self.NUM_AUDIO_TOKENS, dim=1)
|
548 |
+
avg_logprobs = sum_logprobs / lengths ** length_penalty
|
549 |
+
# choose the best beam according to sum_logprobs
|
550 |
+
best_beam = y[torch.argmax(avg_logprobs), :]
|
551 |
+
worst_beam = y[torch.argmin(avg_logprobs), :]
|
552 |
+
# strip all eos tokens
|
553 |
+
best_beam = best_beam[best_beam != self.NUM_AUDIO_TOKENS]
|
554 |
+
worst_beam = worst_beam[worst_beam != self.NUM_AUDIO_TOKENS]
|
555 |
+
if return_worst:
|
556 |
+
y = worst_beam.unsqueeze(0)
|
557 |
+
else:
|
558 |
+
y = best_beam.unsqueeze(0)
|
559 |
+
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
|
560 |
+
break
|
561 |
+
|
562 |
+
y = torch.concat([y, samples], dim=1)
|
563 |
+
|
564 |
+
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
|
565 |
+
if self.num_quantizers == 1:
|
566 |
+
return torch.stack(codes, dim=-1)
|
567 |
+
|
568 |
+
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
|
569 |
+
enrolled_len = enroll_x_lens.max().item()
|
570 |
+
# SOS + Synthesis Text + EOS
|
571 |
+
text = torch.concat(
|
572 |
+
[
|
573 |
+
text[:, :1],
|
574 |
+
text[:, enrolled_len - 1 :],
|
575 |
+
],
|
576 |
+
dim=1,
|
577 |
+
)
|
578 |
+
text_len = text_len - (enrolled_len - 2)
|
579 |
+
assert text.shape[0] == 1
|
580 |
+
|
581 |
+
x = self.embed_scale * self.nar_text_embedding(text)
|
582 |
+
# Add language embedding
|
583 |
+
prompt_language_id = prompt_language.to(x.device)
|
584 |
+
text_language_id = text_language.to(x.device)
|
585 |
+
src_lang_emb = self.embed_scale * self.nar_language_embedding(prompt_language_id)
|
586 |
+
tgt_lang_emb = self.embed_scale * self.nar_language_embedding(text_language_id)
|
587 |
+
x[:, :enroll_x_lens, :] += src_lang_emb
|
588 |
+
x[:, enroll_x_lens:, :] += tgt_lang_emb
|
589 |
+
x = self.nar_text_prenet(x)
|
590 |
+
x_pos = self.nar_text_position(text)
|
591 |
+
|
592 |
+
if self.prefix_mode == 0:
|
593 |
+
for i, predict_layer in enumerate(
|
594 |
+
self.nar_predict_layers
|
595 |
+
):
|
596 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
597 |
+
y_pos = self.nar_audio_position(y_pos)
|
598 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
599 |
+
|
600 |
+
xy_dec, _ = self.nar_decoder(
|
601 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
602 |
+
)
|
603 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
604 |
+
|
605 |
+
samples = torch.argmax(logits, dim=-1)
|
606 |
+
codes.append(samples)
|
607 |
+
|
608 |
+
if i < self.num_quantizers - 2:
|
609 |
+
y_emb[:, :prefix_len] += self.embed_scale * self.nar_audio_embeddings(
|
610 |
+
prompts[..., i + 1]
|
611 |
+
)[0]
|
612 |
+
y_emb[:, prefix_len:] += self.embed_scale * self.nar_audio_embeddings(samples)[0]
|
613 |
+
else:
|
614 |
+
y_pos = self.nar_audio_position(y[:, int(self.ar_audio_prepend_bos):])
|
615 |
+
|
616 |
+
ref_at_emb = self.embed_scale * self.nar_audio_embeddings(prompts)[0] + src_lang_emb
|
617 |
+
est_at = y[:, prefix_len+int(self.ar_audio_prepend_bos):].unsqueeze(-1)
|
618 |
+
#
|
619 |
+
for i in range(1, 8):
|
620 |
+
y_emb, _ = self.nar_audio_embeddings(est_at)
|
621 |
+
y_emb = self.embed_scale * y_emb + tgt_lang_emb
|
622 |
+
|
623 |
+
y_emb = torch.concat([ref_at_emb, y_emb], dim=1)
|
624 |
+
xy_pos = torch.concat([x+x_pos, y_emb+y_pos], dim=1)
|
625 |
+
|
626 |
+
xy_dec = self.nar_decoder(
|
627 |
+
xy_pos
|
628 |
+
)
|
629 |
+
logits = self.nar_predict_layers[i-1](xy_dec[:, text_len + prefix_len :])
|
630 |
+
# print(logits.size(), xy_pos.size(), xy_dec.size())
|
631 |
+
samples = torch.argmax(logits, dim=-1)
|
632 |
+
est_at = torch.concat([est_at, samples.unsqueeze(-1)], dim=-1)
|
633 |
+
codes.append(samples)
|
634 |
+
|
635 |
+
assert len(codes) == self.num_quantizers
|
636 |
+
return torch.stack(codes, dim=-1)
|
637 |
+
|
638 |
+
def top_k_top_p_filtering(
|
639 |
+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
640 |
+
):
|
641 |
+
if top_k > 0:
|
642 |
+
top_k = min(
|
643 |
+
max(top_k, min_tokens_to_keep), logits.size(-1)
|
644 |
+
) # Safety check
|
645 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
646 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
647 |
+
logits[indices_to_remove] = filter_value
|
648 |
+
|
649 |
+
if top_p < 1.0:
|
650 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
651 |
+
cumulative_probs = torch.cumsum(
|
652 |
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
653 |
+
)
|
654 |
+
|
655 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
656 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
657 |
+
if min_tokens_to_keep > 1:
|
658 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
659 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
660 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
661 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
662 |
+
..., :-1
|
663 |
+
].clone()
|
664 |
+
sorted_indices_to_remove[..., 0] = 0
|
665 |
+
|
666 |
+
# scatter sorted tensors to original indexing
|
667 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
668 |
+
1, sorted_indices, sorted_indices_to_remove
|
669 |
+
)
|
670 |
+
logits[indices_to_remove] = filter_value
|
671 |
+
return logits
|
672 |
+
|
673 |
+
|
674 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
675 |
+
if temperature != 1.0:
|
676 |
+
logits = logits / temperature
|
677 |
+
# Top-p/top-k filtering
|
678 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
679 |
+
# Sample
|
680 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
681 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
682 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
|
683 |
+
return token, current_logprobs
|
684 |
+
|
685 |
+
class SLSTM(nn.Module):
|
686 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional=False):
|
687 |
+
super().__init__()
|
688 |
+
self.skip = skip
|
689 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional)
|
690 |
+
if bidirectional:
|
691 |
+
self.out_fc = nn.Linear(dimension*2, dimension)
|
692 |
+
else:
|
693 |
+
self.out_fc = None
|
694 |
+
|
695 |
+
def forward(self, x, hidden=None):
|
696 |
+
x = x.permute(2, 0, 1)
|
697 |
+
y, hidden = self.lstm(x, hidden)
|
698 |
+
if self.out_fc is not None:
|
699 |
+
y = self.out_fc(y)
|
700 |
+
if self.skip:
|
701 |
+
y = y + x
|
702 |
+
y = y.permute(1, 2, 0)
|
703 |
+
return y, hidden
|
704 |
+
|
705 |
+
class EncodecDecoderLstm(nn.Module):
|
706 |
+
def __init__(self, dictionary, emb_dim,
|
707 |
+
out_dim=None,
|
708 |
+
num_layers=3, lstm_skip=True, lstm_bidire=False,
|
709 |
+
activation_param={'alpha': 1.0}, **kwargs):
|
710 |
+
super().__init__()
|
711 |
+
|
712 |
+
# Identity()
|
713 |
+
if out_dim is None:
|
714 |
+
out_dim = emb_dim
|
715 |
+
self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire)
|
716 |
+
self.elu = nn.ELU(**activation_param)
|
717 |
+
self.embedding_dim = emb_dim
|
718 |
+
self.padding_idx = dictionary.pad()
|
719 |
+
self.emb = nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index)
|
720 |
+
|
721 |
+
def forward(self, x, hidden=None):
|
722 |
+
"""
|
723 |
+
Args:
|
724 |
+
x (_type_): B,T,D
|
725 |
+
"""
|
726 |
+
# print(x.size())
|
727 |
+
quantized_out = self.emb(x)
|
728 |
+
out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden)
|
729 |
+
out = self.elu(out)
|
730 |
+
return out.permute(0,2,1), hidden
|
731 |
+
|
732 |
+
class NATEncodecDecoderLstm(nn.Module):
|
733 |
+
def __init__(self, codecs, dictionary, emb_dim,
|
734 |
+
out_dim=None,
|
735 |
+
num_layers=3, lstm_skip=True, lstm_bidire=False,
|
736 |
+
activation_param={'alpha': 1.0}, **kwargs):
|
737 |
+
super().__init__()
|
738 |
+
|
739 |
+
# Identity()
|
740 |
+
if out_dim is None:
|
741 |
+
out_dim = emb_dim
|
742 |
+
self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire)
|
743 |
+
self.elu = nn.ELU(**activation_param)
|
744 |
+
self.codecs = codecs
|
745 |
+
self.embedding_dim = emb_dim
|
746 |
+
self.padding_idx = dictionary.pad()
|
747 |
+
self.emb_list = nn.ModuleList(
|
748 |
+
[nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index) for i in range(len(self.codecs))]
|
749 |
+
)
|
750 |
+
|
751 |
+
def forward(self, x, hidden=None):
|
752 |
+
"""
|
753 |
+
Args:
|
754 |
+
x (_type_): B,T,D
|
755 |
+
"""
|
756 |
+
if len(x.size()) == 2:
|
757 |
+
x = x.unsqueeze(-1)
|
758 |
+
|
759 |
+
if x.size(2) != len(self.codecs) and x.size(1) == len(self.codecs):
|
760 |
+
x = x.permute(0, 2, 1)
|
761 |
+
|
762 |
+
quantized_out = 0
|
763 |
+
for i in range(x.size(2)):
|
764 |
+
quantized = self.emb_list[i](x[: , :, i])
|
765 |
+
quantized_out = quantized_out + quantized
|
766 |
+
# quantized_out = quantized_out / len(self.codecs)
|
767 |
+
|
768 |
+
out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden)
|
769 |
+
out = self.elu(out)
|
770 |
+
return out.permute(0,2,1), hidden
|
771 |
+
|
772 |
+
AutoModel.register(VallexConfig, VALLE)
|
slam_llm/models/wavlm/WavLM.py
ADDED
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
4 |
+
# Copyright (c) 2021 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import logging
|
12 |
+
from typing import List, Optional, Tuple
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch.nn import LayerNorm
|
20 |
+
from .modules import (
|
21 |
+
Fp32GroupNorm,
|
22 |
+
Fp32LayerNorm,
|
23 |
+
GradMultiply,
|
24 |
+
MultiheadAttention,
|
25 |
+
SamePad,
|
26 |
+
init_bert_params,
|
27 |
+
get_activation_fn,
|
28 |
+
TransposeLast,
|
29 |
+
GLU_Linear,
|
30 |
+
)
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def compute_mask_indices(
|
36 |
+
shape: Tuple[int, int],
|
37 |
+
padding_mask: Optional[torch.Tensor],
|
38 |
+
mask_prob: float,
|
39 |
+
mask_length: int,
|
40 |
+
mask_type: str = "static",
|
41 |
+
mask_other: float = 0.0,
|
42 |
+
min_masks: int = 0,
|
43 |
+
no_overlap: bool = False,
|
44 |
+
min_space: int = 0,
|
45 |
+
) -> np.ndarray:
|
46 |
+
"""
|
47 |
+
Computes random mask spans for a given shape
|
48 |
+
|
49 |
+
Args:
|
50 |
+
shape: the the shape for which to compute masks.
|
51 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
52 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
53 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
54 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
55 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
56 |
+
mask_type: how to compute mask lengths
|
57 |
+
static = fixed size
|
58 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
59 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
60 |
+
poisson = sample from possion distribution with lambda = mask length
|
61 |
+
min_masks: minimum number of masked spans
|
62 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
63 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
64 |
+
"""
|
65 |
+
|
66 |
+
bsz, all_sz = shape
|
67 |
+
mask = np.full((bsz, all_sz), False)
|
68 |
+
|
69 |
+
all_num_mask = int(
|
70 |
+
# add a random number for probabilistic rounding
|
71 |
+
mask_prob * all_sz / float(mask_length)
|
72 |
+
+ np.random.rand()
|
73 |
+
)
|
74 |
+
|
75 |
+
all_num_mask = max(min_masks, all_num_mask)
|
76 |
+
|
77 |
+
mask_idcs = []
|
78 |
+
for i in range(bsz):
|
79 |
+
if padding_mask is not None:
|
80 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
81 |
+
num_mask = int(
|
82 |
+
# add a random number for probabilistic rounding
|
83 |
+
mask_prob * sz / float(mask_length)
|
84 |
+
+ np.random.rand()
|
85 |
+
)
|
86 |
+
num_mask = max(min_masks, num_mask)
|
87 |
+
else:
|
88 |
+
sz = all_sz
|
89 |
+
num_mask = all_num_mask
|
90 |
+
|
91 |
+
if mask_type == "static":
|
92 |
+
lengths = np.full(num_mask, mask_length)
|
93 |
+
elif mask_type == "uniform":
|
94 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
95 |
+
elif mask_type == "normal":
|
96 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
97 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
98 |
+
elif mask_type == "poisson":
|
99 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
100 |
+
lengths = [int(round(x)) for x in lengths]
|
101 |
+
else:
|
102 |
+
raise Exception("unknown mask selection " + mask_type)
|
103 |
+
|
104 |
+
if sum(lengths) == 0:
|
105 |
+
lengths[0] = min(mask_length, sz - 1)
|
106 |
+
|
107 |
+
if no_overlap:
|
108 |
+
mask_idc = []
|
109 |
+
|
110 |
+
def arrange(s, e, length, keep_length):
|
111 |
+
span_start = np.random.randint(s, e - length)
|
112 |
+
mask_idc.extend(span_start + i for i in range(length))
|
113 |
+
|
114 |
+
new_parts = []
|
115 |
+
if span_start - s - min_space >= keep_length:
|
116 |
+
new_parts.append((s, span_start - min_space + 1))
|
117 |
+
if e - span_start - keep_length - min_space > keep_length:
|
118 |
+
new_parts.append((span_start + length + min_space, e))
|
119 |
+
return new_parts
|
120 |
+
|
121 |
+
parts = [(0, sz)]
|
122 |
+
min_length = min(lengths)
|
123 |
+
for length in sorted(lengths, reverse=True):
|
124 |
+
lens = np.fromiter(
|
125 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
126 |
+
np.int,
|
127 |
+
)
|
128 |
+
l_sum = np.sum(lens)
|
129 |
+
if l_sum == 0:
|
130 |
+
break
|
131 |
+
probs = lens / np.sum(lens)
|
132 |
+
c = np.random.choice(len(parts), p=probs)
|
133 |
+
s, e = parts.pop(c)
|
134 |
+
parts.extend(arrange(s, e, length, min_length))
|
135 |
+
mask_idc = np.asarray(mask_idc)
|
136 |
+
else:
|
137 |
+
min_len = min(lengths)
|
138 |
+
if sz - min_len <= num_mask:
|
139 |
+
min_len = sz - num_mask - 1
|
140 |
+
|
141 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
142 |
+
|
143 |
+
mask_idc = np.asarray(
|
144 |
+
[
|
145 |
+
mask_idc[j] + offset
|
146 |
+
for j in range(len(mask_idc))
|
147 |
+
for offset in range(lengths[j])
|
148 |
+
]
|
149 |
+
)
|
150 |
+
|
151 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
152 |
+
|
153 |
+
min_len = min([len(m) for m in mask_idcs])
|
154 |
+
for i, mask_idc in enumerate(mask_idcs):
|
155 |
+
if len(mask_idc) > min_len:
|
156 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
157 |
+
mask[i, mask_idc] = True
|
158 |
+
|
159 |
+
return mask
|
160 |
+
|
161 |
+
|
162 |
+
class WavLMConfig:
|
163 |
+
def __init__(self, cfg=None):
|
164 |
+
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
165 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
166 |
+
|
167 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
168 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
169 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
170 |
+
self.activation_fn: str = "gelu" # activation function to use
|
171 |
+
|
172 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
173 |
+
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
174 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
175 |
+
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
176 |
+
|
177 |
+
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
178 |
+
|
179 |
+
# dropouts
|
180 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
181 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
182 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
183 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
184 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
185 |
+
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
186 |
+
|
187 |
+
# masking
|
188 |
+
self.mask_length: int = 10 # mask length
|
189 |
+
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
190 |
+
self.mask_selection: str = "static" # how to choose mask length
|
191 |
+
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
192 |
+
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
193 |
+
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
194 |
+
|
195 |
+
# channel masking
|
196 |
+
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
197 |
+
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
198 |
+
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
199 |
+
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
200 |
+
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
201 |
+
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
202 |
+
|
203 |
+
# positional embeddings
|
204 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
205 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
206 |
+
|
207 |
+
# relative position embedding
|
208 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
209 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
210 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
211 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
212 |
+
|
213 |
+
if cfg is not None:
|
214 |
+
self.update(cfg)
|
215 |
+
|
216 |
+
def update(self, cfg: dict):
|
217 |
+
self.__dict__.update(cfg)
|
218 |
+
|
219 |
+
|
220 |
+
class WavLM(nn.Module):
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
cfg: WavLMConfig,
|
224 |
+
) -> None:
|
225 |
+
super().__init__()
|
226 |
+
logger.info(f"WavLM Config: {cfg.__dict__}")
|
227 |
+
|
228 |
+
self.cfg = cfg
|
229 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
230 |
+
self.embed = feature_enc_layers[-1][0]
|
231 |
+
|
232 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
233 |
+
conv_layers=feature_enc_layers,
|
234 |
+
dropout=0.0,
|
235 |
+
mode=cfg.extractor_mode,
|
236 |
+
conv_bias=cfg.conv_bias,
|
237 |
+
)
|
238 |
+
|
239 |
+
self.post_extract_proj = (
|
240 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
241 |
+
if self.embed != cfg.encoder_embed_dim
|
242 |
+
else None
|
243 |
+
)
|
244 |
+
|
245 |
+
self.mask_prob = cfg.mask_prob
|
246 |
+
self.mask_selection = cfg.mask_selection
|
247 |
+
self.mask_other = cfg.mask_other
|
248 |
+
self.mask_length = cfg.mask_length
|
249 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
250 |
+
self.mask_min_space = cfg.mask_min_space
|
251 |
+
|
252 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
253 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
254 |
+
self.mask_channel_other = cfg.mask_channel_other
|
255 |
+
self.mask_channel_length = cfg.mask_channel_length
|
256 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
257 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
258 |
+
|
259 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
260 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
261 |
+
|
262 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
263 |
+
|
264 |
+
self.mask_emb = nn.Parameter(
|
265 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
266 |
+
)
|
267 |
+
|
268 |
+
self.encoder = TransformerEncoder(cfg)
|
269 |
+
self.layer_norm = LayerNorm(self.embed)
|
270 |
+
|
271 |
+
def apply_mask(self, x, padding_mask):
|
272 |
+
B, T, C = x.shape
|
273 |
+
if self.mask_prob > 0:
|
274 |
+
mask_indices = compute_mask_indices(
|
275 |
+
(B, T),
|
276 |
+
padding_mask,
|
277 |
+
self.mask_prob,
|
278 |
+
self.mask_length,
|
279 |
+
self.mask_selection,
|
280 |
+
self.mask_other,
|
281 |
+
min_masks=2,
|
282 |
+
no_overlap=self.no_mask_overlap,
|
283 |
+
min_space=self.mask_min_space,
|
284 |
+
)
|
285 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
286 |
+
x[mask_indices] = self.mask_emb
|
287 |
+
else:
|
288 |
+
mask_indices = None
|
289 |
+
|
290 |
+
if self.mask_channel_prob > 0:
|
291 |
+
mask_channel_indices = compute_mask_indices(
|
292 |
+
(B, C),
|
293 |
+
None,
|
294 |
+
self.mask_channel_prob,
|
295 |
+
self.mask_channel_length,
|
296 |
+
self.mask_channel_selection,
|
297 |
+
self.mask_channel_other,
|
298 |
+
no_overlap=self.no_mask_channel_overlap,
|
299 |
+
min_space=self.mask_channel_min_space,
|
300 |
+
)
|
301 |
+
mask_channel_indices = (
|
302 |
+
torch.from_numpy(mask_channel_indices)
|
303 |
+
.to(x.device)
|
304 |
+
.unsqueeze(1)
|
305 |
+
.expand(-1, T, -1)
|
306 |
+
)
|
307 |
+
x[mask_channel_indices] = 0
|
308 |
+
|
309 |
+
return x, mask_indices
|
310 |
+
|
311 |
+
def forward_padding_mask(
|
312 |
+
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
313 |
+
) -> torch.Tensor:
|
314 |
+
extra = padding_mask.size(1) % features.size(1)
|
315 |
+
if extra > 0:
|
316 |
+
padding_mask = padding_mask[:, :-extra]
|
317 |
+
padding_mask = padding_mask.view(
|
318 |
+
padding_mask.size(0), features.size(1), -1
|
319 |
+
)
|
320 |
+
padding_mask = padding_mask.all(-1)
|
321 |
+
return padding_mask
|
322 |
+
|
323 |
+
def extract_features(
|
324 |
+
self,
|
325 |
+
source: torch.Tensor,
|
326 |
+
padding_mask: Optional[torch.Tensor] = None,
|
327 |
+
mask: bool = False,
|
328 |
+
ret_conv: bool = False,
|
329 |
+
output_layer: Optional[int] = None,
|
330 |
+
ret_layer_results: bool = False,
|
331 |
+
):
|
332 |
+
|
333 |
+
if self.feature_grad_mult > 0:
|
334 |
+
features = self.feature_extractor(source)
|
335 |
+
if self.feature_grad_mult != 1.0:
|
336 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
337 |
+
else:
|
338 |
+
with torch.no_grad():
|
339 |
+
features = self.feature_extractor(source)
|
340 |
+
|
341 |
+
features = features.transpose(1, 2)
|
342 |
+
features = self.layer_norm(features)
|
343 |
+
|
344 |
+
if padding_mask is not None:
|
345 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
346 |
+
|
347 |
+
if self.post_extract_proj is not None:
|
348 |
+
features = self.post_extract_proj(features)
|
349 |
+
|
350 |
+
features = self.dropout_input(features)
|
351 |
+
|
352 |
+
if mask:
|
353 |
+
x, mask_indices = self.apply_mask(
|
354 |
+
features, padding_mask
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
x = features
|
358 |
+
|
359 |
+
# feature: (B, T, D), float
|
360 |
+
# target: (B, T), long
|
361 |
+
# x: (B, T, D), float
|
362 |
+
# padding_mask: (B, T), bool
|
363 |
+
# mask_indices: (B, T), bool
|
364 |
+
x, layer_results = self.encoder(
|
365 |
+
x,
|
366 |
+
padding_mask=padding_mask,
|
367 |
+
layer=None if output_layer is None else output_layer - 1
|
368 |
+
)
|
369 |
+
|
370 |
+
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
|
371 |
+
|
372 |
+
feature = res["features"] if ret_conv else res["x"]
|
373 |
+
if ret_layer_results:
|
374 |
+
feature = (feature, res["layer_results"])
|
375 |
+
return feature, res["padding_mask"]
|
376 |
+
|
377 |
+
|
378 |
+
class ConvFeatureExtractionModel(nn.Module):
|
379 |
+
def __init__(
|
380 |
+
self,
|
381 |
+
conv_layers: List[Tuple[int, int, int]],
|
382 |
+
dropout: float = 0.0,
|
383 |
+
mode: str = "default",
|
384 |
+
conv_bias: bool = False,
|
385 |
+
conv_type: str = "default"
|
386 |
+
):
|
387 |
+
super().__init__()
|
388 |
+
|
389 |
+
assert mode in {"default", "layer_norm"}
|
390 |
+
|
391 |
+
def block(
|
392 |
+
n_in,
|
393 |
+
n_out,
|
394 |
+
k,
|
395 |
+
stride,
|
396 |
+
is_layer_norm=False,
|
397 |
+
is_group_norm=False,
|
398 |
+
conv_bias=False,
|
399 |
+
):
|
400 |
+
def make_conv():
|
401 |
+
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
402 |
+
nn.init.kaiming_normal_(conv.weight)
|
403 |
+
return conv
|
404 |
+
|
405 |
+
assert (
|
406 |
+
is_layer_norm and is_group_norm
|
407 |
+
) == False, "layer norm and group norm are exclusive"
|
408 |
+
|
409 |
+
if is_layer_norm:
|
410 |
+
return nn.Sequential(
|
411 |
+
make_conv(),
|
412 |
+
nn.Dropout(p=dropout),
|
413 |
+
nn.Sequential(
|
414 |
+
TransposeLast(),
|
415 |
+
Fp32LayerNorm(dim, elementwise_affine=True),
|
416 |
+
TransposeLast(),
|
417 |
+
),
|
418 |
+
nn.GELU(),
|
419 |
+
)
|
420 |
+
elif is_group_norm:
|
421 |
+
return nn.Sequential(
|
422 |
+
make_conv(),
|
423 |
+
nn.Dropout(p=dropout),
|
424 |
+
Fp32GroupNorm(dim, dim, affine=True),
|
425 |
+
nn.GELU(),
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
429 |
+
|
430 |
+
self.conv_type = conv_type
|
431 |
+
if self.conv_type == "default":
|
432 |
+
in_d = 1
|
433 |
+
self.conv_layers = nn.ModuleList()
|
434 |
+
for i, cl in enumerate(conv_layers):
|
435 |
+
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
436 |
+
(dim, k, stride) = cl
|
437 |
+
|
438 |
+
self.conv_layers.append(
|
439 |
+
block(
|
440 |
+
in_d,
|
441 |
+
dim,
|
442 |
+
k,
|
443 |
+
stride,
|
444 |
+
is_layer_norm=mode == "layer_norm",
|
445 |
+
is_group_norm=mode == "default" and i == 0,
|
446 |
+
conv_bias=conv_bias,
|
447 |
+
)
|
448 |
+
)
|
449 |
+
in_d = dim
|
450 |
+
elif self.conv_type == "conv2d":
|
451 |
+
in_d = 1
|
452 |
+
self.conv_layers = nn.ModuleList()
|
453 |
+
for i, cl in enumerate(conv_layers):
|
454 |
+
assert len(cl) == 3
|
455 |
+
(dim, k, stride) = cl
|
456 |
+
|
457 |
+
self.conv_layers.append(
|
458 |
+
torch.nn.Conv2d(in_d, dim, k, stride)
|
459 |
+
)
|
460 |
+
self.conv_layers.append(torch.nn.ReLU())
|
461 |
+
in_d = dim
|
462 |
+
elif self.conv_type == "custom":
|
463 |
+
in_d = 1
|
464 |
+
idim = 80
|
465 |
+
self.conv_layers = nn.ModuleList()
|
466 |
+
for i, cl in enumerate(conv_layers):
|
467 |
+
assert len(cl) == 3
|
468 |
+
(dim, k, stride) = cl
|
469 |
+
self.conv_layers.append(
|
470 |
+
torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
|
471 |
+
)
|
472 |
+
self.conv_layers.append(
|
473 |
+
torch.nn.LayerNorm([dim, idim])
|
474 |
+
)
|
475 |
+
self.conv_layers.append(torch.nn.ReLU())
|
476 |
+
in_d = dim
|
477 |
+
if (i + 1) % 2 == 0:
|
478 |
+
self.conv_layers.append(
|
479 |
+
torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
480 |
+
)
|
481 |
+
idim = int(math.ceil(idim / 2))
|
482 |
+
else:
|
483 |
+
pass
|
484 |
+
|
485 |
+
def forward(self, x, mask=None):
|
486 |
+
|
487 |
+
# BxT -> BxCxT
|
488 |
+
x = x.unsqueeze(1)
|
489 |
+
if self.conv_type == "custom":
|
490 |
+
for conv in self.conv_layers:
|
491 |
+
if isinstance(conv, nn.LayerNorm):
|
492 |
+
x = x.transpose(1, 2)
|
493 |
+
x = conv(x).transpose(1, 2)
|
494 |
+
else:
|
495 |
+
x = conv(x)
|
496 |
+
x = x.transpose(2, 3).contiguous()
|
497 |
+
x = x.view(x.size(0), -1, x.size(-1))
|
498 |
+
else:
|
499 |
+
for conv in self.conv_layers:
|
500 |
+
x = conv(x)
|
501 |
+
if self.conv_type == "conv2d":
|
502 |
+
b, c, t, f = x.size()
|
503 |
+
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
504 |
+
return x
|
505 |
+
|
506 |
+
|
507 |
+
class TransformerEncoder(nn.Module):
|
508 |
+
def __init__(self, args):
|
509 |
+
super().__init__()
|
510 |
+
|
511 |
+
self.dropout = args.dropout
|
512 |
+
self.embedding_dim = args.encoder_embed_dim
|
513 |
+
|
514 |
+
self.pos_conv = nn.Conv1d(
|
515 |
+
self.embedding_dim,
|
516 |
+
self.embedding_dim,
|
517 |
+
kernel_size=args.conv_pos,
|
518 |
+
padding=args.conv_pos // 2,
|
519 |
+
groups=args.conv_pos_groups,
|
520 |
+
)
|
521 |
+
dropout = 0
|
522 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
523 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
524 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
525 |
+
|
526 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
527 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
528 |
+
|
529 |
+
if hasattr(args, "relative_position_embedding"):
|
530 |
+
self.relative_position_embedding = args.relative_position_embedding
|
531 |
+
self.num_buckets = args.num_buckets
|
532 |
+
self.max_distance = args.max_distance
|
533 |
+
else:
|
534 |
+
self.relative_position_embedding = False
|
535 |
+
self.num_buckets = 0
|
536 |
+
self.max_distance = 0
|
537 |
+
|
538 |
+
self.layers = nn.ModuleList(
|
539 |
+
[
|
540 |
+
TransformerSentenceEncoderLayer(
|
541 |
+
embedding_dim=self.embedding_dim,
|
542 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
543 |
+
num_attention_heads=args.encoder_attention_heads,
|
544 |
+
dropout=self.dropout,
|
545 |
+
attention_dropout=args.attention_dropout,
|
546 |
+
activation_dropout=args.activation_dropout,
|
547 |
+
activation_fn=args.activation_fn,
|
548 |
+
layer_norm_first=args.layer_norm_first,
|
549 |
+
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
550 |
+
num_buckets=self.num_buckets,
|
551 |
+
max_distance=self.max_distance,
|
552 |
+
gru_rel_pos=args.gru_rel_pos,
|
553 |
+
)
|
554 |
+
for i in range(args.encoder_layers)
|
555 |
+
]
|
556 |
+
)
|
557 |
+
|
558 |
+
self.layer_norm_first = args.layer_norm_first
|
559 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
560 |
+
self.layerdrop = args.encoder_layerdrop
|
561 |
+
|
562 |
+
self.apply(init_bert_params)
|
563 |
+
|
564 |
+
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
565 |
+
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
566 |
+
|
567 |
+
if self.layer_norm_first and layer is None:
|
568 |
+
x = self.layer_norm(x)
|
569 |
+
|
570 |
+
return x, layer_results
|
571 |
+
|
572 |
+
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
573 |
+
|
574 |
+
if padding_mask is not None:
|
575 |
+
x[padding_mask] = 0
|
576 |
+
|
577 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
578 |
+
x_conv = x_conv.transpose(1, 2)
|
579 |
+
x = x + x_conv
|
580 |
+
|
581 |
+
if not self.layer_norm_first:
|
582 |
+
x = self.layer_norm(x)
|
583 |
+
|
584 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
585 |
+
|
586 |
+
# B x T x C -> T x B x C
|
587 |
+
x = x.transpose(0, 1)
|
588 |
+
|
589 |
+
layer_results = []
|
590 |
+
z = None
|
591 |
+
if tgt_layer is not None:
|
592 |
+
layer_results.append((x, z))
|
593 |
+
r = None
|
594 |
+
pos_bias = None
|
595 |
+
for i, layer in enumerate(self.layers):
|
596 |
+
dropout_probability = np.random.random()
|
597 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
598 |
+
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
|
599 |
+
self_attn_mask=streaming_mask, pos_bias=pos_bias)
|
600 |
+
if tgt_layer is not None:
|
601 |
+
layer_results.append((x, z))
|
602 |
+
if i == tgt_layer:
|
603 |
+
r = x
|
604 |
+
break
|
605 |
+
|
606 |
+
if r is not None:
|
607 |
+
x = r
|
608 |
+
|
609 |
+
# T x B x C -> B x T x C
|
610 |
+
x = x.transpose(0, 1)
|
611 |
+
|
612 |
+
return x, layer_results
|
613 |
+
|
614 |
+
|
615 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
616 |
+
"""
|
617 |
+
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
618 |
+
models.
|
619 |
+
"""
|
620 |
+
|
621 |
+
def __init__(
|
622 |
+
self,
|
623 |
+
embedding_dim: float = 768,
|
624 |
+
ffn_embedding_dim: float = 3072,
|
625 |
+
num_attention_heads: float = 8,
|
626 |
+
dropout: float = 0.1,
|
627 |
+
attention_dropout: float = 0.1,
|
628 |
+
activation_dropout: float = 0.1,
|
629 |
+
activation_fn: str = "relu",
|
630 |
+
layer_norm_first: bool = False,
|
631 |
+
has_relative_attention_bias: bool = False,
|
632 |
+
num_buckets: int = 0,
|
633 |
+
max_distance: int = 0,
|
634 |
+
rescale_init: bool = False,
|
635 |
+
gru_rel_pos: bool = False,
|
636 |
+
) -> None:
|
637 |
+
|
638 |
+
super().__init__()
|
639 |
+
# Initialize parameters
|
640 |
+
self.embedding_dim = embedding_dim
|
641 |
+
self.dropout = dropout
|
642 |
+
self.activation_dropout = activation_dropout
|
643 |
+
|
644 |
+
# Initialize blocks
|
645 |
+
self.activation_name = activation_fn
|
646 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
647 |
+
self.self_attn = MultiheadAttention(
|
648 |
+
self.embedding_dim,
|
649 |
+
num_attention_heads,
|
650 |
+
dropout=attention_dropout,
|
651 |
+
self_attention=True,
|
652 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
653 |
+
num_buckets=num_buckets,
|
654 |
+
max_distance=max_distance,
|
655 |
+
rescale_init=rescale_init,
|
656 |
+
gru_rel_pos=gru_rel_pos,
|
657 |
+
)
|
658 |
+
|
659 |
+
self.dropout1 = nn.Dropout(dropout)
|
660 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
661 |
+
self.dropout3 = nn.Dropout(dropout)
|
662 |
+
|
663 |
+
self.layer_norm_first = layer_norm_first
|
664 |
+
|
665 |
+
# layer norm associated with the self attention layer
|
666 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
667 |
+
|
668 |
+
if self.activation_name == "glu":
|
669 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
670 |
+
else:
|
671 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
672 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
673 |
+
|
674 |
+
# layer norm associated with the position wise feed-forward NN
|
675 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
676 |
+
|
677 |
+
def forward(
|
678 |
+
self,
|
679 |
+
x: torch.Tensor,
|
680 |
+
self_attn_mask: torch.Tensor = None,
|
681 |
+
self_attn_padding_mask: torch.Tensor = None,
|
682 |
+
need_weights: bool = False,
|
683 |
+
pos_bias=None
|
684 |
+
):
|
685 |
+
"""
|
686 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
687 |
+
modules similar to the original Transformer imlementation.
|
688 |
+
"""
|
689 |
+
residual = x
|
690 |
+
|
691 |
+
if self.layer_norm_first:
|
692 |
+
x = self.self_attn_layer_norm(x)
|
693 |
+
x, attn, pos_bias = self.self_attn(
|
694 |
+
query=x,
|
695 |
+
key=x,
|
696 |
+
value=x,
|
697 |
+
key_padding_mask=self_attn_padding_mask,
|
698 |
+
need_weights=False,
|
699 |
+
attn_mask=self_attn_mask,
|
700 |
+
position_bias=pos_bias
|
701 |
+
)
|
702 |
+
x = self.dropout1(x)
|
703 |
+
x = residual + x
|
704 |
+
|
705 |
+
residual = x
|
706 |
+
x = self.final_layer_norm(x)
|
707 |
+
if self.activation_name == "glu":
|
708 |
+
x = self.fc1(x)
|
709 |
+
else:
|
710 |
+
x = self.activation_fn(self.fc1(x))
|
711 |
+
x = self.dropout2(x)
|
712 |
+
x = self.fc2(x)
|
713 |
+
x = self.dropout3(x)
|
714 |
+
x = residual + x
|
715 |
+
else:
|
716 |
+
x, attn, pos_bias = self.self_attn(
|
717 |
+
query=x,
|
718 |
+
key=x,
|
719 |
+
value=x,
|
720 |
+
key_padding_mask=self_attn_padding_mask,
|
721 |
+
need_weights=need_weights,
|
722 |
+
attn_mask=self_attn_mask,
|
723 |
+
position_bias=pos_bias
|
724 |
+
)
|
725 |
+
|
726 |
+
x = self.dropout1(x)
|
727 |
+
x = residual + x
|
728 |
+
|
729 |
+
x = self.self_attn_layer_norm(x)
|
730 |
+
|
731 |
+
residual = x
|
732 |
+
if self.activation_name == "glu":
|
733 |
+
x = self.fc1(x)
|
734 |
+
else:
|
735 |
+
x = self.activation_fn(self.fc1(x))
|
736 |
+
x = self.dropout2(x)
|
737 |
+
x = self.fc2(x)
|
738 |
+
x = self.dropout3(x)
|
739 |
+
x = residual + x
|
740 |
+
x = self.final_layer_norm(x)
|
741 |
+
|
742 |
+
return x, attn, pos_bias
|
743 |
+
|
slam_llm/models/wavlm/modules.py
ADDED
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
4 |
+
# Copyright (c) 2021 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import warnings
|
12 |
+
from typing import Dict, Optional, Tuple
|
13 |
+
import torch
|
14 |
+
from torch import Tensor, nn
|
15 |
+
from torch.nn import Parameter
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
class TransposeLast(nn.Module):
|
20 |
+
def __init__(self, deconstruct_idx=None):
|
21 |
+
super().__init__()
|
22 |
+
self.deconstruct_idx = deconstruct_idx
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
if self.deconstruct_idx is not None:
|
26 |
+
x = x[self.deconstruct_idx]
|
27 |
+
return x.transpose(-2, -1)
|
28 |
+
|
29 |
+
|
30 |
+
class Fp32LayerNorm(nn.LayerNorm):
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
+
|
34 |
+
def forward(self, input):
|
35 |
+
output = F.layer_norm(
|
36 |
+
input.float(),
|
37 |
+
self.normalized_shape,
|
38 |
+
self.weight.float() if self.weight is not None else None,
|
39 |
+
self.bias.float() if self.bias is not None else None,
|
40 |
+
self.eps,
|
41 |
+
)
|
42 |
+
return output.type_as(input)
|
43 |
+
|
44 |
+
|
45 |
+
class Fp32GroupNorm(nn.GroupNorm):
|
46 |
+
def __init__(self, *args, **kwargs):
|
47 |
+
super().__init__(*args, **kwargs)
|
48 |
+
|
49 |
+
def forward(self, input):
|
50 |
+
output = F.group_norm(
|
51 |
+
input.float(),
|
52 |
+
self.num_groups,
|
53 |
+
self.weight.float() if self.weight is not None else None,
|
54 |
+
self.bias.float() if self.bias is not None else None,
|
55 |
+
self.eps,
|
56 |
+
)
|
57 |
+
return output.type_as(input)
|
58 |
+
|
59 |
+
|
60 |
+
class GradMultiply(torch.autograd.Function):
|
61 |
+
@staticmethod
|
62 |
+
def forward(ctx, x, scale):
|
63 |
+
ctx.scale = scale
|
64 |
+
res = x.new(x)
|
65 |
+
return res
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def backward(ctx, grad):
|
69 |
+
return grad * ctx.scale, None
|
70 |
+
|
71 |
+
|
72 |
+
class SamePad(nn.Module):
|
73 |
+
def __init__(self, kernel_size, causal=False):
|
74 |
+
super().__init__()
|
75 |
+
if causal:
|
76 |
+
self.remove = kernel_size - 1
|
77 |
+
else:
|
78 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
if self.remove > 0:
|
82 |
+
x = x[:, :, : -self.remove]
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class Swish(nn.Module):
|
87 |
+
"""Swish function
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self):
|
91 |
+
"""Construct an MultiHeadedAttention object."""
|
92 |
+
super(Swish, self).__init__()
|
93 |
+
self.act = torch.nn.Sigmoid()
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
return x * self.act(x)
|
97 |
+
|
98 |
+
|
99 |
+
class GLU_Linear(nn.Module):
|
100 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
101 |
+
super(GLU_Linear, self).__init__()
|
102 |
+
|
103 |
+
self.glu_type = glu_type
|
104 |
+
self.output_dim = output_dim
|
105 |
+
|
106 |
+
if glu_type == "sigmoid":
|
107 |
+
self.glu_act = torch.nn.Sigmoid()
|
108 |
+
elif glu_type == "swish":
|
109 |
+
self.glu_act = Swish()
|
110 |
+
elif glu_type == "relu":
|
111 |
+
self.glu_act = torch.nn.ReLU()
|
112 |
+
elif glu_type == "gelu":
|
113 |
+
self.glu_act = torch.nn.GELU()
|
114 |
+
|
115 |
+
if bias_in_glu:
|
116 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
117 |
+
else:
|
118 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
122 |
+
x = self.linear(x)
|
123 |
+
|
124 |
+
if self.glu_type == "bilinear":
|
125 |
+
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
126 |
+
else:
|
127 |
+
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
128 |
+
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
def gelu_accurate(x):
|
133 |
+
if not hasattr(gelu_accurate, "_a"):
|
134 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
135 |
+
return (
|
136 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
141 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
142 |
+
|
143 |
+
|
144 |
+
def get_activation_fn(activation: str):
|
145 |
+
"""Returns the activation function corresponding to `activation`"""
|
146 |
+
|
147 |
+
if activation == "relu":
|
148 |
+
return F.relu
|
149 |
+
elif activation == "gelu":
|
150 |
+
return gelu
|
151 |
+
elif activation == "gelu_fast":
|
152 |
+
warnings.warn(
|
153 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
154 |
+
)
|
155 |
+
return gelu_accurate
|
156 |
+
elif activation == "gelu_accurate":
|
157 |
+
return gelu_accurate
|
158 |
+
elif activation == "tanh":
|
159 |
+
return torch.tanh
|
160 |
+
elif activation == "linear":
|
161 |
+
return lambda x: x
|
162 |
+
elif activation == "glu":
|
163 |
+
return lambda x: x
|
164 |
+
else:
|
165 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
166 |
+
|
167 |
+
|
168 |
+
def init_bert_params(module):
|
169 |
+
"""
|
170 |
+
Initialize the weights specific to the BERT Model.
|
171 |
+
This overrides the default initializations depending on the specified arguments.
|
172 |
+
1. If normal_init_linear_weights is set then weights of linear
|
173 |
+
layer will be initialized using the normal distribution and
|
174 |
+
bais will be set to the specified value.
|
175 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
176 |
+
layer will be initialized using the normal distribution.
|
177 |
+
3. If normal_init_proj_weights is set then weights of
|
178 |
+
in_project_weight for MultiHeadAttention initialized using
|
179 |
+
the normal distribution (to be validated).
|
180 |
+
"""
|
181 |
+
|
182 |
+
def normal_(data):
|
183 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
184 |
+
# so that the RNG is consistent with and without FSDP
|
185 |
+
data.copy_(
|
186 |
+
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
187 |
+
)
|
188 |
+
|
189 |
+
if isinstance(module, nn.Linear):
|
190 |
+
normal_(module.weight.data)
|
191 |
+
if module.bias is not None:
|
192 |
+
module.bias.data.zero_()
|
193 |
+
if isinstance(module, nn.Embedding):
|
194 |
+
normal_(module.weight.data)
|
195 |
+
if module.padding_idx is not None:
|
196 |
+
module.weight.data[module.padding_idx].zero_()
|
197 |
+
if isinstance(module, MultiheadAttention):
|
198 |
+
normal_(module.q_proj.weight.data)
|
199 |
+
normal_(module.k_proj.weight.data)
|
200 |
+
normal_(module.v_proj.weight.data)
|
201 |
+
|
202 |
+
|
203 |
+
def quant_noise(module, p, block_size):
|
204 |
+
"""
|
205 |
+
Wraps modules and applies quantization noise to the weights for
|
206 |
+
subsequent quantization with Iterative Product Quantization as
|
207 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
208 |
+
|
209 |
+
Args:
|
210 |
+
- module: nn.Module
|
211 |
+
- p: amount of Quantization Noise
|
212 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
213 |
+
|
214 |
+
Remarks:
|
215 |
+
- Module weights must have the right sizes wrt the block size
|
216 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
217 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
218 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
219 |
+
- We implement the simplest form of noise here as stated in the paper
|
220 |
+
which consists in randomly dropping blocks
|
221 |
+
"""
|
222 |
+
|
223 |
+
# if no quantization noise, don't register hook
|
224 |
+
if p <= 0:
|
225 |
+
return module
|
226 |
+
|
227 |
+
# supported modules
|
228 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
229 |
+
|
230 |
+
# test whether module.weight has the right sizes wrt block_size
|
231 |
+
is_conv = module.weight.ndim == 4
|
232 |
+
|
233 |
+
# 2D matrix
|
234 |
+
if not is_conv:
|
235 |
+
assert (
|
236 |
+
module.weight.size(1) % block_size == 0
|
237 |
+
), "Input features must be a multiple of block sizes"
|
238 |
+
|
239 |
+
# 4D matrix
|
240 |
+
else:
|
241 |
+
# 1x1 convolutions
|
242 |
+
if module.kernel_size == (1, 1):
|
243 |
+
assert (
|
244 |
+
module.in_channels % block_size == 0
|
245 |
+
), "Input channels must be a multiple of block sizes"
|
246 |
+
# regular convolutions
|
247 |
+
else:
|
248 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
249 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
250 |
+
|
251 |
+
def _forward_pre_hook(mod, input):
|
252 |
+
# no noise for evaluation
|
253 |
+
if mod.training:
|
254 |
+
if not is_conv:
|
255 |
+
# gather weight and sizes
|
256 |
+
weight = mod.weight
|
257 |
+
in_features = weight.size(1)
|
258 |
+
out_features = weight.size(0)
|
259 |
+
|
260 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
261 |
+
mask = torch.zeros(
|
262 |
+
in_features // block_size * out_features, device=weight.device
|
263 |
+
)
|
264 |
+
mask.bernoulli_(p)
|
265 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
266 |
+
|
267 |
+
else:
|
268 |
+
# gather weight and sizes
|
269 |
+
weight = mod.weight
|
270 |
+
in_channels = mod.in_channels
|
271 |
+
out_channels = mod.out_channels
|
272 |
+
|
273 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
274 |
+
if mod.kernel_size == (1, 1):
|
275 |
+
mask = torch.zeros(
|
276 |
+
int(in_channels // block_size * out_channels),
|
277 |
+
device=weight.device,
|
278 |
+
)
|
279 |
+
mask.bernoulli_(p)
|
280 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
281 |
+
else:
|
282 |
+
mask = torch.zeros(
|
283 |
+
weight.size(0), weight.size(1), device=weight.device
|
284 |
+
)
|
285 |
+
mask.bernoulli_(p)
|
286 |
+
mask = (
|
287 |
+
mask.unsqueeze(2)
|
288 |
+
.unsqueeze(3)
|
289 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
290 |
+
)
|
291 |
+
|
292 |
+
# scale weights and apply mask
|
293 |
+
mask = mask.to(
|
294 |
+
torch.bool
|
295 |
+
) # x.bool() is not currently supported in TorchScript
|
296 |
+
s = 1 / (1 - p)
|
297 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
298 |
+
|
299 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
300 |
+
return module
|
301 |
+
|
302 |
+
|
303 |
+
class MultiheadAttention(nn.Module):
|
304 |
+
"""Multi-headed attention.
|
305 |
+
|
306 |
+
See "Attention Is All You Need" for more details.
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
embed_dim,
|
312 |
+
num_heads,
|
313 |
+
kdim=None,
|
314 |
+
vdim=None,
|
315 |
+
dropout=0.0,
|
316 |
+
bias=True,
|
317 |
+
add_bias_kv=False,
|
318 |
+
add_zero_attn=False,
|
319 |
+
self_attention=False,
|
320 |
+
encoder_decoder_attention=False,
|
321 |
+
q_noise=0.0,
|
322 |
+
qn_block_size=8,
|
323 |
+
has_relative_attention_bias=False,
|
324 |
+
num_buckets=32,
|
325 |
+
max_distance=128,
|
326 |
+
gru_rel_pos=False,
|
327 |
+
rescale_init=False,
|
328 |
+
):
|
329 |
+
super().__init__()
|
330 |
+
self.embed_dim = embed_dim
|
331 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
332 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
333 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
334 |
+
|
335 |
+
self.num_heads = num_heads
|
336 |
+
self.dropout_module = nn.Dropout(dropout)
|
337 |
+
|
338 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
339 |
+
self.num_buckets = num_buckets
|
340 |
+
self.max_distance = max_distance
|
341 |
+
if self.has_relative_attention_bias:
|
342 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
343 |
+
|
344 |
+
self.head_dim = embed_dim // num_heads
|
345 |
+
self.q_head_dim = self.head_dim
|
346 |
+
self.k_head_dim = self.head_dim
|
347 |
+
assert (
|
348 |
+
self.head_dim * num_heads == self.embed_dim
|
349 |
+
), "embed_dim must be divisible by num_heads"
|
350 |
+
self.scaling = self.head_dim ** -0.5
|
351 |
+
|
352 |
+
self.self_attention = self_attention
|
353 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
354 |
+
|
355 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
356 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
357 |
+
)
|
358 |
+
|
359 |
+
k_bias = True
|
360 |
+
if rescale_init:
|
361 |
+
k_bias = False
|
362 |
+
|
363 |
+
k_embed_dim = embed_dim
|
364 |
+
q_embed_dim = embed_dim
|
365 |
+
|
366 |
+
self.k_proj = quant_noise(
|
367 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
368 |
+
)
|
369 |
+
self.v_proj = quant_noise(
|
370 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
371 |
+
)
|
372 |
+
self.q_proj = quant_noise(
|
373 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
374 |
+
)
|
375 |
+
|
376 |
+
self.out_proj = quant_noise(
|
377 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
378 |
+
)
|
379 |
+
|
380 |
+
if add_bias_kv:
|
381 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
382 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
383 |
+
else:
|
384 |
+
self.bias_k = self.bias_v = None
|
385 |
+
|
386 |
+
self.add_zero_attn = add_zero_attn
|
387 |
+
|
388 |
+
self.gru_rel_pos = gru_rel_pos
|
389 |
+
if self.gru_rel_pos:
|
390 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
391 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
392 |
+
|
393 |
+
self.reset_parameters()
|
394 |
+
|
395 |
+
def reset_parameters(self):
|
396 |
+
if self.qkv_same_dim:
|
397 |
+
# Empirically observed the convergence to be much better with
|
398 |
+
# the scaled initialization
|
399 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
400 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
401 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
402 |
+
else:
|
403 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
404 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
405 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
406 |
+
|
407 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
408 |
+
if self.out_proj.bias is not None:
|
409 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
410 |
+
if self.bias_k is not None:
|
411 |
+
nn.init.xavier_normal_(self.bias_k)
|
412 |
+
if self.bias_v is not None:
|
413 |
+
nn.init.xavier_normal_(self.bias_v)
|
414 |
+
if self.has_relative_attention_bias:
|
415 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
416 |
+
|
417 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
418 |
+
num_buckets = self.num_buckets
|
419 |
+
max_distance = self.max_distance
|
420 |
+
relative_buckets = 0
|
421 |
+
|
422 |
+
if bidirectional:
|
423 |
+
num_buckets = num_buckets // 2
|
424 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
425 |
+
relative_positions = torch.abs(relative_positions)
|
426 |
+
else:
|
427 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
428 |
+
|
429 |
+
max_exact = num_buckets // 2
|
430 |
+
is_small = relative_positions < max_exact
|
431 |
+
|
432 |
+
relative_postion_if_large = max_exact + (
|
433 |
+
torch.log(relative_positions.float() / max_exact)
|
434 |
+
/ math.log(max_distance / max_exact)
|
435 |
+
* (num_buckets - max_exact)
|
436 |
+
).to(torch.long)
|
437 |
+
relative_postion_if_large = torch.min(
|
438 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
439 |
+
)
|
440 |
+
|
441 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
442 |
+
return relative_buckets
|
443 |
+
|
444 |
+
def compute_bias(self, query_length, key_length):
|
445 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
446 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
447 |
+
relative_position = memory_position - context_position
|
448 |
+
relative_position_bucket = self._relative_positions_bucket(
|
449 |
+
relative_position,
|
450 |
+
bidirectional=True
|
451 |
+
)
|
452 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
453 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
454 |
+
values = values.permute([2, 0, 1])
|
455 |
+
return values
|
456 |
+
|
457 |
+
def forward(
|
458 |
+
self,
|
459 |
+
query,
|
460 |
+
key: Optional[Tensor],
|
461 |
+
value: Optional[Tensor],
|
462 |
+
key_padding_mask: Optional[Tensor] = None,
|
463 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
464 |
+
need_weights: bool = True,
|
465 |
+
static_kv: bool = False,
|
466 |
+
attn_mask: Optional[Tensor] = None,
|
467 |
+
before_softmax: bool = False,
|
468 |
+
need_head_weights: bool = False,
|
469 |
+
position_bias: Optional[Tensor] = None
|
470 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
471 |
+
"""Input shape: Time x Batch x Channel
|
472 |
+
|
473 |
+
Args:
|
474 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
475 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
476 |
+
padding elements are indicated by 1s.
|
477 |
+
need_weights (bool, optional): return the attention weights,
|
478 |
+
averaged over heads (default: False).
|
479 |
+
attn_mask (ByteTensor, optional): typically used to
|
480 |
+
implement causal attention, where the mask prevents the
|
481 |
+
attention from looking forward in time (default: None).
|
482 |
+
before_softmax (bool, optional): return the raw attention
|
483 |
+
weights and values before the attention softmax.
|
484 |
+
need_head_weights (bool, optional): return the attention
|
485 |
+
weights for each head. Implies *need_weights*. Default:
|
486 |
+
return the average attention weights over all heads.
|
487 |
+
"""
|
488 |
+
if need_head_weights:
|
489 |
+
need_weights = True
|
490 |
+
|
491 |
+
is_tpu = query.device.type == "xla"
|
492 |
+
|
493 |
+
tgt_len, bsz, embed_dim = query.size()
|
494 |
+
src_len = tgt_len
|
495 |
+
assert embed_dim == self.embed_dim
|
496 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
497 |
+
if key is not None:
|
498 |
+
src_len, key_bsz, _ = key.size()
|
499 |
+
if not torch.jit.is_scripting():
|
500 |
+
assert key_bsz == bsz
|
501 |
+
assert value is not None
|
502 |
+
assert src_len, bsz == value.shape[:2]
|
503 |
+
|
504 |
+
if self.has_relative_attention_bias and position_bias is None:
|
505 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
506 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
507 |
+
|
508 |
+
if (
|
509 |
+
not is_tpu # don't use PyTorch version on TPUs
|
510 |
+
and incremental_state is None
|
511 |
+
and not static_kv
|
512 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
513 |
+
# treats bias in linear module as method.
|
514 |
+
and not torch.jit.is_scripting()
|
515 |
+
and self.q_head_dim == self.head_dim
|
516 |
+
):
|
517 |
+
assert key is not None and value is not None
|
518 |
+
assert attn_mask is None
|
519 |
+
|
520 |
+
attn_mask_rel_pos = None
|
521 |
+
if position_bias is not None:
|
522 |
+
attn_mask_rel_pos = position_bias
|
523 |
+
if self.gru_rel_pos:
|
524 |
+
query_layer = query.transpose(0, 1)
|
525 |
+
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
526 |
+
query_layer = query_layer.view(*new_x_shape)
|
527 |
+
query_layer = query_layer.permute(0, 2, 1, 3)
|
528 |
+
_B, _H, _L, __ = query_layer.size()
|
529 |
+
|
530 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
531 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
532 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
533 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
534 |
+
|
535 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
536 |
+
k_proj_bias = self.k_proj.bias
|
537 |
+
if k_proj_bias is None:
|
538 |
+
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
539 |
+
|
540 |
+
x, attn = F.multi_head_attention_forward(
|
541 |
+
query,
|
542 |
+
key,
|
543 |
+
value,
|
544 |
+
self.embed_dim,
|
545 |
+
self.num_heads,
|
546 |
+
torch.empty([0]),
|
547 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
548 |
+
self.bias_k,
|
549 |
+
self.bias_v,
|
550 |
+
self.add_zero_attn,
|
551 |
+
self.dropout_module.p,
|
552 |
+
self.out_proj.weight,
|
553 |
+
self.out_proj.bias,
|
554 |
+
self.training,
|
555 |
+
# self.training or self.dropout_module.apply_during_inference,
|
556 |
+
key_padding_mask,
|
557 |
+
need_weights,
|
558 |
+
attn_mask_rel_pos,
|
559 |
+
use_separate_proj_weight=True,
|
560 |
+
q_proj_weight=self.q_proj.weight,
|
561 |
+
k_proj_weight=self.k_proj.weight,
|
562 |
+
v_proj_weight=self.v_proj.weight,
|
563 |
+
)
|
564 |
+
return x, attn, position_bias
|
565 |
+
|
566 |
+
if incremental_state is not None:
|
567 |
+
saved_state = self._get_input_buffer(incremental_state)
|
568 |
+
if saved_state is not None and "prev_key" in saved_state:
|
569 |
+
# previous time steps are cached - no need to recompute
|
570 |
+
# key and value if they are static
|
571 |
+
if static_kv:
|
572 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
573 |
+
key = value = None
|
574 |
+
else:
|
575 |
+
saved_state = None
|
576 |
+
|
577 |
+
if self.self_attention:
|
578 |
+
q = self.q_proj(query)
|
579 |
+
k = self.k_proj(query)
|
580 |
+
v = self.v_proj(query)
|
581 |
+
elif self.encoder_decoder_attention:
|
582 |
+
# encoder-decoder attention
|
583 |
+
q = self.q_proj(query)
|
584 |
+
if key is None:
|
585 |
+
assert value is None
|
586 |
+
k = v = None
|
587 |
+
else:
|
588 |
+
k = self.k_proj(key)
|
589 |
+
v = self.v_proj(key)
|
590 |
+
|
591 |
+
else:
|
592 |
+
assert key is not None and value is not None
|
593 |
+
q = self.q_proj(query)
|
594 |
+
k = self.k_proj(key)
|
595 |
+
v = self.v_proj(value)
|
596 |
+
q *= self.scaling
|
597 |
+
|
598 |
+
if self.bias_k is not None:
|
599 |
+
assert self.bias_v is not None
|
600 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
601 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
602 |
+
if attn_mask is not None:
|
603 |
+
attn_mask = torch.cat(
|
604 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
605 |
+
)
|
606 |
+
if key_padding_mask is not None:
|
607 |
+
key_padding_mask = torch.cat(
|
608 |
+
[
|
609 |
+
key_padding_mask,
|
610 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
611 |
+
],
|
612 |
+
dim=1,
|
613 |
+
)
|
614 |
+
|
615 |
+
q = (
|
616 |
+
q.contiguous()
|
617 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
618 |
+
.transpose(0, 1)
|
619 |
+
)
|
620 |
+
if k is not None:
|
621 |
+
k = (
|
622 |
+
k.contiguous()
|
623 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
624 |
+
.transpose(0, 1)
|
625 |
+
)
|
626 |
+
if v is not None:
|
627 |
+
v = (
|
628 |
+
v.contiguous()
|
629 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
630 |
+
.transpose(0, 1)
|
631 |
+
)
|
632 |
+
|
633 |
+
if saved_state is not None:
|
634 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
635 |
+
if "prev_key" in saved_state:
|
636 |
+
_prev_key = saved_state["prev_key"]
|
637 |
+
assert _prev_key is not None
|
638 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
639 |
+
if static_kv:
|
640 |
+
k = prev_key
|
641 |
+
else:
|
642 |
+
assert k is not None
|
643 |
+
k = torch.cat([prev_key, k], dim=1)
|
644 |
+
src_len = k.size(1)
|
645 |
+
if "prev_value" in saved_state:
|
646 |
+
_prev_value = saved_state["prev_value"]
|
647 |
+
assert _prev_value is not None
|
648 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
649 |
+
if static_kv:
|
650 |
+
v = prev_value
|
651 |
+
else:
|
652 |
+
assert v is not None
|
653 |
+
v = torch.cat([prev_value, v], dim=1)
|
654 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
655 |
+
if "prev_key_padding_mask" in saved_state:
|
656 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
657 |
+
assert k is not None and v is not None
|
658 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
659 |
+
key_padding_mask=key_padding_mask,
|
660 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
661 |
+
batch_size=bsz,
|
662 |
+
src_len=k.size(1),
|
663 |
+
static_kv=static_kv,
|
664 |
+
)
|
665 |
+
|
666 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
667 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
668 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
669 |
+
# In this branch incremental_state is never None
|
670 |
+
assert incremental_state is not None
|
671 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
672 |
+
assert k is not None
|
673 |
+
assert k.size(1) == src_len
|
674 |
+
|
675 |
+
# This is part of a workaround to get around fork/join parallelism
|
676 |
+
# not supporting Optional types.
|
677 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
678 |
+
key_padding_mask = None
|
679 |
+
|
680 |
+
if key_padding_mask is not None:
|
681 |
+
assert key_padding_mask.size(0) == bsz
|
682 |
+
assert key_padding_mask.size(1) == src_len
|
683 |
+
|
684 |
+
if self.add_zero_attn:
|
685 |
+
assert v is not None
|
686 |
+
src_len += 1
|
687 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
688 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
689 |
+
if attn_mask is not None:
|
690 |
+
attn_mask = torch.cat(
|
691 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
692 |
+
)
|
693 |
+
if key_padding_mask is not None:
|
694 |
+
key_padding_mask = torch.cat(
|
695 |
+
[
|
696 |
+
key_padding_mask,
|
697 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
698 |
+
key_padding_mask
|
699 |
+
),
|
700 |
+
],
|
701 |
+
dim=1,
|
702 |
+
)
|
703 |
+
|
704 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
705 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
706 |
+
|
707 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
708 |
+
|
709 |
+
if attn_mask is not None:
|
710 |
+
attn_mask = attn_mask.unsqueeze(0)
|
711 |
+
attn_weights += attn_mask
|
712 |
+
|
713 |
+
if key_padding_mask is not None:
|
714 |
+
# don't attend to padding symbols
|
715 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
716 |
+
if not is_tpu:
|
717 |
+
attn_weights = attn_weights.masked_fill(
|
718 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
719 |
+
float("-inf"),
|
720 |
+
)
|
721 |
+
else:
|
722 |
+
attn_weights = attn_weights.transpose(0, 2)
|
723 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
724 |
+
attn_weights = attn_weights.transpose(0, 2)
|
725 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
726 |
+
|
727 |
+
if before_softmax:
|
728 |
+
return attn_weights, v, position_bias
|
729 |
+
|
730 |
+
if position_bias is not None:
|
731 |
+
if self.gru_rel_pos == 1:
|
732 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
733 |
+
_B, _H, _L, __ = query_layer.size()
|
734 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
735 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
736 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
737 |
+
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
738 |
+
|
739 |
+
position_bias = position_bias.view(attn_weights.size())
|
740 |
+
|
741 |
+
attn_weights = attn_weights + position_bias
|
742 |
+
|
743 |
+
attn_weights_float = F.softmax(
|
744 |
+
attn_weights, dim=-1
|
745 |
+
)
|
746 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
747 |
+
attn_probs = self.dropout_module(attn_weights)
|
748 |
+
|
749 |
+
assert v is not None
|
750 |
+
attn = torch.bmm(attn_probs, v)
|
751 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
752 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
753 |
+
attn = self.out_proj(attn)
|
754 |
+
attn_weights: Optional[Tensor] = None
|
755 |
+
if need_weights:
|
756 |
+
attn_weights = attn_weights_float.view(
|
757 |
+
bsz, self.num_heads, tgt_len, src_len
|
758 |
+
).transpose(1, 0)
|
759 |
+
if not need_head_weights:
|
760 |
+
# average attention weights over heads
|
761 |
+
attn_weights = attn_weights.mean(dim=0)
|
762 |
+
|
763 |
+
return attn, attn_weights, position_bias
|
764 |
+
|
765 |
+
@staticmethod
|
766 |
+
def _append_prev_key_padding_mask(
|
767 |
+
key_padding_mask: Optional[Tensor],
|
768 |
+
prev_key_padding_mask: Optional[Tensor],
|
769 |
+
batch_size: int,
|
770 |
+
src_len: int,
|
771 |
+
static_kv: bool,
|
772 |
+
) -> Optional[Tensor]:
|
773 |
+
# saved key padding masks have shape (bsz, seq_len)
|
774 |
+
if prev_key_padding_mask is not None and static_kv:
|
775 |
+
new_key_padding_mask = prev_key_padding_mask
|
776 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
777 |
+
new_key_padding_mask = torch.cat(
|
778 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
779 |
+
)
|
780 |
+
# During incremental decoding, as the padding token enters and
|
781 |
+
# leaves the frame, there will be a time when prev or current
|
782 |
+
# is None
|
783 |
+
elif prev_key_padding_mask is not None:
|
784 |
+
if src_len > prev_key_padding_mask.size(1):
|
785 |
+
filler = torch.zeros(
|
786 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
787 |
+
device=prev_key_padding_mask.device,
|
788 |
+
)
|
789 |
+
new_key_padding_mask = torch.cat(
|
790 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
791 |
+
)
|
792 |
+
else:
|
793 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
794 |
+
elif key_padding_mask is not None:
|
795 |
+
if src_len > key_padding_mask.size(1):
|
796 |
+
filler = torch.zeros(
|
797 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
798 |
+
device=key_padding_mask.device,
|
799 |
+
)
|
800 |
+
new_key_padding_mask = torch.cat(
|
801 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
802 |
+
)
|
803 |
+
else:
|
804 |
+
new_key_padding_mask = key_padding_mask.float()
|
805 |
+
else:
|
806 |
+
new_key_padding_mask = prev_key_padding_mask
|
807 |
+
return new_key_padding_mask
|
808 |
+
|
809 |
+
def _get_input_buffer(
|
810 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
811 |
+
) -> Dict[str, Optional[Tensor]]:
|
812 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
813 |
+
if result is not None:
|
814 |
+
return result
|
815 |
+
else:
|
816 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
817 |
+
return empty_result
|
818 |
+
|
819 |
+
def _set_input_buffer(
|
820 |
+
self,
|
821 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
822 |
+
buffer: Dict[str, Optional[Tensor]],
|
823 |
+
):
|
824 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
825 |
+
|
826 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
827 |
+
return attn_weights
|
slam_llm/policies/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
from slam_llm.policies.mixed_precision import *
|
5 |
+
from slam_llm.policies.wrapping import *
|
6 |
+
from slam_llm.policies.activation_checkpointing_functions import apply_fsdp_checkpointing
|
7 |
+
from slam_llm.policies.anyprecision_optimizer import AnyPrecisionAdamW
|
slam_llm/policies/activation_checkpointing_functions.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
7 |
+
checkpoint_wrapper,
|
8 |
+
CheckpointImpl,
|
9 |
+
apply_activation_checkpointing,
|
10 |
+
)
|
11 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
12 |
+
|
13 |
+
non_reentrant_wrapper = partial(
|
14 |
+
checkpoint_wrapper,
|
15 |
+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
16 |
+
)
|
17 |
+
|
18 |
+
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
|
19 |
+
|
20 |
+
|
21 |
+
def apply_fsdp_checkpointing(model):
|
22 |
+
"""apply activation checkpointing to model
|
23 |
+
returns None as model is updated directly
|
24 |
+
"""
|
25 |
+
print(f"--> applying fsdp activation checkpointing...")
|
26 |
+
|
27 |
+
apply_activation_checkpointing(
|
28 |
+
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
|
29 |
+
)
|
slam_llm/policies/anyprecision_optimizer.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
# AnyPrecisionAdamW: a flexible precision AdamW optimizer
|
5 |
+
# with optional Kahan summation for high precision weight updates.
|
6 |
+
# Allows direct control over momentum, variance and auxiliary compensation
|
7 |
+
# buffer dtypes.
|
8 |
+
# Optional Kahan summation is used to offset precision reduction for
|
9 |
+
# the weight updates. This allows full training in BFloat16 (equal or
|
10 |
+
# better than FP32 results in many cases) due to high precision weight upates.
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.optim.optimizer import Optimizer
|
14 |
+
|
15 |
+
|
16 |
+
class AnyPrecisionAdamW(Optimizer):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
params,
|
20 |
+
lr=1e-3,
|
21 |
+
betas=(0.9, 0.999),
|
22 |
+
eps=1e-8,
|
23 |
+
weight_decay=0.0,
|
24 |
+
use_kahan_summation=False,
|
25 |
+
momentum_dtype=torch.bfloat16,
|
26 |
+
variance_dtype=torch.bfloat16,
|
27 |
+
compensation_buffer_dtype=torch.bfloat16,
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
32 |
+
parameter groups
|
33 |
+
lr (float, optional): learning rate (default: 1e-3)
|
34 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
35 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
36 |
+
eps (float, optional): term added to the denominator to improve
|
37 |
+
numerical stability (default: 1e-8)
|
38 |
+
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
39 |
+
|
40 |
+
# Any Precision specific
|
41 |
+
use_kahan_summation = creates auxiliary buffer to ensure high precision
|
42 |
+
model param updates (default: False)
|
43 |
+
momentum_dtype = dtype for momentum (default: BFloat32)
|
44 |
+
variance_dtype = dtype for uncentered variance (default: BFloat16)
|
45 |
+
compensation_buffer_dtype = dtype for Kahan summation
|
46 |
+
buffer (default: BFloat16)
|
47 |
+
|
48 |
+
# Usage
|
49 |
+
This optimizer implements optimizer states, and Kahan summation
|
50 |
+
for high precision updates, all in user controlled dtypes.
|
51 |
+
Defaults are variance in BF16, Momentum in FP32.
|
52 |
+
This can be run in FSDP mixed precision, amp, or full precision,
|
53 |
+
depending on what training pipeline you wish to work with.
|
54 |
+
|
55 |
+
Setting to use_kahan_summation = False, and changing momentum and
|
56 |
+
variance dtypes to FP32, reverts this to a standard AdamW optimizer.
|
57 |
+
|
58 |
+
"""
|
59 |
+
defaults = dict(
|
60 |
+
lr=lr,
|
61 |
+
betas=betas,
|
62 |
+
eps=eps,
|
63 |
+
weight_decay=weight_decay,
|
64 |
+
use_kahan_summation=use_kahan_summation,
|
65 |
+
momentum_dtype=momentum_dtype,
|
66 |
+
variance_dtype=variance_dtype,
|
67 |
+
compensation_buffer_dtype=compensation_buffer_dtype,
|
68 |
+
)
|
69 |
+
|
70 |
+
super().__init__(params, defaults)
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
def step(self, closure=None):
|
74 |
+
"""Performs a single optimization step.
|
75 |
+
Args:
|
76 |
+
closure (callable, optional): A closure that reevaluates the model
|
77 |
+
and returns the loss.
|
78 |
+
"""
|
79 |
+
|
80 |
+
if closure is not None:
|
81 |
+
with torch.enable_grad():
|
82 |
+
# to fix linter, we do not keep the returned loss for use atm.
|
83 |
+
closure()
|
84 |
+
|
85 |
+
for group in self.param_groups:
|
86 |
+
|
87 |
+
beta1, beta2 = group["betas"]
|
88 |
+
lr = group["lr"]
|
89 |
+
weight_decay = group["weight_decay"]
|
90 |
+
eps = group["eps"]
|
91 |
+
use_kahan_summation = group["use_kahan_summation"]
|
92 |
+
|
93 |
+
momentum_dtype = group["momentum_dtype"]
|
94 |
+
variance_dtype = group["variance_dtype"]
|
95 |
+
compensation_buffer_dtype = group["compensation_buffer_dtype"]
|
96 |
+
|
97 |
+
for p in group["params"]:
|
98 |
+
if p.grad is None:
|
99 |
+
continue
|
100 |
+
|
101 |
+
if p.grad.is_sparse:
|
102 |
+
raise RuntimeError(
|
103 |
+
"AnyPrecisionAdamW does not support sparse gradients"
|
104 |
+
)
|
105 |
+
|
106 |
+
state = self.state[p]
|
107 |
+
|
108 |
+
# State initialization
|
109 |
+
if len(state) == 0:
|
110 |
+
|
111 |
+
state["step"] = torch.tensor(0.0)
|
112 |
+
|
113 |
+
# momentum - EMA of gradient values
|
114 |
+
state["exp_avg"] = torch.zeros_like(
|
115 |
+
p,
|
116 |
+
dtype=momentum_dtype,
|
117 |
+
)
|
118 |
+
|
119 |
+
# variance uncentered - EMA of squared gradient values
|
120 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
121 |
+
p,
|
122 |
+
dtype=variance_dtype,
|
123 |
+
)
|
124 |
+
|
125 |
+
# optional Kahan summation - accumulated error tracker
|
126 |
+
if use_kahan_summation:
|
127 |
+
state["compensation"] = torch.zeros_like(
|
128 |
+
p,
|
129 |
+
dtype=compensation_buffer_dtype,
|
130 |
+
)
|
131 |
+
|
132 |
+
# main processing -------------------------
|
133 |
+
|
134 |
+
# update the steps for each param group update
|
135 |
+
state["step"] += 1
|
136 |
+
step = state["step"]
|
137 |
+
|
138 |
+
exp_avg = state["exp_avg"]
|
139 |
+
exp_avg_sq = state["exp_avg_sq"]
|
140 |
+
|
141 |
+
grad = p.grad
|
142 |
+
|
143 |
+
# weight decay, AdamW style
|
144 |
+
if weight_decay:
|
145 |
+
p.data.mul_(1 - lr * weight_decay)
|
146 |
+
|
147 |
+
# update momentum
|
148 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
149 |
+
|
150 |
+
# update uncentered variance
|
151 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
152 |
+
|
153 |
+
# adjust using bias1
|
154 |
+
bias_correction1 = 1 - beta1**step
|
155 |
+
|
156 |
+
step_size = lr / bias_correction1
|
157 |
+
|
158 |
+
# adjust using bias2
|
159 |
+
denom_correction = (1 - beta2**step) ** 0.5 # avoids math import
|
160 |
+
|
161 |
+
centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(
|
162 |
+
eps, alpha=1
|
163 |
+
)
|
164 |
+
|
165 |
+
# lr update to compensation
|
166 |
+
if use_kahan_summation:
|
167 |
+
compensation = state["compensation"]
|
168 |
+
|
169 |
+
compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
|
170 |
+
|
171 |
+
# update weights with compensation (Kahan summation)
|
172 |
+
# save error back to compensation for next iteration
|
173 |
+
temp_buffer = p.detach().clone()
|
174 |
+
p.data.add_(compensation)
|
175 |
+
compensation.add_(temp_buffer.sub_(p.data))
|
176 |
+
|
177 |
+
else:
|
178 |
+
# usual AdamW updates
|
179 |
+
p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)
|
slam_llm/policies/mixed_precision.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from torch.distributed.fsdp import (
|
7 |
+
MixedPrecision,
|
8 |
+
)
|
9 |
+
|
10 |
+
# requires grad scaler in main loop
|
11 |
+
fpSixteen = MixedPrecision(
|
12 |
+
param_dtype=torch.float16,
|
13 |
+
# Gradient communication precision.
|
14 |
+
reduce_dtype=torch.float16,
|
15 |
+
# Buffer precision.
|
16 |
+
buffer_dtype=torch.float16,
|
17 |
+
)
|
18 |
+
|
19 |
+
bfSixteen = MixedPrecision(
|
20 |
+
param_dtype=torch.bfloat16,
|
21 |
+
# Gradient communication precision.
|
22 |
+
reduce_dtype=torch.bfloat16,
|
23 |
+
# Buffer precision.
|
24 |
+
buffer_dtype=torch.bfloat16,
|
25 |
+
cast_forward_inputs=True,
|
26 |
+
)
|
27 |
+
|
28 |
+
bfSixteen_mixed = MixedPrecision(
|
29 |
+
param_dtype=torch.float32,
|
30 |
+
reduce_dtype=torch.bfloat16,
|
31 |
+
buffer_dtype=torch.bfloat16,
|
32 |
+
)
|
33 |
+
|
34 |
+
fp32_policy = MixedPrecision(
|
35 |
+
param_dtype=torch.float32,
|
36 |
+
reduce_dtype=torch.float32,
|
37 |
+
buffer_dtype=torch.float32,
|
38 |
+
)
|
slam_llm/policies/wrapping.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import functools
|
5 |
+
|
6 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
7 |
+
from torch.distributed.fsdp.wrap import (
|
8 |
+
transformer_auto_wrap_policy,
|
9 |
+
size_based_auto_wrap_policy,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def get_size_policy(min_params=1e8):
|
14 |
+
num_wrap_policy = functools.partial(
|
15 |
+
size_based_auto_wrap_policy, min_num_params=min_params
|
16 |
+
)
|
17 |
+
return num_wrap_policy
|
18 |
+
|
19 |
+
|
20 |
+
def get_llama_wrapper():
|
21 |
+
"""we register our main layer class and use the fsdp transformer wrapping policy
|
22 |
+
ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers
|
23 |
+
"""
|
24 |
+
# ==== use new transformer wrapper
|
25 |
+
|
26 |
+
llama_auto_wrap_policy = functools.partial(
|
27 |
+
transformer_auto_wrap_policy,
|
28 |
+
transformer_layer_cls={
|
29 |
+
LlamaDecoderLayer,
|
30 |
+
},
|
31 |
+
)
|
32 |
+
|
33 |
+
return llama_auto_wrap_policy
|
slam_llm/utils/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
from slam_llm.utils.memory_utils import MemoryTrace
|
5 |
+
from slam_llm.utils.dataset_utils import *
|
6 |
+
from slam_llm.utils.fsdp_utils import fsdp_auto_wrap_policy
|
7 |
+
from slam_llm.utils.train_utils import *
|