xcczach commited on
Commit
35c1cfd
·
verified ·
1 Parent(s): 7c82bbb

Upload 73 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. model/slam_model_s2s.py +444 -0
  2. s2s.py +178 -0
  3. s2s_config.py +272 -0
  4. slam_llm/__init__.py +0 -0
  5. slam_llm/data/__init__.py +2 -0
  6. slam_llm/data/concatenator.py +34 -0
  7. slam_llm/data/sampler.py +57 -0
  8. slam_llm/models/BEATs/BEATs.py +181 -0
  9. slam_llm/models/BEATs/Tokenizers.py +173 -0
  10. slam_llm/models/BEATs/backbone.py +783 -0
  11. slam_llm/models/BEATs/modules.py +219 -0
  12. slam_llm/models/BEATs/quantizer.py +215 -0
  13. slam_llm/models/EAT/EAT.py +32 -0
  14. slam_llm/models/SpatialAST/SpatialAST.py +122 -0
  15. slam_llm/models/SpatialAST/vision_transformer.py +239 -0
  16. slam_llm/models/avhubert/__init__.py +10 -0
  17. slam_llm/models/avhubert/decoder.py +243 -0
  18. slam_llm/models/avhubert/hubert.py +792 -0
  19. slam_llm/models/avhubert/hubert_asr.py +523 -0
  20. slam_llm/models/avhubert/hubert_criterion.py +169 -0
  21. slam_llm/models/avhubert/hubert_dataset.py +529 -0
  22. slam_llm/models/avhubert/hubert_pretraining.py +401 -0
  23. slam_llm/models/avhubert/infer_s2s.py +318 -0
  24. slam_llm/models/avhubert/resnet.py +169 -0
  25. slam_llm/models/avhubert/sequence_generator.py +985 -0
  26. slam_llm/models/avhubert/utils.py +298 -0
  27. slam_llm/models/encoder.py +158 -0
  28. slam_llm/models/musicfm/model/__init__.py +2 -0
  29. slam_llm/models/musicfm/model/musicfm_25hz.py +253 -0
  30. slam_llm/models/musicfm/modules/__init__.py +2 -0
  31. slam_llm/models/musicfm/modules/conv.py +82 -0
  32. slam_llm/models/musicfm/modules/features.py +45 -0
  33. slam_llm/models/musicfm/modules/flash_conformer.py +2114 -0
  34. slam_llm/models/musicfm/modules/random_quantizer.py +83 -0
  35. slam_llm/models/projector.py +81 -0
  36. slam_llm/models/slam_model.py +443 -0
  37. slam_llm/models/vallex/__init__.py +0 -0
  38. slam_llm/models/vallex/activation.py +179 -0
  39. slam_llm/models/vallex/scaling.py +1404 -0
  40. slam_llm/models/vallex/transformers.py +613 -0
  41. slam_llm/models/vallex/vallex_config.py +56 -0
  42. slam_llm/models/vallex/vallex_model.py +772 -0
  43. slam_llm/models/wavlm/WavLM.py +743 -0
  44. slam_llm/models/wavlm/modules.py +827 -0
  45. slam_llm/policies/__init__.py +7 -0
  46. slam_llm/policies/activation_checkpointing_functions.py +29 -0
  47. slam_llm/policies/anyprecision_optimizer.py +179 -0
  48. slam_llm/policies/mixed_precision.py +38 -0
  49. slam_llm/policies/wrapping.py +33 -0
  50. slam_llm/utils/__init__.py +7 -0
model/slam_model_s2s.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import logging
4
+ import torch.nn.functional as F
5
+ from slam_llm.models.slam_model import (
6
+ slam_model,
7
+ setup_tokenizer,
8
+ setup_encoder,
9
+ setup_encoder_projector,
10
+ setup_llm,
11
+ )
12
+ from slam_llm.utils.train_utils import print_model_size
13
+ from typing import List, Optional
14
+ from slam_llm.utils.metric import compute_accuracy
15
+ from transformers import T5ForConditionalGeneration
16
+ from tqdm import tqdm
17
+ from utils.tts_adapter_utils import setup_tts_adapter
18
+ from utils.codec_utils import setup_codec
19
+ from utils.trick_utils import partial_freeze_weights, train_embedding_layer_only
20
+ from utils.snac_utils import layershift
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def model_factory(train_config, model_config, ckpt_path, **kwargs):
26
+ # return necessary components for training
27
+ tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
28
+
29
+ if train_config.task_type == "s2s" or train_config.task_type == "asr":
30
+ encoder = setup_encoder(train_config, model_config, **kwargs)
31
+ elif train_config.task_type == "tts":
32
+ encoder = None
33
+ else:
34
+ raise NotImplementedError
35
+
36
+ # llm
37
+ llm = setup_llm(train_config, model_config, **kwargs)
38
+
39
+ # projector
40
+ if encoder is not None:
41
+ encoder_projector = setup_encoder_projector(
42
+ train_config, model_config, **kwargs
43
+ )
44
+ else:
45
+ encoder_projector = None
46
+
47
+ codec_decoder = None
48
+ if model_config.codec_decode:
49
+ codec_decoder = setup_codec(train_config, model_config, **kwargs)
50
+
51
+ tts_adapter = None
52
+ if model_config.tts_adapter:
53
+ adapter_config = model_config.tts_adapter_config
54
+ tts_adapter = setup_tts_adapter(adapter_config, model_config, **kwargs)
55
+
56
+ model = slam_model_s2s(
57
+ encoder,
58
+ llm,
59
+ encoder_projector,
60
+ tokenizer,
61
+ tts_adapter,
62
+ codec_decoder,
63
+ train_config,
64
+ model_config,
65
+ **kwargs,
66
+ )
67
+
68
+ if ckpt_path is not None:
69
+ logger.info("loading other parts from: {}".format(ckpt_path))
70
+ ckpt_dict = torch.load(ckpt_path, map_location="cpu")
71
+ model.load_state_dict(ckpt_dict, strict=False)
72
+
73
+ if train_config.train_audio_embed_only:
74
+ partial_freeze_weights(model, model_config.vocab_config.padded_text_vocabsize, model_config.vocab_config.total_vocabsize)
75
+
76
+ if train_config.train_embed_only:
77
+ train_embedding_layer_only(model)
78
+
79
+ print_model_size(
80
+ model,
81
+ train_config,
82
+ (
83
+ int(os.environ["RANK"])
84
+ if train_config.enable_fsdp or train_config.enable_ddp
85
+ else 0
86
+ ),
87
+ )
88
+ return model, tokenizer
89
+
90
+
91
+ class slam_model_s2s(slam_model):
92
+ def __init__(
93
+ self,
94
+ encoder,
95
+ llm,
96
+ encoder_projector,
97
+ tokenizer,
98
+ tts_adapter,
99
+ codec_decoder,
100
+ train_config,
101
+ model_config,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(
105
+ encoder,
106
+ llm,
107
+ encoder_projector,
108
+ tokenizer,
109
+ train_config,
110
+ model_config,
111
+ **kwargs,
112
+ )
113
+
114
+ # resize llm embedding layer
115
+ self.original_vocabsize = self.llm.lm_head.weight.size(0)
116
+ if self.model_config.vocab_config.total_vocabsize != self.original_vocabsize:
117
+ self.llm.resize_token_embeddings(self.model_config.vocab_config.total_vocabsize)
118
+
119
+ if int(os.environ.get("RANK", "0")) == 0:
120
+ logger.info("Resize llm embedding layer's vocab size to {}".format(self.model_config.vocab_config.total_vocabsize))
121
+
122
+ self.codec_decoder = codec_decoder
123
+ self.tts_adapter = tts_adapter
124
+ self.code_layer = self.model_config.vocab_config.code_layer
125
+
126
+
127
+ def forward(self,
128
+ input_ids: torch.LongTensor = None,
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ position_ids: Optional[torch.LongTensor] = None,
131
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
132
+ inputs_embeds: Optional[torch.FloatTensor] = None,
133
+ labels: Optional[torch.LongTensor] = None,
134
+ use_cache: Optional[bool] = None,
135
+ output_attentions: Optional[bool] = None,
136
+ output_hidden_states: Optional[bool] = None,
137
+ return_dict: Optional[bool] = None,
138
+ **kwargs,
139
+ ):
140
+ audio_mel = kwargs.get("audio_mel", None)
141
+ audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
142
+
143
+ audio = kwargs.get("audio", None)
144
+ audio_mask = kwargs.get("audio_mask", None)
145
+
146
+ modality_mask = kwargs.get("modality_mask", None)
147
+
148
+ encoder_outs = None
149
+ if audio_mel is not None or audio is not None:
150
+ if self.train_config.freeze_encoder: # freeze encoder
151
+ self.encoder.eval()
152
+
153
+ if self.model_config.encoder_name == "whisper":
154
+ encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim
155
+ if self.model_config.encoder_name == "wavlm":
156
+ encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask
157
+ if self.model_config.encoder_name == "hubert":
158
+ results = self.encoder(source = audio, padding_mask = 1-audio_mask)
159
+ if self.model_config.encoder_type == "pretrain":
160
+ encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"]
161
+ if self.model_config.encoder_type == "finetune":
162
+ encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
163
+ encoder_outs = encoder_outs.transpose(0, 1)
164
+ if self.encoder is None:
165
+ encoder_outs = audio_mel if audio_mel is not None else audio
166
+
167
+ if self.model_config.encoder_projector == "q-former":
168
+ encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
169
+ if self.model_config.encoder_projector == "linear":
170
+ encoder_outs = self.encoder_projector(encoder_outs)
171
+ if self.model_config.encoder_projector == "cov1d-linear":
172
+ encoder_outs = self.encoder_projector(encoder_outs)
173
+
174
+ if input_ids is not None:
175
+ input_ids[input_ids == -1] = 0 # [btz, 8, seq_length]
176
+
177
+ if isinstance(self.llm, T5ForConditionalGeneration):
178
+ inputs_embeds = self.llm.shared(input_ids)
179
+ else:
180
+ if hasattr(self.llm.model, "embed_tokens"):
181
+ inputs_embeds = self.llm.model.embed_tokens(input_ids) # [btz, 8, seq_length, emb_dim]
182
+ elif hasattr(self.llm.model.model, "embed_tokens"):
183
+ inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
184
+ else:
185
+ inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
186
+
187
+ if modality_mask is not None and encoder_outs is not None:
188
+ modality_mask = modality_mask.unsqueeze(1).repeat(1, self.code_layer, 1) # [btz, 8, seq_length]
189
+ modality_mask_start_indices = (modality_mask == True).float().argmax(dim=2)
190
+ modality_lengths = torch.clamp(modality_mask.sum(dim=2), max=encoder_outs.shape[1]).tolist()
191
+
192
+ encoder_outs_pad = torch.zeros_like(inputs_embeds)
193
+ for i in range(encoder_outs.shape[0]):
194
+ for j in range(self.code_layer):
195
+ start_idx = modality_mask_start_indices[i, j].item()
196
+ length = modality_lengths[i][j]
197
+ encoder_outs_pad[i, j, start_idx:start_idx+length] = encoder_outs[i, :length]
198
+
199
+ inputs_embeds[:, :self.code_layer, :, :] = encoder_outs_pad[:, :self.code_layer, :, :] + inputs_embeds[:, :self.code_layer, :, :] * (~modality_mask[:, :, :, None])
200
+
201
+ inputs_embeds = torch.mean(inputs_embeds, dim=1) # [btz, seq_length, emb_dim], average over the 8 layers
202
+
203
+ if kwargs.get("inference_mode", False):
204
+ return inputs_embeds, attention_mask
205
+
206
+ text_labels = labels[:,self.code_layer] if labels is not None else None
207
+ audio_labels = labels[:, :self.code_layer] if labels is not None else None
208
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=text_labels) # here we use the text token layer as the target label
209
+
210
+ # parrallel generation
211
+ # TODO: add tts adapter forward
212
+ x_ori = model_outputs.logits
213
+ text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize
214
+ audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize
215
+ xt = x_ori[..., :text_vocab_size]
216
+ xa = []
217
+ for i in range(self.code_layer):
218
+ xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
219
+
220
+ loss_recorder = []
221
+ total_loss, loss_recorder = self.compute_parallel_loss(xt, text_labels, xa, audio_labels)
222
+ model_outputs.loss = total_loss
223
+
224
+ text_acc = -1
225
+ audio_acc = [-1 for _ in range(self.code_layer)]
226
+ if self.metric:
227
+ with torch.no_grad():
228
+ preds = torch.argmax(xt, -1)
229
+ text_acc = compute_accuracy(preds.detach()[:, :-1], text_labels.detach()[:, 1:], ignore_label=-100)
230
+
231
+ preds_audio = [torch.argmax(xa[i], -1) for i in range(self.code_layer)]
232
+ audio_acc = [compute_accuracy(preds_audio[i].detach()[:, :-1], audio_labels[:, i, 1:], ignore_label=-100) for i in range(self.code_layer)]
233
+
234
+ # metrics = {"text_acc": text_acc, "audio_acc": audio_acc, "layer_loss": loss_recorder}
235
+ return model_outputs, text_acc, audio_acc, loss_recorder
236
+
237
+
238
+
239
+ def compute_parallel_loss(self, xt, text_labels, xa, audio_labels):
240
+ """
241
+ Compute the parallel loss for text and audio layers.
242
+ """
243
+ text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize
244
+ audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize
245
+ layer_loss = [0 for _ in range(self.code_layer+1) ]
246
+
247
+ if text_labels is not None:
248
+ # text_loss = F.cross_entropy(xt.reshape(-1, text_vocab_size), text_labels.reshape(-1), ignore_index=-100)
249
+ text_loss = F.cross_entropy(xt[:, :-1, :].reshape(-1, text_vocab_size), text_labels[:, 1:].reshape(-1), ignore_index=-100)
250
+ layer_loss[self.code_layer] = text_loss
251
+ else:
252
+ text_loss = 0
253
+
254
+ total_audio_loss = 0
255
+ single_audio_loss = 0
256
+ for i in range(self.code_layer):
257
+ if audio_labels[:,i] is not None:
258
+ # audio_loss += F.cross_entropy(xa[i].reshape(-1, audio_vocab_size), audio_labels[:,i].reshape(-1), ignore_index=-100)
259
+ single_audio_loss = F.cross_entropy(xa[i][:, :-1, :].reshape(-1, audio_vocab_size), audio_labels[:, i, 1:].reshape(-1), ignore_index=-100)
260
+ layer_loss[i] = single_audio_loss
261
+ total_audio_loss += single_audio_loss
262
+
263
+ total_loss = (text_loss + total_audio_loss) / (self.code_layer+1)
264
+ return total_loss, layer_loss
265
+
266
+
267
+ @torch.no_grad()
268
+ def generate(self,
269
+ input_ids: torch.LongTensor = None,
270
+ attention_mask: Optional[torch.Tensor] = None,
271
+ position_ids: Optional[torch.LongTensor] = None,
272
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
273
+ inputs_embeds: Optional[torch.FloatTensor] = None,
274
+ labels: Optional[torch.LongTensor] = None,
275
+ use_cache: Optional[bool] = None,
276
+ output_attentions: Optional[bool] = None,
277
+ output_hidden_states: Optional[bool] = None,
278
+ return_dict: Optional[bool] = None,
279
+ **kwargs,
280
+ ):
281
+ kwargs["inference_mode"] = True
282
+
283
+ inputs_embeds, attention_mask = self.forward(
284
+ input_ids=input_ids,
285
+ attention_mask=attention_mask,
286
+ position_ids=position_ids,
287
+ past_key_values=past_key_values,
288
+ inputs_embeds=inputs_embeds,
289
+ labels=labels,
290
+ use_cache=use_cache,
291
+ output_attentions=output_attentions,
292
+ output_hidden_states=output_hidden_states,
293
+ return_dict=return_dict,
294
+ **kwargs,
295
+ )
296
+
297
+ generated_ids = [[] for _ in range((self.code_layer+1))]
298
+ current_input_text = None
299
+ current_audio_tokens = [None for _ in range(self.code_layer)]
300
+ # input_pos = torch.arange(input_ids.size(-1), device=input_ids.device).unsqueeze(0)
301
+ past_key_values = None
302
+
303
+ text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize
304
+ audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize
305
+
306
+ max_new_tokens = kwargs.get("max_new_tokens", 360)
307
+ repetition_penalty = kwargs.get("repetition_penalty", 1.0)
308
+ decode_text_only = kwargs.get("decode_text_only", False)
309
+
310
+ pad_t = self.model_config.vocab_config.pad_t
311
+ pad_a = self.model_config.vocab_config.pad_a
312
+ eot = self.model_config.vocab_config.eot
313
+ eoa = self.model_config.vocab_config.eoa
314
+
315
+ text_end = False # Track whether text generation has ended
316
+ audio_end = False # Track whether audio generation has ended
317
+
318
+ # NOTE: currently, we only support greedy decoding and sampling for parallel generation, no beam search
319
+ for step in tqdm(range(max_new_tokens), desc="Generating"):
320
+ if current_input_text is not None:
321
+ audio_tokens = torch.cat([layershift(current_audio_tokens[i], i).unsqueeze(1) for i in range(self.code_layer)], dim=1)
322
+ combined_input_ids = torch.cat([audio_tokens, current_input_text.unsqueeze(1)], dim=1)
323
+ inputs_embeds = self.llm.model.embed_tokens(combined_input_ids)
324
+ inputs_embeds = torch.mean(inputs_embeds, dim=1).unsqueeze(1)
325
+
326
+ outputs = self.llm(
327
+ inputs_embeds=inputs_embeds, # [btz, seq_len / 1, emb_dim]
328
+ attention_mask=attention_mask, # single sample, no need for attention mask
329
+ past_key_values=past_key_values,
330
+ # position_ids=input_pos,
331
+ use_cache=True,
332
+ )
333
+
334
+ logits = outputs.logits
335
+ past_key_values = outputs.past_key_values # Update past_key_values for the next step
336
+
337
+ # Split logits into text and audio layers based on vocab size
338
+ xt_logits = logits[..., :text_vocab_size]
339
+ xa_logits = [logits[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)] for i in range(self.code_layer)]
340
+
341
+ # Apply repetition penalty to the logits
342
+ if repetition_penalty != 1.0:
343
+ xt_logits = self.repetition_penalty(xt_logits, generated_ids[self.code_layer], repetition_penalty)
344
+ for i in range(self.code_layer):
345
+ xa_logits[i] = self.repetition_penalty(xa_logits[i], generated_ids[i], repetition_penalty)
346
+
347
+ if not text_end:
348
+ next_token_text = self.sample_next_token(xt_logits[:, -1, :], **kwargs)
349
+ else:
350
+ next_token_text = torch.tensor([pad_t], device=input_ids.device)
351
+
352
+ next_tokens_audio = []
353
+ for i in range(self.code_layer):
354
+ if not audio_end and not decode_text_only:
355
+ next_token_audio = self.sample_next_token(xa_logits[i][:, -1, :], **kwargs)
356
+ else:
357
+ next_token_audio = torch.full((input_ids.size(0),), pad_a, device=input_ids.device)
358
+ next_tokens_audio.append(next_token_audio)
359
+
360
+ if next_tokens_audio[-1] == eoa or decode_text_only:
361
+ audio_end = True
362
+ if next_token_text == eot:
363
+ text_end = True
364
+
365
+ # Update input_ids for the next step
366
+ current_input_text = next_token_text
367
+ for i in range(self.code_layer):
368
+ current_audio_tokens[i] = next_tokens_audio[i]
369
+
370
+ # if input_pos.size(-1) > 1:
371
+ # input_pos = torch.tensor(input_pos.size(-1), device=input_ids.device).unsqueeze(0)
372
+ # else:
373
+ # input_pos = input_pos.add_(1)
374
+ attention_mask = torch.cat([attention_mask, torch.ones((input_ids.size(0), 1), device=input_ids.device)], dim=1)
375
+
376
+ if audio_end and text_end:
377
+ break
378
+
379
+ # Append generated tokens to the list
380
+ for i in range(self.code_layer):
381
+ generated_ids[i].append(next_tokens_audio[i].clone().tolist()[0]) # Audio layers
382
+ generated_ids[self.code_layer].append(next_token_text.clone().tolist()[0]) # Text layer
383
+
384
+ # Concatenate the generated tokens to form the complete sequence
385
+ text_tokens = generated_ids[-1]
386
+ generated_ids[-1] = text_tokens[: text_tokens.index(eot)] if eot in text_tokens else text_tokens
387
+ generated_ids = [torch.tensor(layer) for layer in generated_ids]
388
+ return generated_ids
389
+
390
+
391
+ @torch.no_grad()
392
+ def sample_next_token(self, logits, **kwargs):
393
+ """
394
+ Generate the next token based on the model output logits.
395
+ Supports both greedy decoding, top-k sampling, and top-p (nucleus) sampling.
396
+ """
397
+ do_sample = kwargs.get("do_sample", False)
398
+ temperature = kwargs.get("temperature", 1.0)
399
+ top_k = kwargs.get("top_k", 50)
400
+ top_p = kwargs.get("top_p", 1.0)
401
+ num_samples = kwargs.get("num_samples", 1)
402
+
403
+ # Adjust logits with temperature
404
+ logits = logits.squeeze(0)
405
+ logits = logits / temperature
406
+
407
+ # Top-k filtering
408
+ if top_k > 0:
409
+ top_k = min(top_k, logits.size(-1)) # Make sure top_k is within the vocab size
410
+ values, indices = torch.topk(logits, top_k)
411
+ logits[logits < values[..., [-1]]] = -float('Inf') # Filter tokens not in top_k
412
+
413
+ # Top-p filtering (nucleus sampling)
414
+ if top_p < 1.0:
415
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
416
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
417
+
418
+ # Remove tokens with cumulative probability above the threshold
419
+ sorted_indices_to_remove = cumulative_probs > top_p
420
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
421
+ sorted_indices_to_remove[..., 0] = 0
422
+
423
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
424
+ logits[indices_to_remove] = -float('Inf')
425
+
426
+ if do_sample:
427
+ # Perform sampling
428
+ return torch.multinomial(F.softmax(logits, dim=-1), num_samples=num_samples)
429
+ else:
430
+ # Greedy decoding (argmax)
431
+ return torch.argmax(logits, dim=-1, keepdim=True)
432
+
433
+
434
+ def repetition_penalty(self, logits, generated_ids, repetition_penalty):
435
+ """
436
+ Apply repetition penalty to the logits.
437
+ """
438
+ for token_id in set(generated_ids):
439
+ if logits[0, -1, token_id] < 0:
440
+ logits[0, -1, token_id] *= repetition_penalty
441
+ else:
442
+ logits[0, -1, token_id] /= repetition_penalty
443
+
444
+ return logits
s2s.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from slam_llm.utils.model_utils import get_custom_model_factory
4
+ from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift
5
+ import whisper
6
+ import numpy as np
7
+ from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME
8
+ import os
9
+ from omegaconf import OmegaConf
10
+ from huggingface_hub import hf_hub_download
11
+ from typing import Callable
12
+
13
+
14
+ def update_progress(progress_callback: Callable[[str], None] | None, message: str):
15
+ if progress_callback:
16
+ progress_callback(message)
17
+
18
+
19
+ def pull_model_ckpt():
20
+ if not os.path.exists(CKPT_LOCAL_DIR):
21
+ os.makedirs(CKPT_LOCAL_DIR)
22
+ if os.path.exists(CKPT_PATH):
23
+ return
24
+ hf_hub_download(
25
+ repo_id=CKPT_REPO,
26
+ filename=CKPT_NAME,
27
+ local_dir=CKPT_LOCAL_DIR,
28
+ token=os.getenv("HF_TOKEN"),
29
+ )
30
+
31
+
32
+ pull_model_ckpt()
33
+
34
+
35
+ def extract_audio_feature(audio_path, mel_size):
36
+ print("Extracting audio features from", audio_path)
37
+ audio_raw = whisper.load_audio(audio_path)
38
+ audio_raw = whisper.pad_or_trim(audio_raw)
39
+ audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0)
40
+ audio_length = (audio_mel.shape[0] + 1) // 2
41
+ audio_length = audio_length // 5
42
+ audio_res = audio_mel
43
+
44
+ return audio_res, audio_length
45
+
46
+
47
+ def get_input_ids(length, special_token_a, special_token_t, vocab_config):
48
+ input_ids = []
49
+ for i in range(vocab_config.code_layer):
50
+ input_ids_item = []
51
+ input_ids_item.append(layershift(vocab_config.input_a, i))
52
+ input_ids_item += [layershift(vocab_config.pad_a, i)] * length
53
+ input_ids_item += [
54
+ (layershift(vocab_config.eoa, i)),
55
+ layershift(special_token_a, i),
56
+ ]
57
+ input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
58
+ input_id_T = torch.tensor(
59
+ [vocab_config.input_t]
60
+ + [vocab_config.pad_t] * length
61
+ + [vocab_config.eot, special_token_t]
62
+ )
63
+ input_ids.append(input_id_T.unsqueeze(0))
64
+ return input_ids
65
+
66
+
67
+ def generate_from_wav(
68
+ wav_path, model, codec_decoder, dataset_config, decode_config, device
69
+ ):
70
+ mel_size = dataset_config.mel_size
71
+ prompt = dataset_config.prompt
72
+ prompt_template = "USER: {}\n ASSISTANT: "
73
+ vocab_config = dataset_config.vocab_config
74
+ special_token_a = vocab_config.answer_a
75
+ special_token_t = vocab_config.answer_t
76
+ code_layer = vocab_config.code_layer
77
+ task_type = dataset_config.task_type
78
+
79
+ audio_mel, audio_length = extract_audio_feature(wav_path, mel_size)
80
+
81
+ prompt = prompt_template.format(prompt)
82
+ prompt_ids = model.tokenizer.encode(prompt)
83
+ prompt_length = len(prompt_ids)
84
+ prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
85
+
86
+ example_ids = get_input_ids(
87
+ audio_length + prompt_length, special_token_a, special_token_t, vocab_config
88
+ )
89
+ text_layer = example_ids[code_layer]
90
+ text_layer = torch.cat(
91
+ (
92
+ text_layer[:, : audio_length + 1],
93
+ prompt_ids.unsqueeze(0),
94
+ text_layer[:, -2:],
95
+ ),
96
+ dim=1,
97
+ ) # <bos> <audio> <prompt> <eos> <task>
98
+ example_ids[code_layer] = text_layer
99
+
100
+ input_length = audio_length
101
+ example_mask = example_ids[0][0].ge(-1)
102
+ example_ids = torch.stack(example_ids).squeeze()
103
+
104
+ input_ids = example_ids.unsqueeze(0).to(device)
105
+ attention_mask = example_mask.unsqueeze(0).to(device)
106
+ audio_mel = audio_mel.unsqueeze(0).to(device)
107
+ input_length = torch.tensor([input_length]).to(device)
108
+ audio_length = torch.tensor([audio_length]).to(device)
109
+ task_type = [task_type]
110
+
111
+ modality_mask = torch.zeros_like(attention_mask)
112
+ padding_left = 1 # +1 for <bos>
113
+ modality_mask[0, padding_left : padding_left + audio_length] = True
114
+
115
+ batch = {
116
+ "input_ids": input_ids,
117
+ "attention_mask": attention_mask,
118
+ "audio_mel": audio_mel,
119
+ "input_length": input_length,
120
+ "audio_length": audio_length,
121
+ "modality_mask": modality_mask,
122
+ "task_types": task_type,
123
+ }
124
+
125
+ model_outputs = model.generate(**batch, **decode_config)
126
+ text_outputs = model_outputs[7]
127
+ audio_outputs = model_outputs[:7]
128
+ output_text = model.tokenizer.decode(
129
+ text_outputs, add_special_tokens=False, skip_special_tokens=True
130
+ )
131
+
132
+ if decode_config.decode_text_only:
133
+ return None, output_text
134
+
135
+ audio_tokens = [audio_outputs[layer] for layer in range(7)]
136
+ audiolist = reconscruct_snac(audio_tokens)
137
+ audio = reconstruct_tensors(audiolist)
138
+ with torch.inference_mode():
139
+ audio_hat = codec_decoder.decode(audio)
140
+
141
+ return audio_hat, output_text
142
+
143
+
144
+ def generate(
145
+ wav_path: str, progress_callback: Callable[[str], None] | None = None
146
+ ) -> tuple[np.ndarray, int | float]:
147
+ config = OmegaConf.structured(InferenceConfig())
148
+ train_config, model_config, dataset_config, decode_config = (
149
+ config.train_config,
150
+ config.model_config,
151
+ config.dataset_config,
152
+ config.decode_config,
153
+ )
154
+
155
+ torch.cuda.manual_seed(train_config.seed)
156
+ torch.manual_seed(train_config.seed)
157
+ random.seed(train_config.seed)
158
+
159
+ update_progress(progress_callback, "Loading model")
160
+
161
+ model_factory = get_custom_model_factory(model_config)
162
+ model, _ = model_factory(train_config, model_config, CKPT_PATH)
163
+ codec_decoder = model.codec_decoder
164
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
165
+ model.to(device)
166
+ model.eval()
167
+
168
+ update_progress(progress_callback, "Generating")
169
+ output_wav, output_text = generate_from_wav(
170
+ wav_path, model, codec_decoder, dataset_config, decode_config, device
171
+ )
172
+
173
+ return output_wav.squeeze().cpu().numpy(), 24000
174
+
175
+
176
+ if __name__ == "__main__":
177
+ wav_path = "sample.wav"
178
+ generate(wav_path)
s2s_config.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional, List
3
+ import os
4
+
5
+ CKPT_NAME = "model.pt"
6
+ CKPT_LOCAL_DIR = "model_ckpts"
7
+ CKPT_PATH = os.path.join(CKPT_LOCAL_DIR, CKPT_NAME)
8
+ CKPT_REPO = "xcczach/mini-omni"
9
+
10
+
11
+ @dataclass
12
+ class VocabConfig:
13
+ text_vocabsize: int = 151936
14
+ text_specialtokens: int = 64
15
+ audio_vocabsize: int = 4096
16
+ audio_specialtokens: int = 64
17
+ total_vocabsize: int = 181120
18
+ code_layer: int = 7
19
+
20
+ padded_text_vocabsize: int = field(init=False)
21
+ padded_audio_vocabsize: int = field(init=False)
22
+ total_audio_vocabsize: int = field(init=False)
23
+
24
+ eot: int = field(init=False) # end of text token
25
+ pad_t: int = field(init=False) # padding text token
26
+ input_t: int = field(init=False) # input text token
27
+ answer_t: int = field(init=False) # answer text token
28
+ asr: int = field(init=False) # ASR token
29
+
30
+ eoa: int = field(init=False) # end of audio token
31
+ pad_a: int = field(init=False) # padding audio token
32
+ input_a: int = field(init=False) # input audio token
33
+ answer_a: int = field(init=False) # answer audio token
34
+ split: int = field(init=False) # split token
35
+
36
+ def __post_init__(self):
37
+ self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens
38
+ self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens
39
+ self.total_audio_vocabsize = self.padded_audio_vocabsize * self.code_layer
40
+
41
+ self.eot = self.text_vocabsize
42
+ self.pad_t = self.text_vocabsize + 1
43
+ self.input_t = self.text_vocabsize + 2
44
+ self.answer_t = self.text_vocabsize + 3
45
+ self.asr = self.text_vocabsize + 4
46
+
47
+ self.eoa = self.audio_vocabsize
48
+ self.pad_a = self.audio_vocabsize + 1
49
+ self.input_a = self.audio_vocabsize + 2
50
+ self.answer_a = self.audio_vocabsize + 3
51
+ self.split = self.audio_vocabsize + 4
52
+
53
+
54
+ @dataclass
55
+ class TTSAdapterConfig:
56
+ add_qkv_bias: Optional[bool] = True
57
+ bias: bool = False
58
+ gelu_approximate: Optional[str] = None
59
+ head_size: Optional[int] = 64
60
+ intermediate_size: Optional[int] = 4864
61
+ lm_head_bias: bool = False
62
+ mlp_class_name: str = "GptNeoxMLP"
63
+ n_layer: int = 6
64
+ n_head: int = 14
65
+ n_embd: int = 896
66
+ n_query_groups: Optional[int] = 2
67
+ norm_class_name: str = "RMSNorm"
68
+ norm_eps: float = 1e-6
69
+ parallel_residual: bool = False
70
+ rotary_percentage: float = 1
71
+ shared_attention_norm: bool = False
72
+
73
+ def __post_init__(self):
74
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
75
+
76
+
77
+ @dataclass
78
+ class ModelConfig:
79
+ file: str = "model/slam_model_s2s.py:model_factory"
80
+ llm_name: str = "qwen2-0.5b"
81
+ llm_path: str = "Qwen/Qwen2-0.5B"
82
+ llm_type: str = "decoder_only"
83
+ llm_dim: int = 896
84
+ encoder_name: Optional[str] = "whisper"
85
+ encoder_ds_rate: int = 2
86
+ encoder_path: Optional[str] = "small"
87
+ encoder_dim: int = 768
88
+ encoder_projector: str = "linear"
89
+ encoder_projector_ds_rate: int = 5
90
+ modal: str = "audio"
91
+ normalize: Optional[bool] = field(
92
+ default=False,
93
+ metadata={"help": "whether input is normalized, used for models such as wavlm"},
94
+ )
95
+ encoder_type: str = field(
96
+ default="finetune",
97
+ metadata={
98
+ "help": "whether model is only pretrained or finetuned, used for models such as hubert"
99
+ },
100
+ )
101
+ vocab_config: VocabConfig = field(default_factory=VocabConfig)
102
+ codec_decode: bool = True
103
+ codec_decoder_type: str = "SNAC"
104
+ codec_decoder_path: Optional[str] = "hubertsiuzdak/snac_24khz"
105
+ tts_adapter: bool = False
106
+ tts_adapter_config: TTSAdapterConfig = field(default_factory=TTSAdapterConfig)
107
+
108
+
109
+ @dataclass
110
+ class PeftConfig:
111
+ peft_method: str = "lora" # None , llama_adapter, prefix
112
+ r: int = 8
113
+ lora_alpha: int = 32
114
+ target_modules: List = field(default_factory=lambda: ["q_proj", "v_proj"])
115
+ bias: str = "none"
116
+ task_type: str = "CAUSAL_LM"
117
+ lora_dropout: float = 0.05
118
+ inference_mode: bool = False
119
+
120
+
121
+ @dataclass
122
+ class TrainConfig:
123
+ model_name: str = "s2s"
124
+ enable_ddp: bool = False
125
+ enable_deepspeed: bool = False
126
+ enable_fsdp: bool = False
127
+ low_cpu_fsdp: bool = False
128
+ run_validation: bool = True
129
+ batch_size_training: int = 4
130
+ batching_strategy: str = field(
131
+ default="custom", metadata={"help": "alternative: padding"}
132
+ ) #
133
+ context_length: int = 4096
134
+ gradient_accumulation_steps: int = 1
135
+ num_epochs: int = 1
136
+ num_workers_dataloader: int = 2
137
+ warmup_steps: int = 1000
138
+ total_steps: int = 100000
139
+ validation_interval: int = 1000
140
+ lr: float = 1e-4
141
+ weight_decay: float = 0.0
142
+ gamma: float = 0.85
143
+ seed: int = 42
144
+ use_fp16: bool = False
145
+ mixed_precision: bool = True
146
+ val_batch_size: int = 1
147
+
148
+ use_peft: bool = False
149
+ peft_config: PeftConfig = field(default_factory=PeftConfig)
150
+ output_dir: str = "PATH/to/save/PEFT/model"
151
+ freeze_layers: bool = False
152
+ num_freeze_layers: int = 1
153
+ quantization: bool = False
154
+ one_gpu: bool = False
155
+ save_model: bool = True
156
+ dist_checkpoint_root_folder: str = (
157
+ "PATH/to/save/FSDP/model" # will be used if using FSDP
158
+ )
159
+ dist_checkpoint_folder: str = "fine-tuned" # will be used if using FSDP
160
+ save_optimizer: bool = False # will be used if using FSDP
161
+ use_fast_kernels: bool = (
162
+ False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
163
+ )
164
+ run_test_during_validation: bool = False
165
+ run_test_during_validation_file: str = "test.wav"
166
+ run_test_during_validation_prompt: str = "<|S2S|>"
167
+ freeze_llm: bool = field(
168
+ default=True,
169
+ metadata={
170
+ "help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
171
+ },
172
+ )
173
+ freeze_encoder: bool = True
174
+ train_embed_only: bool = False
175
+ train_audio_embed_only: bool = False
176
+ task_type: str = "s2s"
177
+
178
+
179
+ @dataclass
180
+ class DataConfig:
181
+ dataset: str = "speech_dataset_s2s"
182
+ file: str = "examples/s2s/speech_dataset_s2s.py:get_speech_dataset"
183
+ train_data_path: Optional[str] = None
184
+ val_data_path: Optional[str] = None
185
+ train_split: str = "train"
186
+ test_split: str = "validation"
187
+ prompt: Optional[str] = None
188
+ data_path: Optional[str] = None
189
+ max_words: Optional[int] = None
190
+ max_mel: Optional[float] = None
191
+ fix_length_audio: int = -1
192
+ inference_mode: bool = True
193
+ input_type: str = field(
194
+ default="mel",
195
+ metadata={"help": "Use raw when input is wav, mel when for whisper"},
196
+ )
197
+ mel_size: int = field(
198
+ default=80, metadata={"help": "80 for whisper large v1 and v2, 128 for v3"}
199
+ )
200
+ normalize: Optional[bool] = field(
201
+ default=False,
202
+ metadata={"help": "whether input is normalized, used for models such as wavlm"},
203
+ )
204
+ seed: int = 42
205
+ manifest_format: str = field(
206
+ default="datasets", metadata={"help": "alternative: jsonl"}
207
+ )
208
+ split_size: float = 0.1
209
+
210
+ vocab_config: VocabConfig = field(default_factory=VocabConfig)
211
+ load_from_cache_file: bool = False
212
+ task_type: str = "s2s"
213
+
214
+
215
+ @dataclass
216
+ class DecodeConfig:
217
+ do_sample: bool = False
218
+ max_new_tokens: int = 300
219
+ min_length: int = 10
220
+ temperature: float = 1.0
221
+ top_k: int = 50
222
+ top_p: float = 0.9
223
+ num_beams: int = 1
224
+ num_return_sequences: int = 1
225
+ num_samples: int = 1
226
+ max_time: float = 0.0
227
+ repetition_penalty: float = 1.0
228
+ length_penalty: float = 1.0
229
+ early_stopping: bool = False
230
+ no_repeat_ngram_size: int = 0
231
+ bad_words_ids: List = field(default_factory=list)
232
+ num_beam_groups: int = 1
233
+ diversity_penalty: float = 0.0
234
+ task_type: str = "s2s"
235
+ decode_text_only: bool = False
236
+
237
+
238
+ @dataclass
239
+ class FSDPConfig:
240
+ mixed_precision: bool = True
241
+ use_fp16: bool = False
242
+ # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
243
+ sharding_strategy: str = (
244
+ "NO_SHARD" # ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
245
+ )
246
+ checkpoint_type: str = (
247
+ "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
248
+ )
249
+ fsdp_activation_checkpointing: bool = True
250
+ fsdp_cpu_offload: bool = False
251
+ pure_bf16: bool = False
252
+ optimizer: str = "AdamW"
253
+
254
+
255
+ @dataclass
256
+ class LogConfig:
257
+ use_wandb: bool = False
258
+ wandb_dir: str = "/valleblob/v-wenxichen/exp/wandb_log"
259
+ wandb_entity_name: str = "project_name"
260
+ wandb_project_name: str = "project_name"
261
+ wandb_exp_name: str = "exp_name"
262
+ log_file: str = "/valleblob/v-wenxichen/exp/log/test.log"
263
+ log_interval: int = 10
264
+ online_output_dir: Optional[str] = None
265
+
266
+
267
+ @dataclass
268
+ class InferenceConfig:
269
+ dataset_config: DataConfig = field(default_factory=DataConfig)
270
+ model_config: ModelConfig = field(default_factory=ModelConfig)
271
+ train_config: TrainConfig = field(default_factory=TrainConfig)
272
+ decode_config: DecodeConfig = field(default_factory=DecodeConfig)
slam_llm/__init__.py ADDED
File without changes
slam_llm/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
slam_llm/data/concatenator.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ from tqdm import tqdm
5
+ from itertools import chain
6
+
7
+ from torch.utils.data import Dataset
8
+
9
+
10
+ class ConcatDataset(Dataset):
11
+ def __init__(self, dataset, chunk_size=4096):
12
+ self.dataset = dataset
13
+ self.chunk_size = chunk_size
14
+
15
+ self.samples = []
16
+
17
+ buffer = {
18
+ "input_ids": [],
19
+ "attention_mask": [],
20
+ "labels": [],
21
+ }
22
+
23
+ for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
24
+ buffer = {k: v + sample[k] for k,v in buffer.items()}
25
+
26
+ while len(next(iter(buffer.values()))) > self.chunk_size:
27
+ self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
28
+ buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
29
+
30
+ def __getitem__(self, idx):
31
+ return self.samples[idx]
32
+
33
+ def __len__(self):
34
+ return len(self.samples)
slam_llm/data/sampler.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import random
5
+ from itertools import islice
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
12
+ def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
13
+ if isinstance(next(iter(data_source)), dict):
14
+ first_key = next(iter(next(iter(data_source)).keys()))
15
+ self.lengths = [len(d[first_key]) for d in data_source]
16
+ else:
17
+ self.lengths = [len(d) for d in data_source]
18
+ self.batch_size = batch_size
19
+ self.drop_last = drop_last
20
+ self.shuffle = shuffle
21
+
22
+ def __iter__(self):
23
+ ids = np.argsort(self.lengths)
24
+ if self.drop_last:
25
+ ids = ids[:len(ids) // self.batch_size * self.batch_size]
26
+
27
+ batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
28
+
29
+ if self.shuffle:
30
+ random.shuffle(batches)
31
+
32
+ for b in batches:
33
+ yield b
34
+
35
+ def __len__(self):
36
+ if self.drop_last:
37
+ return len(self.lengths) // self.batch_size
38
+ else:
39
+ return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
40
+
41
+
42
+ class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
43
+ def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
44
+ random.seed(seed)
45
+ self.batch_sampler = LengthBasedBatchSampler(
46
+ data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
47
+ )
48
+ self.num_replicas = num_replicas
49
+ self.rank = rank
50
+
51
+ def __iter__(self):
52
+ max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
53
+ return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
54
+
55
+ def __len__(self):
56
+ return len(self.batch_sampler) // self.num_replicas
57
+
slam_llm/models/BEATs/BEATs.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from .backbone import (
17
+ TransformerEncoder,
18
+ )
19
+
20
+ import logging
21
+ from typing import Optional
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class BEATsConfig:
27
+ def __init__(self, cfg=None):
28
+ self.input_patch_size: int = -1 # path size of patch embedding
29
+ self.embed_dim: int = 512 # patch embedding dimension
30
+ self.conv_bias: bool = False # include bias in conv encoder
31
+
32
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
33
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
34
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
35
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
36
+ self.activation_fn: str = "gelu" # activation function to use
37
+
38
+ self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
39
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
40
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
41
+
42
+ # dropouts
43
+ self.dropout: float = 0.1 # dropout probability for the transformer
44
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
45
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
46
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
47
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
48
+
49
+ # positional embeddings
50
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
51
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
52
+
53
+ # relative position embedding
54
+ self.relative_position_embedding: bool = False # apply relative position embedding
55
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
56
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
57
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
58
+
59
+ # label predictor
60
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
61
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
62
+ self.predictor_class: int = 527 # target class number for the predictor
63
+
64
+ if cfg is not None:
65
+ self.update(cfg)
66
+
67
+ def update(self, cfg: dict):
68
+ self.__dict__.update(cfg)
69
+
70
+
71
+ class BEATs(nn.Module):
72
+ def __init__(
73
+ self,
74
+ cfg: BEATsConfig,
75
+ ) -> None:
76
+ super().__init__()
77
+ logger.info(f"BEATs Config: {cfg.__dict__}")
78
+
79
+ self.cfg = cfg
80
+
81
+ self.embed = cfg.embed_dim
82
+ self.post_extract_proj = (
83
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
84
+ if self.embed != cfg.encoder_embed_dim
85
+ else None
86
+ )
87
+
88
+ self.input_patch_size = cfg.input_patch_size
89
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
90
+ bias=cfg.conv_bias)
91
+
92
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
93
+
94
+ assert not cfg.deep_norm or not cfg.layer_norm_first
95
+ self.encoder = TransformerEncoder(cfg)
96
+ self.layer_norm = LayerNorm(self.embed)
97
+
98
+ if cfg.finetuned_model:
99
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
100
+ self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
101
+ else:
102
+ self.predictor = None
103
+
104
+ def forward_padding_mask(
105
+ self,
106
+ features: torch.Tensor,
107
+ padding_mask: torch.Tensor,
108
+ ) -> torch.Tensor:
109
+ extra = padding_mask.size(1) % features.size(1)
110
+ if extra > 0:
111
+ padding_mask = padding_mask[:, :-extra]
112
+ padding_mask = padding_mask.view(
113
+ padding_mask.size(0), features.size(1), -1
114
+ )
115
+ padding_mask = padding_mask.all(-1)
116
+ return padding_mask
117
+
118
+ @classmethod
119
+ def preprocess(
120
+ cls,
121
+ source: torch.Tensor,
122
+ fbank_mean: float = 15.41663,
123
+ fbank_std: float = 6.55582,
124
+ ) -> torch.Tensor:
125
+ if len(source.shape) > 1: # batch
126
+ fbanks = []
127
+ for waveform in source:
128
+ waveform = waveform.unsqueeze(0) * 2 ** 15
129
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
130
+ fbanks.append(fbank)
131
+ fbank = torch.stack(fbanks, dim=0)
132
+ else: # single
133
+ waveform = source.unsqueeze(0) * 2 ** 15
134
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
135
+
136
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
137
+ return fbank
138
+
139
+ def extract_features(
140
+ self,
141
+ fbank: torch.Tensor,
142
+ padding_mask: Optional[torch.Tensor] = None,
143
+ ):
144
+ if padding_mask is not None:
145
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
146
+
147
+ fbank = fbank.unsqueeze(1)
148
+ features = self.patch_embedding(fbank)
149
+ features = features.reshape(features.shape[0], features.shape[1], -1)
150
+ features = features.transpose(1, 2)
151
+ features = self.layer_norm(features)
152
+
153
+ if padding_mask is not None:
154
+ padding_mask = self.forward_padding_mask(features, padding_mask)
155
+
156
+ if self.post_extract_proj is not None:
157
+ features = self.post_extract_proj(features)
158
+
159
+ x = self.dropout_input(features)
160
+
161
+ x, layer_results = self.encoder(
162
+ x,
163
+ padding_mask=padding_mask,
164
+ )
165
+
166
+ if self.predictor is not None:
167
+ x = self.predictor_dropout(x)
168
+ logits = self.predictor(x)
169
+
170
+ if padding_mask is not None and padding_mask.any():
171
+ logits[padding_mask] = 0
172
+ logits = logits.sum(dim=1)
173
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
174
+ else:
175
+ logits = logits.mean(dim=1)
176
+
177
+ lprobs = torch.sigmoid(logits)
178
+
179
+ return lprobs, padding_mask
180
+ else:
181
+ return x, padding_mask
slam_llm/models/BEATs/Tokenizers.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from .backbone import (
17
+ TransformerEncoder,
18
+ )
19
+ from .quantizer import (
20
+ NormEMAVectorQuantizer,
21
+ )
22
+
23
+ import logging
24
+ from typing import Optional
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class TokenizersConfig:
30
+ def __init__(self, cfg=None):
31
+ self.input_patch_size: int = -1 # path size of patch embedding
32
+ self.embed_dim: int = 512 # patch embedding dimension
33
+ self.conv_bias: bool = False # include bias in conv encoder
34
+
35
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
36
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
37
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
38
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
39
+ self.activation_fn: str = "gelu" # activation function to use
40
+
41
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
42
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
43
+
44
+ # dropouts
45
+ self.dropout: float = 0.1 # dropout probability for the transformer
46
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
47
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
48
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
49
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
50
+
51
+ # positional embeddings
52
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
53
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
54
+
55
+ # relative position embedding
56
+ self.relative_position_embedding: bool = False # apply relative position embedding
57
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
58
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
59
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
60
+
61
+ # quantizer
62
+ self.quant_n: int = 1024 # codebook number in quantizer
63
+ self.quant_dim: int = 256 # codebook dimension in quantizer
64
+
65
+ if cfg is not None:
66
+ self.update(cfg)
67
+
68
+ def update(self, cfg: dict):
69
+ self.__dict__.update(cfg)
70
+
71
+
72
+ class Tokenizers(nn.Module):
73
+ def __init__(
74
+ self,
75
+ cfg: TokenizersConfig,
76
+ ) -> None:
77
+ super().__init__()
78
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
79
+
80
+ self.cfg = cfg
81
+
82
+ self.embed = cfg.embed_dim
83
+ self.post_extract_proj = (
84
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
85
+ if self.embed != cfg.encoder_embed_dim
86
+ else None
87
+ )
88
+
89
+ self.input_patch_size = cfg.input_patch_size
90
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
91
+ bias=cfg.conv_bias)
92
+
93
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
94
+
95
+ assert not cfg.deep_norm or not cfg.layer_norm_first
96
+ self.encoder = TransformerEncoder(cfg)
97
+ self.layer_norm = LayerNorm(self.embed)
98
+
99
+ self.quantize = NormEMAVectorQuantizer(
100
+ n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
101
+ )
102
+ self.quant_n = cfg.quant_n
103
+ self.quantize_layer = nn.Sequential(
104
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
105
+ nn.Tanh(),
106
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
107
+ )
108
+
109
+ def forward_padding_mask(
110
+ self,
111
+ features: torch.Tensor,
112
+ padding_mask: torch.Tensor,
113
+ ) -> torch.Tensor:
114
+ extra = padding_mask.size(1) % features.size(1)
115
+ if extra > 0:
116
+ padding_mask = padding_mask[:, :-extra]
117
+ padding_mask = padding_mask.view(
118
+ padding_mask.size(0), features.size(1), -1
119
+ )
120
+ padding_mask = padding_mask.all(-1)
121
+ return padding_mask
122
+
123
+ def preprocess(
124
+ self,
125
+ source: torch.Tensor,
126
+ fbank_mean: float = 15.41663,
127
+ fbank_std: float = 6.55582,
128
+ ) -> torch.Tensor:
129
+ fbanks = []
130
+ for waveform in source:
131
+ waveform = waveform.unsqueeze(0) * 2 ** 15
132
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
133
+ fbanks.append(fbank)
134
+ fbank = torch.stack(fbanks, dim=0)
135
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
136
+ return fbank
137
+
138
+ def extract_labels(
139
+ self,
140
+ source: torch.Tensor,
141
+ padding_mask: Optional[torch.Tensor] = None,
142
+ fbank_mean: float = 15.41663,
143
+ fbank_std: float = 6.55582,
144
+ ):
145
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
146
+
147
+ if padding_mask is not None:
148
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
149
+
150
+ fbank = fbank.unsqueeze(1)
151
+ features = self.patch_embedding(fbank)
152
+ features = features.reshape(features.shape[0], features.shape[1], -1)
153
+ features = features.transpose(1, 2)
154
+ features = self.layer_norm(features)
155
+
156
+ if padding_mask is not None:
157
+ padding_mask = self.forward_padding_mask(features, padding_mask)
158
+
159
+ if self.post_extract_proj is not None:
160
+ features = self.post_extract_proj(features)
161
+
162
+ x = self.dropout_input(features)
163
+
164
+ x, layer_results = self.encoder(
165
+ x,
166
+ padding_mask=padding_mask,
167
+ )
168
+
169
+ quantize_input = self.quantize_layer(x)
170
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
171
+
172
+ return embed_ind
173
+
slam_llm/models/BEATs/backbone.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import numpy as np
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ import torch.nn.functional as F
16
+ from torch.nn import LayerNorm, Parameter
17
+ from .modules import (
18
+ GradMultiply,
19
+ SamePad,
20
+ get_activation_fn,
21
+ GLU_Linear,
22
+ quant_noise,
23
+ )
24
+
25
+
26
+ class TransformerEncoder(nn.Module):
27
+ def __init__(self, args):
28
+ super().__init__()
29
+
30
+ self.dropout = args.dropout
31
+ self.embedding_dim = args.encoder_embed_dim
32
+
33
+ self.pos_conv = nn.Conv1d(
34
+ self.embedding_dim,
35
+ self.embedding_dim,
36
+ kernel_size=args.conv_pos,
37
+ padding=args.conv_pos // 2,
38
+ groups=args.conv_pos_groups,
39
+ )
40
+ dropout = 0
41
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
42
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
43
+ nn.init.constant_(self.pos_conv.bias, 0)
44
+
45
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
46
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
47
+
48
+ if hasattr(args, "relative_position_embedding"):
49
+ self.relative_position_embedding = args.relative_position_embedding
50
+ self.num_buckets = args.num_buckets
51
+ self.max_distance = args.max_distance
52
+ else:
53
+ self.relative_position_embedding = False
54
+ self.num_buckets = 0
55
+ self.max_distance = 0
56
+
57
+ self.layers = nn.ModuleList(
58
+ [
59
+ TransformerSentenceEncoderLayer(
60
+ embedding_dim=self.embedding_dim,
61
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
62
+ num_attention_heads=args.encoder_attention_heads,
63
+ dropout=self.dropout,
64
+ attention_dropout=args.attention_dropout,
65
+ activation_dropout=args.activation_dropout,
66
+ activation_fn=args.activation_fn,
67
+ layer_norm_first=args.layer_norm_first,
68
+ deep_norm=args.deep_norm,
69
+ has_relative_attention_bias=self.relative_position_embedding,
70
+ num_buckets=self.num_buckets,
71
+ max_distance=self.max_distance,
72
+ gru_rel_pos=args.gru_rel_pos,
73
+ encoder_layers=args.encoder_layers,
74
+ )
75
+ for i in range(args.encoder_layers)
76
+ ]
77
+ )
78
+ if self.relative_position_embedding:
79
+ for i in range(1, args.encoder_layers):
80
+ del self.layers[i].self_attn.relative_attention_bias
81
+ self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
82
+
83
+ self.layer_norm_first = args.layer_norm_first
84
+ self.layer_norm = LayerNorm(self.embedding_dim)
85
+ self.layerdrop = args.encoder_layerdrop
86
+
87
+ self.apply(init_bert_params)
88
+
89
+ if args.deep_norm:
90
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
91
+ for i in range(args.encoder_layers):
92
+ nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
93
+ nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
94
+ nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
95
+ nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
96
+ nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
97
+ nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
98
+
99
+ self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
100
+
101
+ def forward(self, x, padding_mask=None, layer=None):
102
+ x, layer_results = self.extract_features(x, padding_mask, layer)
103
+
104
+ if self.layer_norm_first and layer is None:
105
+ x = self.layer_norm(x)
106
+
107
+ return x, layer_results
108
+
109
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
110
+
111
+ if padding_mask is not None:
112
+ x[padding_mask] = 0
113
+
114
+ x_conv = self.pos_conv(x.transpose(1, 2))
115
+ x_conv = x_conv.transpose(1, 2)
116
+ x = x + x_conv
117
+
118
+ if not self.layer_norm_first:
119
+ x = self.layer_norm(x)
120
+
121
+ x = F.dropout(x, p=self.dropout, training=self.training)
122
+
123
+ # B x T x C -> T x B x C
124
+ x = x.transpose(0, 1)
125
+
126
+ layer_results = []
127
+ z = None
128
+ if tgt_layer is not None:
129
+ layer_results.append((x, z))
130
+ r = None
131
+ pos_bias = None
132
+ for i, layer in enumerate(self.layers):
133
+ if self.layer_wise_gradient_decay_ratio != 1.0:
134
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
135
+ dropout_probability = np.random.random()
136
+ if not self.training or (dropout_probability > self.layerdrop):
137
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
138
+ if tgt_layer is not None:
139
+ layer_results.append((x, z))
140
+ if i == tgt_layer:
141
+ r = x
142
+ break
143
+
144
+ if r is not None:
145
+ x = r
146
+
147
+ # T x B x C -> B x T x C
148
+ x = x.transpose(0, 1)
149
+
150
+ return x, layer_results
151
+
152
+
153
+ class TransformerSentenceEncoderLayer(nn.Module):
154
+ def __init__(
155
+ self,
156
+ embedding_dim: float = 768,
157
+ ffn_embedding_dim: float = 3072,
158
+ num_attention_heads: float = 8,
159
+ dropout: float = 0.1,
160
+ attention_dropout: float = 0.1,
161
+ activation_dropout: float = 0.1,
162
+ activation_fn: str = "relu",
163
+ layer_norm_first: bool = False,
164
+ deep_norm: bool = False,
165
+ has_relative_attention_bias: bool = False,
166
+ num_buckets: int = 0,
167
+ max_distance: int = 0,
168
+ rescale_init: bool = False,
169
+ gru_rel_pos: bool = False,
170
+ encoder_layers: int = 0,
171
+ ) -> None:
172
+
173
+ super().__init__()
174
+ self.embedding_dim = embedding_dim
175
+ self.dropout = dropout
176
+ self.activation_dropout = activation_dropout
177
+
178
+ self.activation_name = activation_fn
179
+ self.activation_fn = get_activation_fn(activation_fn)
180
+ self.self_attn = MultiheadAttention(
181
+ self.embedding_dim,
182
+ num_attention_heads,
183
+ dropout=attention_dropout,
184
+ self_attention=True,
185
+ has_relative_attention_bias=has_relative_attention_bias,
186
+ num_buckets=num_buckets,
187
+ max_distance=max_distance,
188
+ rescale_init=rescale_init,
189
+ gru_rel_pos=gru_rel_pos,
190
+ )
191
+
192
+ self.dropout1 = nn.Dropout(dropout)
193
+ self.dropout2 = nn.Dropout(self.activation_dropout)
194
+ self.dropout3 = nn.Dropout(dropout)
195
+
196
+ self.layer_norm_first = layer_norm_first
197
+
198
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
199
+
200
+ if self.activation_name == "glu":
201
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
202
+ else:
203
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
204
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
205
+
206
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
207
+
208
+ self.deep_norm = deep_norm
209
+ if self.deep_norm:
210
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
211
+ else:
212
+ self.deep_norm_alpha = 1
213
+
214
+ def forward(
215
+ self,
216
+ x: torch.Tensor,
217
+ self_attn_mask: torch.Tensor = None,
218
+ self_attn_padding_mask: torch.Tensor = None,
219
+ need_weights: bool = False,
220
+ pos_bias=None
221
+ ):
222
+ residual = x
223
+
224
+ if self.layer_norm_first:
225
+ x = self.self_attn_layer_norm(x)
226
+ x, attn, pos_bias = self.self_attn(
227
+ query=x,
228
+ key=x,
229
+ value=x,
230
+ key_padding_mask=self_attn_padding_mask,
231
+ need_weights=False,
232
+ attn_mask=self_attn_mask,
233
+ position_bias=pos_bias
234
+ )
235
+ x = self.dropout1(x)
236
+ x = residual + x
237
+
238
+ residual = x
239
+ x = self.final_layer_norm(x)
240
+ if self.activation_name == "glu":
241
+ x = self.fc1(x)
242
+ else:
243
+ x = self.activation_fn(self.fc1(x))
244
+ x = self.dropout2(x)
245
+ x = self.fc2(x)
246
+ x = self.dropout3(x)
247
+ x = residual + x
248
+ else:
249
+ x, attn, pos_bias = self.self_attn(
250
+ query=x,
251
+ key=x,
252
+ value=x,
253
+ key_padding_mask=self_attn_padding_mask,
254
+ need_weights=need_weights,
255
+ attn_mask=self_attn_mask,
256
+ position_bias=pos_bias
257
+ )
258
+
259
+ x = self.dropout1(x)
260
+ x = residual * self.deep_norm_alpha + x
261
+
262
+ x = self.self_attn_layer_norm(x)
263
+
264
+ residual = x
265
+ if self.activation_name == "glu":
266
+ x = self.fc1(x)
267
+ else:
268
+ x = self.activation_fn(self.fc1(x))
269
+ x = self.dropout2(x)
270
+ x = self.fc2(x)
271
+ x = self.dropout3(x)
272
+ x = residual * self.deep_norm_alpha + x
273
+ x = self.final_layer_norm(x)
274
+
275
+ return x, attn, pos_bias
276
+
277
+
278
+ class MultiheadAttention(nn.Module):
279
+ """Multi-headed attention.
280
+
281
+ See "Attention Is All You Need" for more details.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ embed_dim,
287
+ num_heads,
288
+ kdim=None,
289
+ vdim=None,
290
+ dropout=0.0,
291
+ bias=True,
292
+ add_bias_kv=False,
293
+ add_zero_attn=False,
294
+ self_attention=False,
295
+ encoder_decoder_attention=False,
296
+ q_noise=0.0,
297
+ qn_block_size=8,
298
+ has_relative_attention_bias=False,
299
+ num_buckets=32,
300
+ max_distance=128,
301
+ gru_rel_pos=False,
302
+ rescale_init=False,
303
+ ):
304
+ super().__init__()
305
+ self.embed_dim = embed_dim
306
+ self.kdim = kdim if kdim is not None else embed_dim
307
+ self.vdim = vdim if vdim is not None else embed_dim
308
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
309
+
310
+ self.num_heads = num_heads
311
+ self.dropout_module = nn.Dropout(dropout)
312
+
313
+ self.has_relative_attention_bias = has_relative_attention_bias
314
+ self.num_buckets = num_buckets
315
+ self.max_distance = max_distance
316
+ if self.has_relative_attention_bias:
317
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
318
+
319
+ self.head_dim = embed_dim // num_heads
320
+ self.q_head_dim = self.head_dim
321
+ self.k_head_dim = self.head_dim
322
+ assert (
323
+ self.head_dim * num_heads == self.embed_dim
324
+ ), "embed_dim must be divisible by num_heads"
325
+ self.scaling = self.head_dim ** -0.5
326
+
327
+ self.self_attention = self_attention
328
+ self.encoder_decoder_attention = encoder_decoder_attention
329
+
330
+ assert not self.self_attention or self.qkv_same_dim, (
331
+ "Self-attention requires query, key and " "value to be of the same size"
332
+ )
333
+
334
+ k_bias = True
335
+ if rescale_init:
336
+ k_bias = False
337
+
338
+ k_embed_dim = embed_dim
339
+ q_embed_dim = embed_dim
340
+
341
+ self.k_proj = quant_noise(
342
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
343
+ )
344
+ self.v_proj = quant_noise(
345
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
346
+ )
347
+ self.q_proj = quant_noise(
348
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
349
+ )
350
+
351
+ self.out_proj = quant_noise(
352
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
353
+ )
354
+
355
+ if add_bias_kv:
356
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
357
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
358
+ else:
359
+ self.bias_k = self.bias_v = None
360
+
361
+ self.add_zero_attn = add_zero_attn
362
+
363
+ self.gru_rel_pos = gru_rel_pos
364
+ if self.gru_rel_pos:
365
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
366
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
367
+
368
+ self.reset_parameters()
369
+
370
+ def reset_parameters(self):
371
+ if self.qkv_same_dim:
372
+ # Empirically observed the convergence to be much better with
373
+ # the scaled initialization
374
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
375
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
376
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
377
+ else:
378
+ nn.init.xavier_uniform_(self.k_proj.weight)
379
+ nn.init.xavier_uniform_(self.v_proj.weight)
380
+ nn.init.xavier_uniform_(self.q_proj.weight)
381
+
382
+ nn.init.xavier_uniform_(self.out_proj.weight)
383
+ if self.out_proj.bias is not None:
384
+ nn.init.constant_(self.out_proj.bias, 0.0)
385
+ if self.bias_k is not None:
386
+ nn.init.xavier_normal_(self.bias_k)
387
+ if self.bias_v is not None:
388
+ nn.init.xavier_normal_(self.bias_v)
389
+ if self.has_relative_attention_bias:
390
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
391
+
392
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
393
+ num_buckets = self.num_buckets
394
+ max_distance = self.max_distance
395
+ relative_buckets = 0
396
+
397
+ if bidirectional:
398
+ num_buckets = num_buckets // 2
399
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
400
+ relative_positions = torch.abs(relative_positions)
401
+ else:
402
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
403
+
404
+ max_exact = num_buckets // 2
405
+ is_small = relative_positions < max_exact
406
+
407
+ relative_postion_if_large = max_exact + (
408
+ torch.log(relative_positions.float() / max_exact)
409
+ / math.log(max_distance / max_exact)
410
+ * (num_buckets - max_exact)
411
+ ).to(torch.long)
412
+ relative_postion_if_large = torch.min(
413
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
414
+ )
415
+
416
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
417
+ return relative_buckets
418
+
419
+ def compute_bias(self, query_length, key_length):
420
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
421
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
422
+ relative_position = memory_position - context_position
423
+ relative_position_bucket = self._relative_positions_bucket(
424
+ relative_position,
425
+ bidirectional=True
426
+ )
427
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
428
+ values = self.relative_attention_bias(relative_position_bucket)
429
+ values = values.permute([2, 0, 1])
430
+ return values
431
+
432
+ def forward(
433
+ self,
434
+ query,
435
+ key: Optional[Tensor],
436
+ value: Optional[Tensor],
437
+ key_padding_mask: Optional[Tensor] = None,
438
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
439
+ need_weights: bool = True,
440
+ static_kv: bool = False,
441
+ attn_mask: Optional[Tensor] = None,
442
+ before_softmax: bool = False,
443
+ need_head_weights: bool = False,
444
+ position_bias: Optional[Tensor] = None
445
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
446
+ """Input shape: Time x Batch x Channel
447
+
448
+ Args:
449
+ key_padding_mask (ByteTensor, optional): mask to exclude
450
+ keys that are pads, of shape `(batch, src_len)`, where
451
+ padding elements are indicated by 1s.
452
+ need_weights (bool, optional): return the attention weights,
453
+ averaged over heads (default: False).
454
+ attn_mask (ByteTensor, optional): typically used to
455
+ implement causal attention, where the mask prevents the
456
+ attention from looking forward in time (default: None).
457
+ before_softmax (bool, optional): return the raw attention
458
+ weights and values before the attention softmax.
459
+ need_head_weights (bool, optional): return the attention
460
+ weights for each head. Implies *need_weights*. Default:
461
+ return the average attention weights over all heads.
462
+ """
463
+ if need_head_weights:
464
+ need_weights = True
465
+
466
+ is_tpu = query.device.type == "xla"
467
+
468
+ tgt_len, bsz, embed_dim = query.size()
469
+ src_len = tgt_len
470
+ assert embed_dim == self.embed_dim
471
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
472
+ if key is not None:
473
+ src_len, key_bsz, _ = key.size()
474
+ if not torch.jit.is_scripting():
475
+ assert key_bsz == bsz
476
+ assert value is not None
477
+ assert src_len, bsz == value.shape[:2]
478
+
479
+ if self.has_relative_attention_bias and position_bias is None:
480
+ position_bias = self.compute_bias(tgt_len, src_len)
481
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
482
+
483
+ if incremental_state is not None:
484
+ saved_state = self._get_input_buffer(incremental_state)
485
+ if saved_state is not None and "prev_key" in saved_state:
486
+ # previous time steps are cached - no need to recompute
487
+ # key and value if they are static
488
+ if static_kv:
489
+ assert self.encoder_decoder_attention and not self.self_attention
490
+ key = value = None
491
+ else:
492
+ saved_state = None
493
+
494
+ if self.self_attention:
495
+ q = self.q_proj(query)
496
+ k = self.k_proj(query)
497
+ v = self.v_proj(query)
498
+ elif self.encoder_decoder_attention:
499
+ # encoder-decoder attention
500
+ q = self.q_proj(query)
501
+ if key is None:
502
+ assert value is None
503
+ k = v = None
504
+ else:
505
+ k = self.k_proj(key)
506
+ v = self.v_proj(key)
507
+
508
+ else:
509
+ assert key is not None and value is not None
510
+ q = self.q_proj(query)
511
+ k = self.k_proj(key)
512
+ v = self.v_proj(value)
513
+ q *= self.scaling
514
+ alpha = 32
515
+ q *= 1 / alpha
516
+
517
+ if self.bias_k is not None:
518
+ assert self.bias_v is not None
519
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
520
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
521
+ if attn_mask is not None:
522
+ attn_mask = torch.cat(
523
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
524
+ )
525
+ if key_padding_mask is not None:
526
+ key_padding_mask = torch.cat(
527
+ [
528
+ key_padding_mask,
529
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
530
+ ],
531
+ dim=1,
532
+ )
533
+
534
+ q = (
535
+ q.contiguous()
536
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
537
+ .transpose(0, 1)
538
+ )
539
+ if k is not None:
540
+ k = (
541
+ k.contiguous()
542
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
543
+ .transpose(0, 1)
544
+ )
545
+ if v is not None:
546
+ v = (
547
+ v.contiguous()
548
+ .view(-1, bsz * self.num_heads, self.head_dim)
549
+ .transpose(0, 1)
550
+ )
551
+
552
+ if saved_state is not None:
553
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
554
+ if "prev_key" in saved_state:
555
+ _prev_key = saved_state["prev_key"]
556
+ assert _prev_key is not None
557
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
558
+ if static_kv:
559
+ k = prev_key
560
+ else:
561
+ assert k is not None
562
+ k = torch.cat([prev_key, k], dim=1)
563
+ src_len = k.size(1)
564
+ if "prev_value" in saved_state:
565
+ _prev_value = saved_state["prev_value"]
566
+ assert _prev_value is not None
567
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
568
+ if static_kv:
569
+ v = prev_value
570
+ else:
571
+ assert v is not None
572
+ v = torch.cat([prev_value, v], dim=1)
573
+ prev_key_padding_mask: Optional[Tensor] = None
574
+ if "prev_key_padding_mask" in saved_state:
575
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
576
+ assert k is not None and v is not None
577
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
578
+ key_padding_mask=key_padding_mask,
579
+ prev_key_padding_mask=prev_key_padding_mask,
580
+ batch_size=bsz,
581
+ src_len=k.size(1),
582
+ static_kv=static_kv,
583
+ )
584
+
585
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
586
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
587
+ saved_state["prev_key_padding_mask"] = key_padding_mask
588
+ # In this branch incremental_state is never None
589
+ assert incremental_state is not None
590
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
591
+ assert k is not None
592
+ assert k.size(1) == src_len
593
+
594
+ # This is part of a workaround to get around fork/join parallelism
595
+ # not supporting Optional types.
596
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
597
+ key_padding_mask = None
598
+
599
+ if key_padding_mask is not None:
600
+ assert key_padding_mask.size(0) == bsz
601
+ assert key_padding_mask.size(1) == src_len
602
+
603
+ if self.add_zero_attn:
604
+ assert v is not None
605
+ src_len += 1
606
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
607
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
608
+ if attn_mask is not None:
609
+ attn_mask = torch.cat(
610
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
611
+ )
612
+ if key_padding_mask is not None:
613
+ key_padding_mask = torch.cat(
614
+ [
615
+ key_padding_mask,
616
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
617
+ key_padding_mask
618
+ ),
619
+ ],
620
+ dim=1,
621
+ )
622
+
623
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
624
+ attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
625
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
626
+
627
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
628
+
629
+ if attn_mask is not None:
630
+ attn_mask = attn_mask.unsqueeze(0)
631
+ attn_weights += attn_mask
632
+
633
+ if key_padding_mask is not None:
634
+ # don't attend to padding symbols
635
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
636
+ if not is_tpu:
637
+ attn_weights = attn_weights.masked_fill(
638
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
639
+ float("-inf"),
640
+ )
641
+ else:
642
+ attn_weights = attn_weights.transpose(0, 2)
643
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
644
+ attn_weights = attn_weights.transpose(0, 2)
645
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
646
+
647
+ if before_softmax:
648
+ return attn_weights, v, position_bias
649
+
650
+ if position_bias is not None:
651
+ attn_mask_rel_pos = position_bias
652
+ if self.gru_rel_pos == 1:
653
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
654
+ _B, _H, _L, __ = query_layer.size()
655
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
656
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
657
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
658
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
659
+
660
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
661
+
662
+ attn_weights = attn_weights + attn_mask_rel_pos
663
+
664
+ attn_weights_float = F.softmax(
665
+ attn_weights, dim=-1
666
+ )
667
+ attn_weights = attn_weights_float.type_as(attn_weights)
668
+ attn_probs = self.dropout_module(attn_weights)
669
+
670
+ assert v is not None
671
+ attn = torch.bmm(attn_probs, v)
672
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
673
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
674
+ attn = self.out_proj(attn)
675
+ attn_weights: Optional[Tensor] = None
676
+ if need_weights:
677
+ attn_weights = attn_weights_float.view(
678
+ bsz, self.num_heads, tgt_len, src_len
679
+ ).transpose(1, 0)
680
+ if not need_head_weights:
681
+ # average attention weights over heads
682
+ attn_weights = attn_weights.mean(dim=0)
683
+
684
+ return attn, attn_weights, position_bias
685
+
686
+ @staticmethod
687
+ def _append_prev_key_padding_mask(
688
+ key_padding_mask: Optional[Tensor],
689
+ prev_key_padding_mask: Optional[Tensor],
690
+ batch_size: int,
691
+ src_len: int,
692
+ static_kv: bool,
693
+ ) -> Optional[Tensor]:
694
+ # saved key padding masks have shape (bsz, seq_len)
695
+ if prev_key_padding_mask is not None and static_kv:
696
+ new_key_padding_mask = prev_key_padding_mask
697
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
698
+ new_key_padding_mask = torch.cat(
699
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
700
+ )
701
+ # During incremental decoding, as the padding token enters and
702
+ # leaves the frame, there will be a time when prev or current
703
+ # is None
704
+ elif prev_key_padding_mask is not None:
705
+ if src_len > prev_key_padding_mask.size(1):
706
+ filler = torch.zeros(
707
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
708
+ device=prev_key_padding_mask.device,
709
+ )
710
+ new_key_padding_mask = torch.cat(
711
+ [prev_key_padding_mask.float(), filler.float()], dim=1
712
+ )
713
+ else:
714
+ new_key_padding_mask = prev_key_padding_mask.float()
715
+ elif key_padding_mask is not None:
716
+ if src_len > key_padding_mask.size(1):
717
+ filler = torch.zeros(
718
+ (batch_size, src_len - key_padding_mask.size(1)),
719
+ device=key_padding_mask.device,
720
+ )
721
+ new_key_padding_mask = torch.cat(
722
+ [filler.float(), key_padding_mask.float()], dim=1
723
+ )
724
+ else:
725
+ new_key_padding_mask = key_padding_mask.float()
726
+ else:
727
+ new_key_padding_mask = prev_key_padding_mask
728
+ return new_key_padding_mask
729
+
730
+ def _get_input_buffer(
731
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
732
+ ) -> Dict[str, Optional[Tensor]]:
733
+ result = self.get_incremental_state(incremental_state, "attn_state")
734
+ if result is not None:
735
+ return result
736
+ else:
737
+ empty_result: Dict[str, Optional[Tensor]] = {}
738
+ return empty_result
739
+
740
+ def _set_input_buffer(
741
+ self,
742
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
743
+ buffer: Dict[str, Optional[Tensor]],
744
+ ):
745
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
746
+
747
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
748
+ return attn_weights
749
+
750
+
751
+ def init_bert_params(module):
752
+ """
753
+ Initialize the weights specific to the BERT Model.
754
+ This overrides the default initializations depending on the specified arguments.
755
+ 1. If normal_init_linear_weights is set then weights of linear
756
+ layer will be initialized using the normal distribution and
757
+ bais will be set to the specified value.
758
+ 2. If normal_init_embed_weights is set then weights of embedding
759
+ layer will be initialized using the normal distribution.
760
+ 3. If normal_init_proj_weights is set then weights of
761
+ in_project_weight for MultiHeadAttention initialized using
762
+ the normal distribution (to be validated).
763
+ """
764
+
765
+ def normal_(data):
766
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
767
+ # so that the RNG is consistent with and without FSDP
768
+ data.copy_(
769
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
770
+ )
771
+
772
+ if isinstance(module, nn.Linear):
773
+ normal_(module.weight.data)
774
+ if module.bias is not None:
775
+ module.bias.data.zero_()
776
+ if isinstance(module, nn.Embedding):
777
+ normal_(module.weight.data)
778
+ if module.padding_idx is not None:
779
+ module.weight.data[module.padding_idx].zero_()
780
+ if isinstance(module, MultiheadAttention):
781
+ normal_(module.q_proj.weight.data)
782
+ normal_(module.k_proj.weight.data)
783
+ normal_(module.v_proj.weight.data)
slam_llm/models/BEATs/modules.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ import torch
13
+ from torch import Tensor, nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class GradMultiply(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x, scale):
20
+ ctx.scale = scale
21
+ res = x.new(x)
22
+ return res
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad):
26
+ return grad * ctx.scale, None
27
+
28
+
29
+ class SamePad(nn.Module):
30
+ def __init__(self, kernel_size, causal=False):
31
+ super().__init__()
32
+ if causal:
33
+ self.remove = kernel_size - 1
34
+ else:
35
+ self.remove = 1 if kernel_size % 2 == 0 else 0
36
+
37
+ def forward(self, x):
38
+ if self.remove > 0:
39
+ x = x[:, :, : -self.remove]
40
+ return x
41
+
42
+
43
+ class Swish(nn.Module):
44
+ def __init__(self):
45
+ super(Swish, self).__init__()
46
+ self.act = torch.nn.Sigmoid()
47
+
48
+ def forward(self, x):
49
+ return x * self.act(x)
50
+
51
+
52
+ class GLU_Linear(nn.Module):
53
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
54
+ super(GLU_Linear, self).__init__()
55
+
56
+ self.glu_type = glu_type
57
+ self.output_dim = output_dim
58
+
59
+ if glu_type == "sigmoid":
60
+ self.glu_act = torch.nn.Sigmoid()
61
+ elif glu_type == "swish":
62
+ self.glu_act = Swish()
63
+ elif glu_type == "relu":
64
+ self.glu_act = torch.nn.ReLU()
65
+ elif glu_type == "gelu":
66
+ self.glu_act = torch.nn.GELU()
67
+
68
+ if bias_in_glu:
69
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
70
+ else:
71
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
72
+
73
+ def forward(self, x):
74
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
75
+ x = self.linear(x)
76
+
77
+ if self.glu_type == "bilinear":
78
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
79
+ else:
80
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
81
+
82
+ return x
83
+
84
+
85
+ def gelu_accurate(x):
86
+ if not hasattr(gelu_accurate, "_a"):
87
+ gelu_accurate._a = math.sqrt(2 / math.pi)
88
+ return (
89
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
90
+ )
91
+
92
+
93
+ def gelu(x: torch.Tensor) -> torch.Tensor:
94
+ return torch.nn.functional.gelu(x.float()).type_as(x)
95
+
96
+
97
+ def get_activation_fn(activation: str):
98
+ """Returns the activation function corresponding to `activation`"""
99
+
100
+ if activation == "relu":
101
+ return F.relu
102
+ elif activation == "gelu":
103
+ return gelu
104
+ elif activation == "gelu_fast":
105
+ warnings.warn(
106
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
107
+ )
108
+ return gelu_accurate
109
+ elif activation == "gelu_accurate":
110
+ return gelu_accurate
111
+ elif activation == "tanh":
112
+ return torch.tanh
113
+ elif activation == "linear":
114
+ return lambda x: x
115
+ elif activation == "glu":
116
+ return lambda x: x
117
+ else:
118
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
119
+
120
+
121
+ def quant_noise(module, p, block_size):
122
+ """
123
+ Wraps modules and applies quantization noise to the weights for
124
+ subsequent quantization with Iterative Product Quantization as
125
+ described in "Training with Quantization Noise for Extreme Model Compression"
126
+
127
+ Args:
128
+ - module: nn.Module
129
+ - p: amount of Quantization Noise
130
+ - block_size: size of the blocks for subsequent quantization with iPQ
131
+
132
+ Remarks:
133
+ - Module weights must have the right sizes wrt the block size
134
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
135
+ - For more detail on how to quantize by blocks with convolutional weights,
136
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
137
+ - We implement the simplest form of noise here as stated in the paper
138
+ which consists in randomly dropping blocks
139
+ """
140
+
141
+ # if no quantization noise, don't register hook
142
+ if p <= 0:
143
+ return module
144
+
145
+ # supported modules
146
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
147
+
148
+ # test whether module.weight has the right sizes wrt block_size
149
+ is_conv = module.weight.ndim == 4
150
+
151
+ # 2D matrix
152
+ if not is_conv:
153
+ assert (
154
+ module.weight.size(1) % block_size == 0
155
+ ), "Input features must be a multiple of block sizes"
156
+
157
+ # 4D matrix
158
+ else:
159
+ # 1x1 convolutions
160
+ if module.kernel_size == (1, 1):
161
+ assert (
162
+ module.in_channels % block_size == 0
163
+ ), "Input channels must be a multiple of block sizes"
164
+ # regular convolutions
165
+ else:
166
+ k = module.kernel_size[0] * module.kernel_size[1]
167
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
168
+
169
+ def _forward_pre_hook(mod, input):
170
+ # no noise for evaluation
171
+ if mod.training:
172
+ if not is_conv:
173
+ # gather weight and sizes
174
+ weight = mod.weight
175
+ in_features = weight.size(1)
176
+ out_features = weight.size(0)
177
+
178
+ # split weight matrix into blocks and randomly drop selected blocks
179
+ mask = torch.zeros(
180
+ in_features // block_size * out_features, device=weight.device
181
+ )
182
+ mask.bernoulli_(p)
183
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
184
+
185
+ else:
186
+ # gather weight and sizes
187
+ weight = mod.weight
188
+ in_channels = mod.in_channels
189
+ out_channels = mod.out_channels
190
+
191
+ # split weight matrix into blocks and randomly drop selected blocks
192
+ if mod.kernel_size == (1, 1):
193
+ mask = torch.zeros(
194
+ int(in_channels // block_size * out_channels),
195
+ device=weight.device,
196
+ )
197
+ mask.bernoulli_(p)
198
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
199
+ else:
200
+ mask = torch.zeros(
201
+ weight.size(0), weight.size(1), device=weight.device
202
+ )
203
+ mask.bernoulli_(p)
204
+ mask = (
205
+ mask.unsqueeze(2)
206
+ .unsqueeze(3)
207
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
208
+ )
209
+
210
+ # scale weights and apply mask
211
+ mask = mask.to(
212
+ torch.bool
213
+ ) # x.bool() is not currently supported in TorchScript
214
+ s = 1 / (1 - p)
215
+ mod.weight.data = s * weight.masked_fill(mask, 0)
216
+
217
+ module.register_forward_pre_hook(_forward_pre_hook)
218
+ return module
219
+
slam_llm/models/BEATs/quantizer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on VQGAN code bases
7
+ # https://github.com/CompVis/taming-transformers
8
+ # --------------------------------------------------------'
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.distributed as distributed
14
+
15
+ try:
16
+ from einops import rearrange, repeat
17
+ except ImportError:
18
+ pass
19
+
20
+
21
+ def l2norm(t):
22
+ return F.normalize(t, p=2, dim=-1)
23
+
24
+
25
+ def ema_inplace(moving_avg, new, decay):
26
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
27
+
28
+
29
+ def sample_vectors(samples, num):
30
+ num_samples, device = samples.shape[0], samples.device
31
+
32
+ if num_samples >= num:
33
+ indices = torch.randperm(num_samples, device=device)[:num]
34
+ else:
35
+ indices = torch.randint(0, num_samples, (num,), device=device)
36
+
37
+ return samples[indices]
38
+
39
+
40
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
41
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
42
+
43
+ means = sample_vectors(samples, num_clusters)
44
+
45
+ for _ in range(num_iters):
46
+ if use_cosine_sim:
47
+ dists = samples @ means.t()
48
+ else:
49
+ diffs = rearrange(samples, 'n d -> n () d') \
50
+ - rearrange(means, 'c d -> () c d')
51
+ dists = -(diffs ** 2).sum(dim=-1)
52
+
53
+ buckets = dists.max(dim=-1).indices
54
+ bins = torch.bincount(buckets, minlength=num_clusters)
55
+ zero_mask = bins == 0
56
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
57
+
58
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
59
+ new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
60
+ new_means = new_means / bins_min_clamped[..., None]
61
+
62
+ if use_cosine_sim:
63
+ new_means = l2norm(new_means)
64
+
65
+ means = torch.where(zero_mask[..., None], means, new_means)
66
+
67
+ return means, bins
68
+
69
+
70
+ class EmbeddingEMA(nn.Module):
71
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
72
+ super().__init__()
73
+ self.num_tokens = num_tokens
74
+ self.codebook_dim = codebook_dim
75
+ self.decay = decay
76
+ self.eps = eps
77
+ if codebook_init_path == '':
78
+ if not kmeans_init:
79
+ weight = torch.randn(num_tokens, codebook_dim)
80
+ weight = l2norm(weight)
81
+ else:
82
+ weight = torch.zeros(num_tokens, codebook_dim)
83
+ self.register_buffer('initted', torch.Tensor([not kmeans_init]))
84
+ else:
85
+ print(f"load init codebook weight from {codebook_init_path}")
86
+ codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
87
+ weight = codebook_ckpt_weight.clone()
88
+ self.register_buffer('initted', torch.Tensor([True]))
89
+
90
+ self.weight = nn.Parameter(weight, requires_grad=False)
91
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
92
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
93
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
94
+ self.update = True
95
+
96
+ @torch.jit.ignore
97
+ def init_embed_(self, data):
98
+ if self.initted:
99
+ return
100
+ print("Performing Kemans init for codebook")
101
+ embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
102
+ self.weight.data.copy_(embed)
103
+ self.cluster_size.data.copy_(cluster_size)
104
+ self.initted.data.copy_(torch.Tensor([True]))
105
+
106
+ def forward(self, embed_id):
107
+ return F.embedding(embed_id, self.weight)
108
+
109
+ def cluster_size_ema_update(self, new_cluster_size):
110
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
111
+
112
+ def embed_avg_ema_update(self, new_embed_avg):
113
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
114
+
115
+ def weight_update(self, num_tokens):
116
+ n = self.cluster_size.sum()
117
+ smoothed_cluster_size = (
118
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
119
+ )
120
+ # normalize embedding average with smoothed cluster size
121
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
122
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
123
+ self.weight.data.copy_(embed_normalized)
124
+
125
+
126
+ def norm_ema_inplace(moving_avg, new, decay):
127
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
128
+ moving_avg.data.copy_(l2norm(moving_avg.data))
129
+
130
+
131
+ class NormEMAVectorQuantizer(nn.Module):
132
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
133
+ statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
134
+ super().__init__()
135
+ self.codebook_dim = embedding_dim
136
+ self.num_tokens = n_embed
137
+ self.beta = beta
138
+ self.decay = decay
139
+
140
+ # learnable = True if orthogonal_reg_weight > 0 else False
141
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
142
+
143
+ self.statistic_code_usage = statistic_code_usage
144
+ if statistic_code_usage:
145
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
146
+ if distributed.is_available() and distributed.is_initialized():
147
+ print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
148
+ self.all_reduce_fn = distributed.all_reduce
149
+ else:
150
+ self.all_reduce_fn = nn.Identity()
151
+
152
+ def reset_cluster_size(self, device):
153
+ if self.statistic_code_usage:
154
+ self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
155
+ self.cluster_size = self.cluster_size.to(device)
156
+
157
+ def forward(self, z):
158
+ # reshape z -> (batch, height, width, channel) and flatten
159
+ # z, 'b c h w -> b h w c'
160
+ # z = rearrange(z, 'b c h w -> b h w c')
161
+ # z = z.transpose(1, 2)
162
+ z = l2norm(z)
163
+ z_flattened = z.reshape(-1, self.codebook_dim)
164
+
165
+ self.embedding.init_embed_(z_flattened)
166
+
167
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
168
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
169
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
170
+
171
+ encoding_indices = torch.argmin(d, dim=1)
172
+
173
+ z_q = self.embedding(encoding_indices).view(z.shape)
174
+
175
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
176
+
177
+ if not self.training:
178
+ with torch.no_grad():
179
+ cluster_size = encodings.sum(0)
180
+ self.all_reduce_fn(cluster_size)
181
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
182
+
183
+ if self.training and self.embedding.update:
184
+ # EMA cluster size
185
+
186
+ bins = encodings.sum(0)
187
+ self.all_reduce_fn(bins)
188
+
189
+ # self.embedding.cluster_size_ema_update(bins)
190
+ ema_inplace(self.cluster_size, bins, self.decay)
191
+
192
+ zero_mask = (bins == 0)
193
+ bins = bins.masked_fill(zero_mask, 1.)
194
+
195
+ embed_sum = z_flattened.t() @ encodings
196
+ self.all_reduce_fn(embed_sum)
197
+
198
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
199
+ embed_normalized = l2norm(embed_normalized)
200
+
201
+ embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
202
+ embed_normalized)
203
+ norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
204
+
205
+ # compute loss for embedding
206
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
207
+
208
+ # preserve gradients
209
+ z_q = z + (z_q - z).detach()
210
+
211
+ # reshape back to match original input shape
212
+ # z_q, 'b h w c -> b c h w'
213
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
214
+ # z_q = z_q.transpose(1, 2)
215
+ return z_q, loss, encoding_indices
slam_llm/models/EAT/EAT.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import random
4
+
5
+ def EAT_preprocess(source, norm_mean = -4.268, norm_std = 4.569, target_length = 1024, fixed_length = False, random_crop = False):
6
+ source = source - source.mean()
7
+ source = source.unsqueeze(dim=0)
8
+
9
+ source = torchaudio.compliance.kaldi.fbank(source, htk_compat=True, sample_frequency=16000, use_energy=False,
10
+ window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10).unsqueeze(dim=0)
11
+
12
+ n_frames = source.shape[1]
13
+ if not fixed_length:
14
+ target_length = n_frames
15
+ if target_length % 16 != 0:
16
+ target_length = n_frames + (16 - n_frames % 16)
17
+ diff = target_length - n_frames
18
+ if diff > 0:
19
+ m = torch.nn.ZeroPad2d((0, 0, 0, diff))
20
+ source = m(source)
21
+ elif diff < 0:
22
+ if random_crop:
23
+ start_index = random.randint(0, n_frames - target_length)
24
+ source = source[:,start_index: start_index+target_length, :]
25
+ else:
26
+ source = source[:,0:target_length, :]
27
+
28
+ # Normalize the mel spectrogram
29
+ source = (source - norm_mean) / (norm_std * 2)
30
+ source = source.squeeze()
31
+
32
+ return source
slam_llm/models/SpatialAST/SpatialAST.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torchlibrosa.stft import STFT, LogmelFilterBank
5
+ from timm.models.layers import to_2tuple
6
+
7
+ from .vision_transformer import VisionTransformer as _VisionTransformer
8
+
9
+ def conv3x3(in_channels, out_channels, stride=1):
10
+ "3x3 convolution with padding"
11
+ return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
12
+
13
+ class PatchEmbed_new(nn.Module):
14
+ """ Flexible Image to Patch Embedding
15
+ """
16
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10):
17
+ super().__init__()
18
+ img_size = to_2tuple(img_size)
19
+ patch_size = to_2tuple(patch_size)
20
+ stride = to_2tuple(stride)
21
+
22
+ self.img_size = img_size
23
+ self.patch_size = patch_size
24
+ self.in_chans = in_chans
25
+
26
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
27
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
28
+ self.patch_hw = (h, w)
29
+ self.num_patches = h*w
30
+
31
+ def get_output_shape(self, img_size):
32
+ return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape
33
+
34
+ def forward(self, x):
35
+ B, C, H, W = x.shape
36
+
37
+ x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
38
+ x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
39
+ x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
40
+ return x
41
+
42
+ class BinauralEncoder(_VisionTransformer):
43
+ """ Spatial Audio Spectrogram Transformer designed for Sound Event Localization and Detection
44
+ --------------------------------------------------------
45
+ References:
46
+ Spatial-AST from BAT: https://github.com/zszheng147/Spatial-AST and https://arxiv.org/abs/2402.01591
47
+ --------------------------------------------------------
48
+ """
49
+ def __init__(self, num_cls_tokens=3, **kwargs):
50
+ super(BinauralEncoder, self).__init__(**kwargs)
51
+ img_size = (1024, 128) # 1024, 128
52
+ in_chans = 1
53
+ emb_dim = 768
54
+
55
+ del self.cls_token
56
+ self.num_cls_tokens = num_cls_tokens
57
+ self.cls_tokens = nn.Parameter(torch.zeros(1, num_cls_tokens, emb_dim))
58
+
59
+ self.patch_embed = PatchEmbed_new(
60
+ img_size=img_size, patch_size=(16, 16),
61
+ in_chans=in_chans, embed_dim=emb_dim, stride=16
62
+ ) # no overlap. stride=img_size=16
63
+
64
+ num_patches = self.patch_embed.num_patches
65
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False) # fixed sin-cos embedding
66
+
67
+ self.spectrogram_extractor = STFT(
68
+ n_fft=1024, hop_length=320, win_length=1024, window='hann',
69
+ center=True, pad_mode='reflect', freeze_parameters=True
70
+ )
71
+ self.logmel_extractor = LogmelFilterBank(
72
+ sr=32000, n_fft=1024, n_mels=128, fmin=50,
73
+ fmax=14000, ref=1.0, amin=1e-10, top_db=None, freeze_parameters=True
74
+ )
75
+
76
+ self.conv_downsample = nn.Sequential(
77
+ conv3x3(4, 1),
78
+ nn.BatchNorm2d(1),
79
+ nn.GELU(),
80
+ )
81
+
82
+ self.bn = nn.BatchNorm2d(2, affine=False)
83
+ del self.norm # remove the original norm
84
+
85
+ self.target_frame = 1024
86
+
87
+ def forward_features_mask(self, x):
88
+ B = x.shape[0] #bsz, 512, 768 (unmasked)
89
+
90
+ x = x + self.pos_embed[:, 1:, :]
91
+
92
+ cls_tokens = self.cls_tokens
93
+ cls_tokens = cls_tokens.expand(B, -1, -1)
94
+ x = torch.cat([cls_tokens, x], dim=1) # bsz, 512 + 2 + 10, 768
95
+ x = self.pos_drop(x)
96
+
97
+ for blk in self.blocks:
98
+ x = blk(x)
99
+
100
+ return x
101
+
102
+ @torch.no_grad()
103
+ def forward(self, waveforms):
104
+ B, C, T = waveforms.shape
105
+
106
+ waveforms = waveforms.reshape(B * C, T)
107
+ real, imag = self.spectrogram_extractor(waveforms)
108
+
109
+ log_mel = self.logmel_extractor(torch.sqrt(real**2 + imag**2)).reshape(B, C, -1, 128)
110
+ log_mel = self.bn(log_mel)
111
+
112
+ IPD = torch.atan2(imag[1::2], real[1::2]) - torch.atan2(imag[::2], real[::2])
113
+ x = torch.cat([log_mel, torch.matmul(torch.cat([torch.cos(IPD), torch.sin(IPD)], dim=1), self.logmel_extractor.melW)], dim=1)
114
+
115
+ if x.shape[2] < self.target_frame:
116
+ x = nn.functional.interpolate(x, (self.target_frame, x.shape[3]), mode="bicubic", align_corners=True)
117
+
118
+ x = self.conv_downsample(x)
119
+ x = self.patch_embed(x)
120
+ x = self.forward_features_mask(x)
121
+
122
+ return x
slam_llm/models/SpatialAST/vision_transformer.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from timm.models.layers import to_2tuple, DropPath, trunc_normal_
5
+
6
+
7
+ class HybridEmbed(nn.Module):
8
+ """ CNN Feature Map Embedding
9
+ Extract feature map from CNN, flatten, project to embedding dim.
10
+ """
11
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
12
+ super().__init__()
13
+ assert isinstance(backbone, nn.Module)
14
+ img_size = to_2tuple(img_size)
15
+ self.img_size = img_size
16
+ self.backbone = backbone
17
+ if feature_size is None:
18
+ with torch.no_grad():
19
+ # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
20
+ # map for all networks, the feature metadata has reliable channel and stride info, but using
21
+ # stride to calc feature dim requires info about padding of each stage that isn't captured.
22
+ training = backbone.training
23
+ if training:
24
+ backbone.eval()
25
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
26
+ feature_size = o.shape[-2:]
27
+ feature_dim = o.shape[1]
28
+ backbone.train(training)
29
+ else:
30
+ feature_size = to_2tuple(feature_size)
31
+ feature_dim = self.backbone.feature_info.channels()[-1]
32
+ self.num_patches = feature_size[0] * feature_size[1]
33
+ self.proj = nn.Linear(feature_dim, embed_dim)
34
+
35
+ def forward(self, x):
36
+ x = self.backbone(x)[-1]
37
+ x = x.flatten(2).transpose(1, 2)
38
+ x = self.proj(x)
39
+ return x
40
+
41
+ class Mlp(nn.Module):
42
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
43
+ super().__init__()
44
+ out_features = out_features or in_features
45
+ hidden_features = hidden_features or in_features
46
+ self.fc1 = nn.Linear(in_features, hidden_features)
47
+ self.act = act_layer()
48
+ self.fc2 = nn.Linear(hidden_features, out_features)
49
+ self.drop = nn.Dropout(drop)
50
+
51
+ def forward(self, x):
52
+ x = self.fc1(x)
53
+ x = self.act(x)
54
+ x = self.drop(x)
55
+ x = self.fc2(x)
56
+ x = self.drop(x)
57
+ return x
58
+
59
+ class Attention(nn.Module):
60
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
61
+ super().__init__()
62
+ self.num_heads = num_heads
63
+ head_dim = dim // num_heads
64
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
65
+ self.scale = qk_scale or head_dim ** -0.5
66
+
67
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
68
+ self.attn_drop = nn.Dropout(attn_drop)
69
+ self.proj = nn.Linear(dim, dim)
70
+ self.proj_drop = nn.Dropout(proj_drop)
71
+
72
+
73
+ def forward(self, x):
74
+ B, N, C = x.shape
75
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
76
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
77
+
78
+ attn = (q @ k.transpose(-2, -1)) * self.scale
79
+ attn = attn.softmax(dim=-1)
80
+ attn = self.attn_drop(attn)
81
+
82
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
83
+ x = self.proj(x)
84
+ x = self.proj_drop(x)
85
+ return x
86
+
87
+ class Block(nn.Module):
88
+
89
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
90
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
91
+ super().__init__()
92
+ self.norm1 = norm_layer(dim)
93
+ self.attn = Attention(
94
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
95
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
96
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
97
+ self.norm2 = norm_layer(dim)
98
+ mlp_hidden_dim = int(dim * mlp_ratio)
99
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
100
+
101
+ def forward(self, x):
102
+ x = x + self.drop_path(self.attn(self.norm1(x)))
103
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
104
+ return x
105
+
106
+ class PatchEmbed(nn.Module):
107
+ """ Image to Patch Embedding
108
+ """
109
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
110
+ super().__init__()
111
+ img_size = to_2tuple(img_size)
112
+ patch_size = to_2tuple(patch_size)
113
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
114
+ self.img_size = img_size
115
+ self.patch_size = patch_size
116
+ self.num_patches = num_patches
117
+
118
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
119
+
120
+ def forward(self, x):
121
+ B, C, H, W = x.shape
122
+ # FIXME look at relaxing size constraints
123
+ assert H == self.img_size[0] and W == self.img_size[1], \
124
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
125
+ x = self.proj(x).flatten(2).transpose(1, 2)
126
+ return x
127
+
128
+
129
+ class PatchEmbed_new(nn.Module):
130
+ """ Flexible Image to Patch Embedding
131
+ """
132
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10):
133
+ super().__init__()
134
+ img_size = to_2tuple(img_size)
135
+ patch_size = to_2tuple(patch_size)
136
+ stride = to_2tuple(stride)
137
+
138
+ self.img_size = img_size
139
+ self.patch_size = patch_size
140
+ self.in_chans = in_chans
141
+
142
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
143
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
144
+ self.patch_hw = (h, w)
145
+ self.num_patches = h*w
146
+
147
+ def get_output_shape(self, img_size):
148
+ return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape
149
+
150
+ def forward(self, x):
151
+ B, C, H, W = x.shape
152
+
153
+ x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
154
+ x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
155
+ x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
156
+ return x
157
+
158
+
159
+ class VisionTransformer(nn.Module):
160
+ """ Vision Transformer with support for patch or hybrid CNN input stage
161
+ """
162
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
163
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
164
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
165
+ super().__init__()
166
+ self.num_classes = num_classes
167
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
168
+
169
+ if hybrid_backbone is not None:
170
+ self.patch_embed = HybridEmbed(
171
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
172
+ else:
173
+ self.patch_embed = PatchEmbed(
174
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
175
+ num_patches = self.patch_embed.num_patches
176
+
177
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
178
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
179
+ self.pos_drop = nn.Dropout(p=drop_rate)
180
+
181
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
182
+ self.blocks = nn.ModuleList([
183
+ Block(
184
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
185
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
186
+ for i in range(depth)])
187
+
188
+ self.norm = norm_layer(embed_dim)
189
+
190
+ # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
191
+ #self.repr = nn.Linear(embed_dim, representation_size)
192
+ #self.repr_act = nn.Tanh()
193
+
194
+ # Classifier head
195
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
196
+
197
+ trunc_normal_(self.pos_embed, std=.02)
198
+ trunc_normal_(self.cls_token, std=.02)
199
+ self.apply(self._init_weights)
200
+
201
+ def _init_weights(self, m):
202
+ if isinstance(m, nn.Linear):
203
+ trunc_normal_(m.weight, std=.02)
204
+ if isinstance(m, nn.Linear) and m.bias is not None:
205
+ nn.init.constant_(m.bias, 0)
206
+ elif isinstance(m, nn.LayerNorm):
207
+ nn.init.constant_(m.bias, 0)
208
+ nn.init.constant_(m.weight, 1.0)
209
+
210
+ @torch.jit.ignore
211
+ def no_weight_decay(self):
212
+ return {'pos_embed', 'cls_token'}
213
+
214
+ def get_classifier(self):
215
+ return self.head
216
+
217
+ def reset_classifier(self, num_classes, global_pool=''):
218
+ self.num_classes = num_classes
219
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
220
+
221
+ def forward_features(self, x):
222
+ B = x.shape[0]
223
+ x = self.patch_embed(x)
224
+
225
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
226
+ x = torch.cat((cls_tokens, x), dim=1)
227
+ x = x + self.pos_embed
228
+ x = self.pos_drop(x)
229
+
230
+ for blk in self.blocks:
231
+ x = blk(x)
232
+
233
+ x = self.norm(x)
234
+ return x[:, 0]
235
+
236
+ def forward(self, x):
237
+ x = self.forward_features(x)
238
+ x = self.head(x)
239
+ return x
slam_llm/models/avhubert/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .hubert import * # noqa
7
+ from .hubert_asr import * # noqa
8
+ from .hubert_dataset import *
9
+ from .hubert_pretraining import *
10
+ from .hubert_criterion import *
slam_llm/models/avhubert/decoder.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from argparse import Namespace
8
+ import contextlib
9
+ import copy
10
+ import math
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from dataclasses import dataclass, field
16
+ from omegaconf import MISSING, II, open_dict
17
+ from typing import Any, Optional
18
+
19
+ from fairseq import checkpoint_utils, tasks, utils
20
+ from fairseq.dataclass import FairseqDataclass
21
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
22
+ from fairseq.tasks import FairseqTask
23
+ from fairseq.models import (
24
+ BaseFairseqModel,
25
+ FairseqEncoder,
26
+ FairseqEncoderDecoderModel,
27
+ FairseqIncrementalDecoder,
28
+ register_model,
29
+ )
30
+ # from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
31
+ from fairseq.modules import (
32
+ LayerNorm,
33
+ PositionalEmbedding,
34
+ TransformerDecoderLayer,
35
+ )
36
+
37
+
38
+ class TransformerDecoder(FairseqIncrementalDecoder):
39
+ """
40
+ Transformer decoder consisting of *args.decoder_layers* layers. Each layer
41
+ is a :class:`TransformerDecoderLayer`.
42
+
43
+ Args:
44
+ args (argparse.Namespace): parsed command-line arguments
45
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
46
+ embed_tokens (torch.nn.Embedding): output embedding
47
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
48
+ (default: False).
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ cfg,
54
+ dictionary,
55
+ embed_tokens,
56
+ no_encoder_attn=False,
57
+ ):
58
+ super().__init__(dictionary)
59
+
60
+ self.dropout = cfg.decoder_dropout
61
+ self.share_input_output_embed = cfg.share_decoder_input_output_embed
62
+
63
+ input_embed_dim = embed_tokens.embedding_dim
64
+ embed_dim = cfg.decoder_embed_dim
65
+ self.output_embed_dim = cfg.decoder_embed_dim
66
+
67
+ self.layerdrop = cfg.decoder_layerdrop
68
+
69
+ padding_idx = embed_tokens.padding_idx
70
+ self.max_target_positions = cfg.max_target_positions
71
+
72
+ self.embed_tokens = embed_tokens
73
+ # self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
74
+ self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
75
+
76
+ self.project_in_dim = (
77
+ Linear(input_embed_dim, embed_dim, bias=False)
78
+ if embed_dim != input_embed_dim
79
+ else None
80
+ )
81
+
82
+ self.embed_positions = (
83
+ PositionalEmbedding(
84
+ cfg.max_target_positions,
85
+ embed_dim,
86
+ padding_idx,
87
+ learned=cfg.decoder_learned_pos,
88
+ )
89
+ if not cfg.no_token_positional_embeddings
90
+ else None
91
+ )
92
+
93
+ # TODO: update this when transformer gets converted to dataclass configs
94
+ transformer_cfg = copy.deepcopy(cfg)
95
+ # with open_dict(transformer_cfg):
96
+ transformer_cfg.dropout = transformer_cfg.decoder_dropout
97
+ transformer_cfg.attention_dropout = (
98
+ transformer_cfg.decoder_attention_dropout
99
+ )
100
+ transformer_cfg.activation_dropout = (
101
+ transformer_cfg.decoder_activation_dropout
102
+ )
103
+
104
+ self.layers = nn.ModuleList([])
105
+ self.layers.extend(
106
+ [
107
+ TransformerDecoderLayer(transformer_cfg, no_encoder_attn)
108
+ for _ in range(transformer_cfg.decoder_layers)
109
+ ]
110
+ )
111
+
112
+ if not self.share_input_output_embed:
113
+ self.embed_out = nn.Parameter(
114
+ torch.Tensor(len(dictionary), self.output_embed_dim)
115
+ )
116
+ nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
117
+
118
+ if transformer_cfg.decoder_normalize_before:
119
+ self.layer_norm = LayerNorm(embed_dim)
120
+ else:
121
+ self.layer_norm = None
122
+
123
+ def forward(
124
+ self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
125
+ ):
126
+ """
127
+ Args:
128
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
129
+ `(batch, tgt_len)`, for teacher forcing
130
+ encoder_out (Tensor, optional): output from the encoder, used for
131
+ encoder-side attention
132
+ incremental_state (dict): dictionary used for storing state during
133
+ :ref:`Incremental decoding`
134
+
135
+ Returns:
136
+ tuple:
137
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
138
+ - a dictionary with any model-specific outputs
139
+ """
140
+ prev_output_tokens = prev_output_tokens.long()
141
+ x, extra = self.extract_features(
142
+ prev_output_tokens, encoder_out, incremental_state
143
+ )
144
+ x = self.output_layer(x)
145
+ return x, extra
146
+
147
+ def extract_features(
148
+ self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
149
+ ):
150
+ """
151
+ Similar to *forward* but only return features.
152
+
153
+ Returns:
154
+ tuple:
155
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
156
+ - a dictionary with any model-specific outputs
157
+ """
158
+
159
+ # embed positions
160
+ positions = (
161
+ self.embed_positions(
162
+ prev_output_tokens, incremental_state=incremental_state
163
+ )
164
+ if self.embed_positions is not None
165
+ else None
166
+ )
167
+
168
+ if incremental_state is not None:
169
+ prev_output_tokens = prev_output_tokens[:, -1:]
170
+ if positions is not None:
171
+ positions = positions[:, -1:]
172
+
173
+ # embed tokens and positions
174
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
175
+
176
+ if self.project_in_dim is not None:
177
+ x = self.project_in_dim(x)
178
+
179
+ if positions is not None:
180
+ x += positions
181
+ x = F.dropout(x, p=self.dropout, training=self.training)
182
+
183
+ # B x T x C -> T x B x C
184
+ x = x.transpose(0, 1)
185
+ attn = None
186
+
187
+ inner_states = [x]
188
+
189
+ # decoder layers
190
+ for layer in self.layers:
191
+ dropout_probability = np.random.random()
192
+ if not self.training or (dropout_probability > self.layerdrop):
193
+ x, attn, _ = layer(
194
+ x,
195
+ encoder_out["encoder_out"] if encoder_out is not None else None,
196
+ encoder_out["padding_mask"] if encoder_out is not None else None,
197
+ incremental_state,
198
+ self_attn_mask=self.buffered_future_mask(x)
199
+ if incremental_state is None
200
+ else None,
201
+ )
202
+ inner_states.append(x)
203
+
204
+ if self.layer_norm:
205
+ x = self.layer_norm(x)
206
+
207
+ # T x B x C -> B x T x C
208
+ x = x.transpose(0, 1)
209
+
210
+ return x, {"attn": attn, "inner_states": inner_states}
211
+
212
+ def output_layer(self, features, **kwargs):
213
+ """Project features to the vocabulary size."""
214
+ # project back to size of vocabulary
215
+ emb_mat = self.embed_tokens.weight if self.share_input_output_embed else self.embed_out
216
+ return torch.matmul(features, emb_mat.transpose(0, 1))
217
+ # if self.share_input_output_embed:
218
+ # return F.linear(features, self.embed_tokens.weight)
219
+ # else:
220
+ # return F.linear(features, self.embed_out)
221
+
222
+ def max_positions(self):
223
+ """Maximum output length supported by the decoder."""
224
+ if self.embed_positions is None:
225
+ return self.max_target_positions
226
+ return min(self.max_target_positions, self.embed_positions.max_positions)
227
+
228
+ def buffered_future_mask(self, tensor):
229
+ dim = tensor.size(0)
230
+ if (
231
+ not hasattr(self, "_future_mask")
232
+ or self._future_mask is None
233
+ or self._future_mask.device != tensor.device
234
+ or self._future_mask.size(0) < dim
235
+ ):
236
+ self._future_mask = torch.triu(
237
+ utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
238
+ )
239
+ return self._future_mask[:dim, :dim]
240
+
241
+ def upgrade_state_dict_named(self, state_dict, name):
242
+ return state_dict
243
+
slam_llm/models/avhubert/hubert.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os,sys
8
+ import logging
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from dataclasses import dataclass, field
16
+ from fairseq import utils
17
+ from fairseq.data.data_utils import compute_mask_indices
18
+ from fairseq.data.dictionary import Dictionary
19
+ from fairseq.dataclass import ChoiceEnum, FairseqDataclass
20
+ from fairseq.models import BaseFairseqModel, register_model
21
+ from fairseq.models.wav2vec.wav2vec2 import (
22
+ ConvFeatureExtractionModel,
23
+ TransformerEncoder,
24
+ )
25
+ from fairseq.modules import GradMultiply, LayerNorm
26
+ from copy import deepcopy
27
+
28
+ DBG=True if len(sys.argv) == 1 else False
29
+
30
+ if DBG:
31
+ from hubert_pretraining import (
32
+ AVHubertPretrainingConfig,
33
+ AVHubertPretrainingTask,
34
+ )
35
+ from resnet import ResEncoder
36
+ logging.basicConfig(
37
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
38
+ datefmt="%Y-%m-%d %H:%M:%S",
39
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
40
+ stream=sys.stdout,
41
+ )
42
+ from utils import compute_mask_indices
43
+ from decoder import TransformerDecoder
44
+
45
+ else:
46
+ from .hubert_pretraining import (
47
+ AVHubertPretrainingConfig,
48
+ AVHubertPretrainingTask,
49
+ )
50
+ from .resnet import ResEncoder
51
+ from .utils import compute_mask_indices
52
+ from .decoder import TransformerDecoder
53
+
54
+ from omegaconf import II
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+ EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
59
+ MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(
60
+ ["static", "uniform", "normal", "poisson"]
61
+ )
62
+ # LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"])
63
+
64
+
65
+ @dataclass
66
+ class AVHubertConfig(FairseqDataclass):
67
+ label_rate: int = II("task.label_rate")
68
+ input_modality: str = II("task.input_modality")
69
+ extractor_mode: EXTRACTOR_MODE_CHOICES = field(
70
+ default="default",
71
+ metadata={
72
+ "help": "mode for feature extractor. default has a single group "
73
+ "norm with d groups in the first conv block, whereas layer_norm "
74
+ "has layer norms in every block (meant to use with normalize=True)"
75
+ },
76
+ )
77
+ encoder_layers: int = field(
78
+ default=12, metadata={"help": "num encoder layers in the transformer"}
79
+ )
80
+ encoder_embed_dim: int = field(
81
+ default=768, metadata={"help": "encoder embedding dimension"}
82
+ )
83
+ encoder_ffn_embed_dim: int = field(
84
+ default=3072, metadata={"help": "encoder embedding dimension for FFN"}
85
+ )
86
+ encoder_attention_heads: int = field(
87
+ default=12, metadata={"help": "num encoder attention heads"}
88
+ )
89
+ activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
90
+ default="gelu", metadata={"help": "activation function to use"}
91
+ )
92
+
93
+ # dropouts
94
+ dropout: float = field(
95
+ default=0.1,
96
+ metadata={"help": "dropout probability for the transformer"},
97
+ )
98
+ attention_dropout: float = field(
99
+ default=0.1,
100
+ metadata={"help": "dropout probability for attention weights"},
101
+ )
102
+ activation_dropout: float = field(
103
+ default=0.0,
104
+ metadata={"help": "dropout probability after activation in FFN"},
105
+ )
106
+ encoder_layerdrop: float = field(
107
+ default=0.0,
108
+ metadata={"help": "probability of dropping a tarnsformer layer"},
109
+ )
110
+ dropout_input: float = field(
111
+ default=0.0,
112
+ metadata={"help": "dropout to apply to the input (after feat extr)"},
113
+ )
114
+ dropout_features: float = field(
115
+ default=0.0,
116
+ metadata={
117
+ "help": "dropout to apply to the features (after feat extr)"
118
+ },
119
+ )
120
+
121
+ final_dim: int = field(
122
+ default=0,
123
+ metadata={
124
+ "help": "project final representations and targets to this many "
125
+ "dimensions. set to encoder_embed_dim is <= 0"
126
+ },
127
+ )
128
+ untie_final_proj: bool = field(
129
+ default=False,
130
+ metadata={"help": "use separate projection for each target"},
131
+ )
132
+ layer_norm_first: bool = field(
133
+ default=False,
134
+ metadata={"help": "apply layernorm first in the transformer"},
135
+ )
136
+ conv_feature_layers: str = field(
137
+ default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
138
+ metadata={
139
+ "help": "string describing convolutional feature extraction "
140
+ "layers in form of a python list that contains "
141
+ "[(dim, kernel_size, stride), ...]"
142
+ },
143
+ )
144
+ conv_bias: bool = field(
145
+ default=False, metadata={"help": "include bias in conv encoder"}
146
+ )
147
+ logit_temp: float = field(
148
+ default=0.1, metadata={"help": "temperature to divide logits by"}
149
+ )
150
+ target_glu: bool = field(
151
+ default=False, metadata={"help": "adds projection + glu to targets"}
152
+ )
153
+ feature_grad_mult: float = field(
154
+ default=1.0,
155
+ metadata={"help": "multiply feature extractor var grads by this"},
156
+ )
157
+
158
+ # masking
159
+ mask_length_audio: int = field(default=10, metadata={"help": "mask length"})
160
+ mask_prob_audio: float = field(
161
+ default=0.65,
162
+ metadata={"help": "probability of replacing a token with mask"},
163
+ )
164
+ mask_length_image: int = field(default=10, metadata={"help": "mask length"})
165
+ mask_prob_image: float = field(
166
+ default=0.65,
167
+ metadata={"help": "probability of replacing a token with mask"},
168
+ )
169
+ mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
170
+ default="static", metadata={"help": "how to choose mask length"}
171
+ )
172
+ mask_other: float = field(
173
+ default=0,
174
+ metadata={
175
+ "help": "secondary mask argument "
176
+ "(used for more complex distributions), "
177
+ "see help in compute_mask_indicesh"
178
+ },
179
+ )
180
+ no_mask_overlap: bool = field(
181
+ default=False, metadata={"help": "whether to allow masks to overlap"}
182
+ )
183
+ mask_min_space: int = field(
184
+ default=1,
185
+ metadata={
186
+ "help": "min space between spans (if no overlap is enabled)"
187
+ },
188
+ )
189
+
190
+ # channel masking
191
+ mask_channel_length: int = field(
192
+ default=10,
193
+ metadata={"help": "length of the mask for features (channels)"},
194
+ )
195
+ mask_channel_prob: float = field(
196
+ default=0.0,
197
+ metadata={"help": "probability of replacing a feature with 0"},
198
+ )
199
+ mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
200
+ default="static",
201
+ metadata={"help": "how to choose mask length for channel masking"},
202
+ )
203
+ mask_channel_other: float = field(
204
+ default=0,
205
+ metadata={
206
+ "help": "secondary mask argument "
207
+ "(used for more complex distributions), "
208
+ "see help in compute_mask_indicesh"
209
+ },
210
+ )
211
+ no_mask_channel_overlap: bool = field(
212
+ default=False,
213
+ metadata={"help": "whether to allow channel masks to overlap"},
214
+ )
215
+ mask_channel_min_space: int = field(
216
+ default=1,
217
+ metadata={
218
+ "help": "min space between spans (if no overlap is enabled)"
219
+ },
220
+ )
221
+
222
+ # positional embeddings
223
+ conv_pos: int = field(
224
+ default=128,
225
+ metadata={
226
+ "help": "number of filters for convolutional positional embeddings"
227
+ },
228
+ )
229
+ conv_pos_groups: int = field(
230
+ default=16,
231
+ metadata={
232
+ "help": "number of groups for convolutional positional embedding"
233
+ },
234
+ )
235
+
236
+ latent_temp: Tuple[float, float, float] = field(
237
+ default=(2, 0.5, 0.999995),
238
+ metadata={"help": "legacy (to be removed)"},
239
+ )
240
+
241
+ # loss computation
242
+ skip_masked: bool = field(
243
+ default=False,
244
+ metadata={"help": "skip computing losses over masked frames"},
245
+ )
246
+ skip_nomask: bool = field(
247
+ default=False,
248
+ metadata={"help": "skip computing losses over unmasked frames"},
249
+ )
250
+ resnet_relu_type: str = field(default='prelu', metadata={"help": 'relu type for resnet'})
251
+ resnet_weights: Optional[str] = field(default=None, metadata={"help": 'resnet weights'})
252
+ sim_type: str = field(default='cosine', metadata={"help": 'similarity type'})
253
+
254
+ sub_encoder_layers: int = field(default=0, metadata={'help': 'number of transformer layers for single modality'})
255
+ audio_feat_dim: int = field(default=-1, metadata={'help': 'audio feature dimension'})
256
+ modality_dropout: float = field(default=0, metadata={'help': 'drop one modality'})
257
+ audio_dropout: float = field(default=0, metadata={'help': 'drop audio feature'})
258
+ modality_fuse: str = field(default='concat', metadata={'help': 'fusing two modalities: add,concat'})
259
+ selection_type : str = field(default='same_other_seq', metadata={'help': 'type of selectig images, same_other_seq: replace masked span with span from another sequence, same_seq: repace masked span with span of the same sequence'})
260
+ masking_type : str = field(default='input', metadata={'help': 'input or feature masking'})
261
+
262
+ decoder_embed_dim: int = field(
263
+ default=768, metadata={"help": "decoder embedding dimension"}
264
+ )
265
+ decoder_ffn_embed_dim: int = field(
266
+ default=3072, metadata={"help": "decoder embedding dimension for FFN"}
267
+ )
268
+ decoder_layers: int = field(
269
+ default=6, metadata={"help": "num of decoder layers"}
270
+ )
271
+ decoder_layerdrop: float = field(
272
+ default=0.0, metadata={"help": "decoder layerdrop chance"}
273
+ )
274
+ decoder_attention_heads: int = field(
275
+ default=4, metadata={"help": "num decoder attention heads"}
276
+ )
277
+ decoder_learned_pos: bool = field(
278
+ default=False,
279
+ metadata={"help": "use learned positional embeddings in the decoder"},
280
+ )
281
+ decoder_normalize_before: bool = field(
282
+ default=False,
283
+ metadata={"help": "apply layernorm before each decoder block"},
284
+ )
285
+ no_token_positional_embeddings: bool = field(
286
+ default=False,
287
+ metadata={
288
+ "help": "if set, disables positional embeddings "
289
+ "(outside self attention)"
290
+ },
291
+ )
292
+ decoder_dropout: float = field(
293
+ default=0.1, metadata={"help": "dropout probability in the decoder"}
294
+ )
295
+ decoder_attention_dropout: float = field(
296
+ default=0.1,
297
+ metadata={
298
+ "help": "dropout probability for attention weights "
299
+ "inside the decoder"
300
+ },
301
+ )
302
+ decoder_activation_dropout: float = field(
303
+ default=0.0,
304
+ metadata={
305
+ "help": "dropout probability after activation in FFN "
306
+ "inside the decoder"
307
+ },
308
+ )
309
+ max_target_positions: int = field(
310
+ default=2048, metadata={"help": "max target positions"}
311
+ )
312
+ share_decoder_input_output_embed: bool = field(
313
+ default=False,
314
+ metadata={"help": "share decoder input and output embeddings"},
315
+ )
316
+ no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
317
+
318
+ # # new fairseq
319
+ # required_seq_len_multiple: int = field(
320
+ # default=1,
321
+ # metadata={
322
+ # "help": "pad the input to encoder such that the sequence length is divisible by multiple"
323
+ # },
324
+ # )
325
+
326
+ # layer_type: LAYER_TYPE_CHOICES = field(
327
+ # default="transformer", metadata={"help": "layer type in encoder"}
328
+ # )
329
+
330
+ class SubModel(nn.Module):
331
+ def __init__(self, resnet=None, input_dim=None, cfg=None):
332
+ super().__init__()
333
+ self.resnet = resnet
334
+ self.proj = nn.Linear(input_dim, cfg.encoder_embed_dim)
335
+ self.encoder = TransformerEncoder(cfg) if cfg.encoder_layers > 0 else None
336
+
337
+ def forward(self, x): #torch.Size([1, 1, 106, 112, 112])
338
+ if self.resnet is not None:
339
+ x = self.resnet(x) #torch.Size([1, 512, 106]) #torch.Size([12, 26, 314])
340
+ x = self.proj(x.transpose(1, 2)) #audio是 Linear(in_features=104, out_features=1024, bias=True) 太他妈扯了吧
341
+ if self.encoder is not None:
342
+ x = self.encoder(x)[0].transpose(1, 2)
343
+ else: #
344
+ x = x.transpose(1, 2)
345
+ return x #torch.Size([1, 1024, 106])
346
+
347
+ @register_model("av_hubert", dataclass=AVHubertConfig)
348
+ class AVHubertModel(BaseFairseqModel):
349
+ def __init__(
350
+ self,
351
+ cfg: AVHubertConfig,
352
+ task_cfg: AVHubertPretrainingConfig,
353
+ dictionaries: List[Dictionary],
354
+ **kwargs
355
+ ) -> None:
356
+ super().__init__()
357
+ logger.info(f"HubertModel Config: {cfg}")
358
+
359
+ feature_ds_rate = 1
360
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
361
+ sub_cfg = deepcopy(cfg)
362
+ sub_cfg.encoder_layers = sub_cfg.sub_encoder_layers
363
+ resnet = ResEncoder(relu_type=cfg.resnet_relu_type, weights=cfg.resnet_weights)
364
+ self.feature_extractor_audio = SubModel(resnet=None, input_dim=cfg.audio_feat_dim, cfg=sub_cfg)
365
+ self.feature_extractor_video = SubModel(resnet=resnet, input_dim=resnet.backend_out, cfg=sub_cfg)
366
+ self.modality_dropout, self.audio_dropout = cfg.modality_dropout, cfg.audio_dropout
367
+ self.modality_fuse = cfg.modality_fuse
368
+ self.encoder_embed_dim = cfg.encoder_embed_dim
369
+ if self.modality_fuse == 'concat':
370
+ self.embed = cfg.encoder_embed_dim * 2
371
+ elif self.modality_fuse == 'add':
372
+ self.embed = cfg.encoder_embed_dim
373
+ self.post_extract_proj = (
374
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
375
+ if self.embed != cfg.encoder_embed_dim
376
+ else None
377
+ )
378
+
379
+ self.mask_prob_image, self.mask_prob_audio = cfg.mask_prob_image, cfg.mask_prob_audio
380
+ self.mask_selection = cfg.mask_selection
381
+ self.mask_other = cfg.mask_other
382
+ self.mask_length_image, self.mask_length_audio = cfg.mask_length_image, cfg.mask_length_audio
383
+ self.no_mask_overlap = cfg.no_mask_overlap
384
+ self.mask_min_space = cfg.mask_min_space
385
+
386
+ self.mask_channel_prob = cfg.mask_channel_prob
387
+ self.mask_channel_selection = cfg.mask_channel_selection
388
+ self.mask_channel_other = cfg.mask_channel_other
389
+ self.mask_channel_length = cfg.mask_channel_length
390
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
391
+ self.mask_channel_min_space = cfg.mask_channel_min_space
392
+
393
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
394
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
395
+
396
+ self.feature_grad_mult = cfg.feature_grad_mult
397
+ self.logit_temp = cfg.logit_temp
398
+ self.skip_masked = cfg.skip_masked
399
+ self.skip_nomask = cfg.skip_nomask
400
+ self.sim_type = cfg.sim_type
401
+ self.selection_type = cfg.selection_type
402
+ self.masking_type = cfg.masking_type
403
+
404
+ final_dim = (
405
+ cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
406
+ )
407
+
408
+ self.mask_emb = nn.Parameter(
409
+ torch.FloatTensor(cfg.audio_feat_dim).uniform_() if self.masking_type == 'input' else torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
410
+ )
411
+
412
+ self.encoder = TransformerEncoder(cfg)
413
+ self.layer_norm = LayerNorm(self.embed)
414
+
415
+ self.target_glu = None
416
+ if cfg.target_glu:
417
+ self.target_glu = nn.Sequential(
418
+ nn.Linear(final_dim, final_dim * 2), nn.GLU()
419
+ )
420
+
421
+ self.untie_final_proj = cfg.untie_final_proj
422
+ if self.untie_final_proj:
423
+ self.final_proj = nn.Linear(
424
+ cfg.encoder_embed_dim, final_dim * len(dictionaries)
425
+ )
426
+ else:
427
+ self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
428
+
429
+ # modules below are not needed during fine-tuning
430
+ if any([d is None for d in dictionaries]):
431
+ logger.info(
432
+ "cannot find dictionary. assume will be used for fine-tuning"
433
+ )
434
+ else:
435
+ self.num_classes = [len(d) for d in dictionaries]
436
+ self.label_embs_concat = nn.Parameter(
437
+ torch.FloatTensor(sum(self.num_classes), final_dim)
438
+ )
439
+ nn.init.uniform_(self.label_embs_concat)
440
+
441
+ def upgrade_state_dict_named(self, state_dict, name):
442
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
443
+
444
+ super().upgrade_state_dict_named(state_dict, name)
445
+ return state_dict
446
+
447
+ @classmethod
448
+ def build_model(cls, cfg: AVHubertConfig, task: AVHubertPretrainingTask):
449
+ """Build a new model instance."""
450
+
451
+ kwargs = {}
452
+ model = AVHubertModel(cfg, task.cfg, task.dictionaries, **kwargs)
453
+ return model
454
+
455
+ def apply_input_mask(self, x, padding_mask, target_list):
456
+ B, C, T = x.shape[:3]
457
+ is_audio = True if len(x.shape) == 3 else False
458
+ if is_audio:
459
+ mask_prob, mask_length = self.mask_prob_audio, self.mask_length_audio
460
+ else:
461
+ mask_prob, mask_length = self.mask_prob_image, self.mask_length_image
462
+ if mask_prob > 0:
463
+
464
+ mask_indices, starts, ends, batch_indexes = compute_mask_indices(
465
+ (B, T),
466
+ padding_mask,
467
+ mask_prob,
468
+ mask_length,
469
+ self.mask_selection,
470
+ self.mask_other,
471
+ min_masks=2,
472
+ no_overlap=self.no_mask_overlap,
473
+ min_space=self.mask_min_space,
474
+ )
475
+ mask_indices_np = mask_indices
476
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
477
+ x = x.transpose(1, 2).contiguous() # [B, T, C, H, W]
478
+ if B == 1:
479
+ x[mask_indices] = 0
480
+ elif is_audio:
481
+ x[mask_indices] = self.mask_emb
482
+ elif self.selection_type == 'same_other_seq':
483
+ perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
484
+ x_perm = x[perm]
485
+ x[mask_indices] = x_perm[mask_indices]
486
+ elif self.selection_type == 'same_seq':
487
+ batch_indexes_, other_indexes = [], []
488
+ for batch_index, start, end in zip(batch_indexes, starts, ends):
489
+ length = end-start
490
+ other_start = np.setdiff1d(np.arange(T), np.arange(max(0, start-length), end))
491
+ if len(other_start) > 0:
492
+ other_start = np.random.choice(other_start, size=1)
493
+ else:
494
+ other_start = 0
495
+ other_end = other_start + length
496
+ other_indexes.append(np.arange(other_start, other_end).clip(max=T-1))
497
+ batch_indexes_.append(np.zeros([length], dtype=np.int64)+batch_index)
498
+ batch_indexes, other_indexes = np.concatenate(batch_indexes_), np.concatenate(other_indexes)
499
+ x[mask_indices] = x[batch_indexes, other_indexes]
500
+
501
+ x = x.transpose(1, 2).contiguous()
502
+ else:
503
+ mask_indices = None
504
+
505
+ if self.mask_channel_prob > 0:
506
+ logger.info(f"No mask channel prob for input masking")
507
+ return x, mask_indices
508
+
509
+ def apply_feature_mask(self, x, padding_mask, target_list):
510
+ B, T, C = x.shape
511
+ assert self.mask_prob_audio == self.mask_prob_image and self.mask_length_audio == self.mask_length_image, f"masking prob/length for image/audio be same for feature masking"
512
+ mask_prob, mask_length = self.mask_prob_audio, self.mask_length_image
513
+ if mask_prob > 0:
514
+ mask_indices, _, _, _ = compute_mask_indices(
515
+ (B, T),
516
+ padding_mask,
517
+ mask_prob,
518
+ mask_length,
519
+ self.mask_selection,
520
+ self.mask_other,
521
+ min_masks=2,
522
+ no_overlap=self.no_mask_overlap,
523
+ min_space=self.mask_min_space,
524
+ )
525
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
526
+ x[mask_indices] = self.mask_emb
527
+ else:
528
+ mask_indices = None
529
+
530
+ if self.mask_channel_prob > 0:
531
+ mask_channel_indices, _, _, _ = compute_mask_indices(
532
+ (B, C),
533
+ None,
534
+ self.mask_channel_prob,
535
+ self.mask_channel_length,
536
+ self.mask_channel_selection,
537
+ self.mask_channel_other,
538
+ no_overlap=self.no_mask_channel_overlap,
539
+ min_space=self.mask_channel_min_space,
540
+ )
541
+ mask_channel_indices = (
542
+ torch.from_numpy(mask_channel_indices)
543
+ .to(x.device)
544
+ .unsqueeze(1)
545
+ .expand(-1, T, -1)
546
+ )
547
+ x[mask_channel_indices] = 0
548
+
549
+ return x, mask_indices
550
+
551
+ def forward_features(self, source: torch.Tensor, modality: str) -> torch.Tensor:
552
+ extractor = eval(f"self.feature_extractor_{modality}")
553
+ if self.feature_grad_mult > 0:
554
+ features = extractor(source)
555
+ if self.feature_grad_mult != 1.0:
556
+ features = GradMultiply.apply(features, self.feature_grad_mult)
557
+ else:
558
+ with torch.no_grad():
559
+ features = extractor(source)
560
+ return features
561
+
562
+ def forward_targets(
563
+ self, features: torch.Tensor, mask_indices: torch.Tensor, target_list: List[torch.Tensor],
564
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
565
+ # Trim features to ensure labels exist and then get aligned labels
566
+ feat_tsz = features.size(2)
567
+ targ_tsz = min([t.size(1) for t in target_list])
568
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
569
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
570
+ features = features[..., :feat_tsz]
571
+ if mask_indices is not None:
572
+ mask_indices = mask_indices[..., :feat_tsz]
573
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
574
+ target_list = [t[:, target_inds.long()] for t in target_list]
575
+ return features, mask_indices, target_list
576
+
577
+ def forward_padding_mask(
578
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
579
+ ) -> torch.Tensor:
580
+ extra = padding_mask.size(1) % features.size(1)
581
+ if extra > 0:
582
+ padding_mask = padding_mask[:, :-extra]
583
+ padding_mask = padding_mask.view(
584
+ padding_mask.size(0), features.size(1), -1
585
+ )
586
+ padding_mask = padding_mask.all(-1)
587
+ return padding_mask
588
+
589
+ def compute_logits(self, feats, emb_mat):
590
+ # feats: [B, T, F], emb_mat: [V, F]
591
+ if self.sim_type == 'dot':
592
+ logits = torch.matmul(feats, emb_mat.transpose(0, 1))
593
+ elif self.sim_type == 'cosine':
594
+ batch_size, timesteps, emb_dim = feats.size()
595
+ feats_ = feats.view(-1, emb_dim)
596
+ nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1) # [B*T, V]
597
+ denom = (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1) * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0) # [B*T, V]
598
+ logits = (nom/denom.clamp(min=1e-6)).view(batch_size, timesteps, -1)
599
+ else:
600
+ raise NotImplementedError
601
+ logits = logits / self.logit_temp
602
+ return logits
603
+
604
+ def forward(
605
+ self,
606
+ source: torch.Tensor,
607
+ target_list: Optional[List[torch.Tensor]] = None,
608
+ padding_mask: Optional[torch.Tensor] = None,
609
+ mask: bool = True,
610
+ features_only: bool = False,
611
+ output_layer: Optional[int] = None
612
+ ) -> Dict[str, torch.Tensor]:
613
+ """output layer is 1-based"""
614
+ src_audio, src_video = source['audio'], source['video']
615
+ if mask and self.masking_type == 'input':
616
+ src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list)
617
+ src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list)
618
+ mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video)
619
+ else:
620
+ src_audio, src_video, mask_indices = src_audio, src_video, None
621
+
622
+ features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
623
+ features_video = self.forward_features(src_video, modality='video')
624
+ modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random()
625
+ if self.training:
626
+ if modality_drop_prob < self.modality_dropout:
627
+ if audio_drop_prob < self.audio_dropout:
628
+ features_audio = 0 * features_audio
629
+ else:
630
+ features_video = 0 * features_video
631
+ if self.modality_fuse == 'concat':
632
+ features = torch.cat([features_audio, features_video], dim=1)
633
+ elif self.modality_fuse == 'add':
634
+ features = features_audio + features_video
635
+ if target_list is not None:
636
+ features, mask_indices, target_list = self.forward_targets(features, mask_indices, target_list)
637
+
638
+ features_pen = features.float().pow(2).mean()
639
+
640
+ features = features.transpose(1, 2)
641
+ features = self.layer_norm(features)
642
+
643
+ if padding_mask is not None:
644
+ padding_mask = self.forward_padding_mask(features, padding_mask)
645
+
646
+ if self.post_extract_proj is not None:
647
+ features = self.post_extract_proj(features)
648
+
649
+ features = self.dropout_input(features)
650
+ if self.masking_type == 'feature' and mask:
651
+ x, mask_indices = self.apply_feature_mask(features, padding_mask, target_list)
652
+ else:
653
+ x = features
654
+
655
+ # feature: (B, T, D), float
656
+ # target: (B, T), long
657
+ # x: (B, T, D), float
658
+ # padding_mask: (B, T), bool
659
+ # mask_indices: (B, T), bool
660
+ x, _ = self.encoder(
661
+ x,
662
+ padding_mask=padding_mask,
663
+ layer=None if output_layer is None else output_layer - 1
664
+ )
665
+
666
+ if features_only:
667
+ return {"x": x, "padding_mask": padding_mask, "features": features}
668
+
669
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
670
+ proj_x = self.final_proj(x)
671
+ if self.untie_final_proj:
672
+ proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1)
673
+ else:
674
+ proj_x_list = [proj_x for _ in self.num_classes]
675
+ logit_list = [self.compute_logits(proj, emb).view(-1, num_class) for proj, emb, num_class in zip(proj_x_list, label_embs_list, self.num_classes)] # [[B*T, V]]
676
+ mask, unmask = torch.logical_and(mask_indices, ~padding_mask).view(-1), torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T]
677
+ logit_m_list, logit_u_list = [logit[mask] for logit in logit_list], [logit[unmask] for logit in logit_list]
678
+ target_m_list, target_u_list = [target.view(-1)[mask].long() for target in target_list], [target.view(-1)[unmask].long() for target in target_list]
679
+ result = {
680
+ "logit_m_list": logit_m_list,
681
+ "logit_u_list": logit_u_list,
682
+ "target_m_list": target_m_list,
683
+ "target_u_list": target_u_list,
684
+ "padding_mask": padding_mask,
685
+ "features_pen": features_pen,
686
+ }
687
+ return result
688
+
689
+ def extract_features(
690
+ self,
691
+ source: torch.Tensor,
692
+ padding_mask: Optional[torch.Tensor] = None,
693
+ mask: bool = False,
694
+ ret_conv: bool = False,
695
+ output_layer: Optional[int] = None,
696
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
697
+ res = self.forward(
698
+ source,
699
+ padding_mask=padding_mask,
700
+ mask=mask,
701
+ features_only=True,
702
+ output_layer=output_layer,
703
+ )
704
+ feature = res["features"] if ret_conv else res["x"]
705
+ return feature, res["padding_mask"]
706
+
707
+ def extract_finetune(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None):
708
+ src_audio, src_video = source['audio'], source['video'] #torch.Size([1, 1, 106, 112, 112])
709
+ if mask and self.masking_type == 'input':
710
+ src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list=None)
711
+ src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list=None)
712
+ mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) # mask_indices not used in fine-tuning
713
+ else: #
714
+ src_audio, src_video, mask_indices = src_audio, src_video, None
715
+
716
+ if src_audio is not None and src_video is None:
717
+ features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
718
+ features_video = features_audio.new_zeros(features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1))
719
+ elif src_audio is None and src_video is not None:
720
+ features_video = self.forward_features(src_video, modality='video')
721
+ features_audio = features_video.new_zeros(features_video.size(0), self.encoder_embed_dim, features_video.size(-1)) #全0!
722
+ elif src_audio is not None and src_video is not None:
723
+ features_video = self.forward_features(src_video, modality='video') #torch.Size([1, 1024, 106]) #scr torch.Size([12, 1, 314, 88, 88])
724
+ features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T] #torch.Size([12, 26, 314])
725
+
726
+ if self.modality_fuse == 'concat': #
727
+ features = torch.cat([features_audio, features_video], dim=1) #torch.Size([1, 2048, 106])
728
+ elif self.modality_fuse == 'add':
729
+ features = features_audio + features_video
730
+ features_pen = features.float().pow(2).mean()
731
+
732
+ features = features.transpose(1, 2)
733
+ features = self.layer_norm(features)
734
+ unmasked_features = features.clone()
735
+
736
+ if padding_mask is not None: #features:torch.Size([1, 106, 2048])
737
+ padding_mask = self.forward_padding_mask(features, padding_mask) #torch.Size([4, 154])
738
+
739
+ if self.post_extract_proj is not None:
740
+ features = self.post_extract_proj(features) #torch.Size([1, 106, 1024])
741
+
742
+ features = self.dropout_input(features)
743
+ unmasked_features = self.dropout_features(unmasked_features)
744
+ x = features
745
+ mask_indices = None
746
+
747
+ # feature: (B, T, D), float
748
+ # target: (B, T), long
749
+ # x: (B, T, D), float
750
+ # padding_mask: (B, T), bool
751
+ # mask_indices: (B, T), bool
752
+ x, _ = self.encoder(
753
+ x,
754
+ padding_mask=padding_mask,
755
+ layer=None if output_layer is None else output_layer - 1
756
+ )
757
+
758
+ return x, padding_mask #torch.Size([1, 106, 1024]), None
759
+
760
+
761
+ def get_extra_losses(self, net_output):
762
+ extra_losses = []
763
+ names = []
764
+ if "features_pen" in net_output:
765
+ extra_losses.append(net_output["features_pen"])
766
+ names.append("features_pen")
767
+
768
+ return extra_losses, names
769
+
770
+ def remove_pretraining_modules(self):
771
+ self.target_glu = None
772
+ self.final_proj = None
773
+
774
+ def get_logits(self, net_output, is_masked=True):
775
+ raise NotImplementedError
776
+
777
+ def get_targets(self, net_output, is_masked=True):
778
+ raise NotImplementedError
779
+
780
+ def compute_nce(self, x, pos, negs):
781
+ neg_is_pos = (pos == negs).all(-1)
782
+ pos = pos.unsqueeze(0)
783
+ targets = torch.cat([pos, negs], dim=0)
784
+
785
+ logits = torch.cosine_similarity(
786
+ x.float(), targets.float(), dim=-1
787
+ ).type_as(x)
788
+ logits /= self.logit_temp
789
+ if neg_is_pos.any():
790
+ logits[1:][neg_is_pos] = float("-inf")
791
+ logits = logits.transpose(0, 1) # (num_x, num_cls+1)
792
+ return logits
slam_llm/models/avhubert/hubert_asr.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys,logging
8
+ import contextlib
9
+ import tempfile
10
+ from argparse import Namespace
11
+ from typing import Any, Optional
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from dataclasses import dataclass, field
16
+ from fairseq import checkpoint_utils, tasks, utils
17
+ from fairseq.dataclass import FairseqDataclass
18
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
19
+ from fairseq.models import BaseFairseqModel, FairseqEncoder, FairseqEncoderDecoderModel, register_model
20
+ from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES
21
+ from fairseq.tasks import FairseqTask
22
+ from omegaconf import II, MISSING
23
+
24
+ DBG=True if len(sys.argv) == 1 else False
25
+
26
+ if DBG:
27
+ from hubert import AVHubertModel
28
+ from decoder import TransformerDecoder
29
+ else:
30
+ from .hubert import AVHubertModel
31
+ from .decoder import TransformerDecoder
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class AVHubertAsrConfig(FairseqDataclass):
38
+ w2v_path: str = field(
39
+ default=MISSING, metadata={"help": "path to hubert model"}
40
+ )
41
+ no_pretrained_weights: bool = field(
42
+ default=False,
43
+ metadata={"help": "if true, does not load pretrained weights"},
44
+ )
45
+ dropout_input: float = field(
46
+ default=0.0,
47
+ metadata={"help": "dropout to apply to the input (after feat extr)"},
48
+ )
49
+ final_dropout: float = field(
50
+ default=0.0,
51
+ metadata={
52
+ "help": "dropout after transformer and before final projection"
53
+ },
54
+ )
55
+ dropout: float = field(
56
+ default=0.0,
57
+ metadata={"help": "dropout probability inside hubert model"},
58
+ )
59
+ attention_dropout: float = field(
60
+ default=0.0,
61
+ metadata={
62
+ "help": "dropout probability for attention weights "
63
+ "inside hubert model"
64
+ },
65
+ )
66
+ activation_dropout: float = field(
67
+ default=0.0,
68
+ metadata={
69
+ "help": "dropout probability after activation in FFN "
70
+ "inside hubert model"
71
+ },
72
+ )
73
+
74
+ # masking
75
+ apply_mask: bool = field(
76
+ default=False, metadata={"help": "apply masking during fine-tuning"}
77
+ )
78
+ mask_length: int = field(
79
+ default=10, metadata={"help": "repeat the mask indices multiple times"}
80
+ )
81
+ mask_prob: float = field(
82
+ default=0.5,
83
+ metadata={
84
+ "help": "probability of replacing a token with mask "
85
+ "(normalized by length)"
86
+ },
87
+ )
88
+ mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
89
+ default="static", metadata={"help": "how to choose masks"}
90
+ )
91
+ mask_other: float = field(
92
+ default=0,
93
+ metadata={
94
+ "help": "secondary mask argument "
95
+ "(used for more complex distributions), "
96
+ "see help in compute_mask_indices"
97
+ },
98
+ )
99
+ no_mask_overlap: bool = field(
100
+ default=False, metadata={"help": "whether to allow masks to overlap"}
101
+ )
102
+
103
+ # channel masking
104
+ mask_channel_length: int = field(
105
+ default=10,
106
+ metadata={"help": "length of the mask for features (channels)"},
107
+ )
108
+ mask_channel_prob: float = field(
109
+ default=0.0,
110
+ metadata={"help": "probability of replacing a feature with 0"},
111
+ )
112
+ mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
113
+ default="static",
114
+ metadata={"help": "how to choose mask length for channel masking"},
115
+ )
116
+ mask_channel_other: float = field(
117
+ default=0,
118
+ metadata={
119
+ "help": "secondary mask argument "
120
+ "(used for more complex distributions), "
121
+ "see help in compute_mask_indices"
122
+ },
123
+ )
124
+ no_mask_channel_overlap: bool = field(
125
+ default=False,
126
+ metadata={"help": "whether to allow channel masks to overlap"},
127
+ )
128
+ freeze_finetune_updates: int = field(
129
+ default=0,
130
+ metadata={"help": "dont finetune hubert for this many updates"},
131
+ )
132
+ feature_grad_mult: float = field(
133
+ default=0.0,
134
+ metadata={"help": "reset feature grad mult in hubert to this"},
135
+ )
136
+ layerdrop: float = field(
137
+ default=0.0,
138
+ metadata={"help": "probability of dropping a layer in hubert"},
139
+ )
140
+ normalize: bool = II("task.normalize")
141
+ data: str = II("task.data")
142
+
143
+ # this holds the loaded hubert args
144
+ w2v_args: Any = None
145
+
146
+
147
+ @dataclass
148
+ class AVHubertCtcConfig(AVHubertAsrConfig):
149
+ pass
150
+
151
+
152
+ @register_model("av_hubert_ctc", dataclass=AVHubertCtcConfig)
153
+ class AVHubertCtc(BaseFairseqModel):
154
+ def __init__(self, cfg: AVHubertCtcConfig, w2v_encoder: BaseFairseqModel):
155
+ super().__init__()
156
+ self.cfg = cfg
157
+ self.w2v_encoder = w2v_encoder
158
+
159
+ def upgrade_state_dict_named(self, state_dict, name):
160
+ super().upgrade_state_dict_named(state_dict, name)
161
+ return state_dict
162
+
163
+ @classmethod
164
+ def build_model(cls, cfg: AVHubertCtcConfig, task: FairseqTask):
165
+ """Build a new model instance."""
166
+ w2v_encoder = HubertEncoder(cfg, task.target_dictionary)
167
+ return cls(cfg, w2v_encoder)
168
+
169
+ def get_normalized_probs(self, net_output, log_probs):
170
+ """Get normalized probabilities (or log probs) from a net's output."""
171
+
172
+ logits = net_output["encoder_out"]
173
+ if log_probs:
174
+ return utils.log_softmax(logits.float(), dim=-1)
175
+ else:
176
+ return utils.softmax(logits.float(), dim=-1)
177
+
178
+ def get_logits(self, net_output):
179
+ logits = net_output["encoder_out"]
180
+ padding = net_output["encoder_padding_mask"]
181
+ if padding is not None and padding.any():
182
+ padding = padding.T
183
+ logits[padding][..., 0] = 0
184
+ logits[padding][..., 1:] = float("-inf")
185
+
186
+ return logits
187
+
188
+ def forward(self, **kwargs):
189
+ x = self.w2v_encoder(**kwargs)
190
+ return x
191
+
192
+
193
+ @dataclass
194
+ class AVHubertSeq2SeqConfig(AVHubertAsrConfig):
195
+ decoder_embed_dim: int = field(
196
+ default=768, metadata={"help": "decoder embedding dimension"}
197
+ )
198
+ decoder_ffn_embed_dim: int = field(
199
+ default=3072, metadata={"help": "decoder embedding dimension for FFN"}
200
+ )
201
+ decoder_layers: int = field(
202
+ default=6, metadata={"help": "num of decoder layers"}
203
+ )
204
+ decoder_layerdrop: float = field(
205
+ default=0.0, metadata={"help": "decoder layerdrop chance"}
206
+ )
207
+ decoder_attention_heads: int = field(
208
+ default=4, metadata={"help": "num decoder attention heads"}
209
+ )
210
+ decoder_learned_pos: bool = field(
211
+ default=False,
212
+ metadata={"help": "use learned positional embeddings in the decoder"},
213
+ )
214
+ decoder_normalize_before: bool = field(
215
+ default=False,
216
+ metadata={"help": "apply layernorm before each decoder block"},
217
+ )
218
+ no_token_positional_embeddings: bool = field(
219
+ default=False,
220
+ metadata={
221
+ "help": "if set, disables positional embeddings "
222
+ "(outside self attention)"
223
+ },
224
+ )
225
+ decoder_dropout: float = field(
226
+ default=0.0, metadata={"help": "dropout probability in the decoder"}
227
+ )
228
+ decoder_attention_dropout: float = field(
229
+ default=0.0,
230
+ metadata={
231
+ "help": "dropout probability for attention weights "
232
+ "inside the decoder"
233
+ },
234
+ )
235
+ decoder_activation_dropout: float = field(
236
+ default=0.0,
237
+ metadata={
238
+ "help": "dropout probability after activation in FFN "
239
+ "inside the decoder"
240
+ },
241
+ )
242
+ max_target_positions: int = field(
243
+ default=2048, metadata={"help": "max target positions"}
244
+ )
245
+ share_decoder_input_output_embed: bool = field(
246
+ default=False,
247
+ metadata={"help": "share decoder input and output embeddings"},
248
+ )
249
+ no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
250
+
251
+ class HubertEncoder(FairseqEncoder):
252
+ def __init__(self, cfg: AVHubertAsrConfig, tgt_dict=None):
253
+ self.apply_mask = cfg.apply_mask
254
+
255
+ arg_overrides = {
256
+ "dropout": cfg.dropout,
257
+ "activation_dropout": cfg.activation_dropout,
258
+ "dropout_input": cfg.dropout_input,
259
+ "attention_dropout": cfg.attention_dropout,
260
+ "mask_length": cfg.mask_length,
261
+ "mask_prob": cfg.mask_prob,
262
+ "mask_selection": cfg.mask_selection,
263
+ "mask_other": cfg.mask_other,
264
+ "no_mask_overlap": cfg.no_mask_overlap,
265
+ "mask_channel_length": cfg.mask_channel_length,
266
+ "mask_channel_prob": cfg.mask_channel_prob,
267
+ "mask_channel_selection": cfg.mask_channel_selection,
268
+ "mask_channel_other": cfg.mask_channel_other,
269
+ "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
270
+ "encoder_layerdrop": cfg.layerdrop,
271
+ "feature_grad_mult": cfg.feature_grad_mult,
272
+ }
273
+
274
+ if cfg.w2v_args is None:
275
+ state = checkpoint_utils.load_checkpoint_to_cpu(
276
+ cfg.w2v_path, arg_overrides
277
+ )
278
+ w2v_args = state.get("cfg", None)
279
+ if w2v_args is None:
280
+ w2v_args = convert_namespace_to_omegaconf(state["args"])
281
+ cfg.w2v_args = w2v_args
282
+ else:
283
+ state = None
284
+ w2v_args = cfg.w2v_args
285
+ if isinstance(w2v_args, Namespace):
286
+ cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
287
+ w2v_args
288
+ )
289
+
290
+ assert cfg.normalize == w2v_args.task.normalize, (
291
+ "Fine-tuning works best when data normalization is the same. "
292
+ "Please check that --normalize is set or unset for "
293
+ "both pre-training and here"
294
+ )
295
+
296
+ w2v_args.task.data = cfg.data
297
+
298
+ task = tasks.setup_task(w2v_args.task)
299
+ model = task.build_model(w2v_args.model)
300
+
301
+ if state is not None and not cfg.no_pretrained_weights:
302
+ # set strict=False because we omit some modules
303
+ model.load_state_dict(state["model"], strict=False)
304
+
305
+ model.remove_pretraining_modules()
306
+
307
+ super().__init__(task.source_dictionary)
308
+
309
+ d = model.encoder.embedding_dim
310
+
311
+ self.w2v_model = model
312
+
313
+ self.final_dropout = nn.Dropout(cfg.final_dropout)
314
+ self.freeze_finetune_updates = cfg.freeze_finetune_updates
315
+ self.num_updates = 0
316
+
317
+ if tgt_dict is not None:
318
+ self.proj = Linear(d, len(tgt_dict))
319
+ elif getattr(cfg, "decoder_embed_dim", d) != d:
320
+ self.proj = Linear(d, cfg.decoder_embed_dim)
321
+ else:
322
+ self.proj = None
323
+
324
+ def set_num_updates(self, num_updates):
325
+ """Set the number of parameters updates."""
326
+ super().set_num_updates(num_updates)
327
+ self.num_updates = num_updates
328
+
329
+ def forward(self, source, padding_mask, tbc=True, **kwargs):
330
+
331
+ w2v_args = {
332
+ "source": source,
333
+ "padding_mask": padding_mask,
334
+ "mask": self.apply_mask and self.training,
335
+ }
336
+ ft = self.freeze_finetune_updates <= self.num_updates
337
+
338
+ with torch.no_grad() if not ft else contextlib.ExitStack():
339
+ x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
340
+
341
+ if tbc:
342
+ # B x T x C -> T x B x C
343
+ x = x.transpose(0, 1)
344
+
345
+ x = self.final_dropout(x)
346
+
347
+ if self.proj:
348
+ x = self.proj(x)
349
+
350
+ return {
351
+ "encoder_out": x, # T x B x C
352
+ "encoder_padding_mask": padding_mask, # B x T
353
+ "padding_mask": padding_mask,
354
+ }
355
+
356
+ def reorder_encoder_out(self, encoder_out, new_order):
357
+ if encoder_out["encoder_out"] is not None:
358
+ encoder_out["encoder_out"] = encoder_out[
359
+ "encoder_out"
360
+ ].index_select(1, new_order)
361
+ if encoder_out["encoder_padding_mask"] is not None:
362
+ encoder_out["encoder_padding_mask"] = encoder_out[
363
+ "encoder_padding_mask"
364
+ ].index_select(0, new_order)
365
+ return encoder_out
366
+
367
+ def max_positions(self):
368
+ """Maximum input length supported by the encoder."""
369
+ return None
370
+
371
+ def upgrade_state_dict_named(self, state_dict, name):
372
+ return state_dict
373
+
374
+
375
+ class HubertEncoderWrapper(FairseqEncoder):
376
+ def __init__(self, w2v_model):
377
+ super().__init__(None)
378
+ self.w2v_model = w2v_model
379
+
380
+ def forward(self, source, padding_mask, **kwargs):
381
+ w2v_args = {
382
+ "source": source,
383
+ "padding_mask": padding_mask,
384
+ }
385
+
386
+ x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
387
+ # B x T x C -> T x B x C
388
+ x = x.transpose(0, 1) #torch.Size([106, 1, 1024])
389
+
390
+ return {
391
+ "encoder_out": x, # T x B x C
392
+ "encoder_padding_mask": padding_mask, # B x T
393
+ "padding_mask": padding_mask
394
+ }
395
+
396
+ def reorder_encoder_out(self, encoder_out, new_order):
397
+ if encoder_out["encoder_out"] is not None:
398
+ encoder_out["encoder_out"] = encoder_out[
399
+ "encoder_out"
400
+ ].index_select(1, new_order)
401
+ if encoder_out["encoder_padding_mask"] is not None:
402
+ encoder_out["encoder_padding_mask"] = encoder_out[
403
+ "encoder_padding_mask"
404
+ ].index_select(0, new_order)
405
+ if encoder_out["padding_mask"] is not None:
406
+ encoder_out["padding_mask"] = encoder_out[
407
+ "padding_mask"
408
+ ].index_select(0, new_order)
409
+ return encoder_out
410
+
411
+ @register_model("av_hubert_seq2seq", dataclass=AVHubertSeq2SeqConfig)
412
+ class AVHubertSeq2Seq(FairseqEncoderDecoderModel):
413
+ def __init__(self, encoder, decoder, tgt_dict, cfg):
414
+ super().__init__(encoder, decoder)
415
+ self.cfg = cfg
416
+ self.freeze_finetune_updates = cfg.freeze_finetune_updates
417
+
418
+ @classmethod
419
+ def build_model(cls, cfg, task):
420
+ """Build a new model instance."""
421
+
422
+ arg_overrides = {
423
+ "dropout": cfg.dropout,
424
+ "activation_dropout": cfg.activation_dropout,
425
+ "dropout_input": cfg.dropout_input,
426
+ "attention_dropout": cfg.attention_dropout,
427
+ "mask_length": cfg.mask_length,
428
+ "mask_prob": cfg.mask_prob,
429
+ "mask_selection": cfg.mask_selection,
430
+ "mask_other": cfg.mask_other,
431
+ "no_mask_overlap": cfg.no_mask_overlap,
432
+ "mask_channel_length": cfg.mask_channel_length,
433
+ "mask_channel_prob": cfg.mask_channel_prob,
434
+ "mask_channel_selection": cfg.mask_channel_selection,
435
+ "mask_channel_other": cfg.mask_channel_other,
436
+ "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
437
+ "encoder_layerdrop": cfg.layerdrop,
438
+ "feature_grad_mult": cfg.feature_grad_mult,
439
+ }
440
+
441
+ if cfg.w2v_args is None:
442
+ state = checkpoint_utils.load_checkpoint_to_cpu(
443
+ cfg.w2v_path, arg_overrides
444
+ )
445
+ w2v_args = state.get("cfg", None)
446
+ if w2v_args is None:
447
+ w2v_args = convert_namespace_to_omegaconf(state["args"])
448
+ cfg.w2v_args = w2v_args
449
+ else:
450
+ state = None
451
+ w2v_args = cfg.w2v_args
452
+ if isinstance(w2v_args, Namespace):
453
+ cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
454
+ w2v_args
455
+ )
456
+
457
+ assert cfg.normalize == w2v_args.task.normalize, (
458
+ "Fine-tuning works best when data normalization is the same. "
459
+ "Please check that --normalize is set or unset for "
460
+ "both pre-training and here"
461
+ )
462
+
463
+ w2v_args.task.data = cfg.data
464
+
465
+ task_pretrain = tasks.setup_task(w2v_args.task)
466
+ if state is not None:
467
+ task_pretrain.load_state_dict(state['task_state'])
468
+
469
+ encoder_ = task_pretrain.build_model(w2v_args.model)
470
+
471
+ encoder = HubertEncoderWrapper(encoder_)
472
+ if state is not None and not cfg.no_pretrained_weights:
473
+ # set strict=False because we omit some modules
474
+ del state['model']['mask_emb']
475
+ encoder.w2v_model.load_state_dict(state["model"], strict=False)
476
+
477
+ encoder.w2v_model.remove_pretraining_modules()
478
+
479
+ src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
480
+
481
+ def build_embedding(dictionary, embed_dim):
482
+ num_embeddings = len(dictionary)
483
+ padding_idx = dictionary.pad()
484
+ emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx)
485
+ return emb
486
+
487
+ decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim)
488
+ decoder = TransformerDecoder(cfg, tgt_dict, decoder_embed_tokens)
489
+
490
+ return AVHubertSeq2Seq(encoder, decoder, tgt_dict, cfg)
491
+
492
+
493
+ def forward(self, **kwargs):
494
+ # ft = self.freeze_finetune_updates <= self.num_updates
495
+ # with torch.no_grad() if not ft else contextlib.ExitStack():
496
+ # output = self.encoder(**kwargs)
497
+ with torch.no_grad():
498
+ output = self.encoder(**kwargs) #encoder_out,encoder_padding_mask,padding_mask
499
+ # decoder_out = self.decoder(prev_output_tokens=kwargs['prev_output_tokens'], encoder_out=output)
500
+ return output
501
+
502
+ def upgrade_state_dict_named(self, state_dict, name):
503
+ super().upgrade_state_dict_named(state_dict, name)
504
+ return state_dict
505
+
506
+ def set_num_updates(self, num_updates):
507
+ """Set the number of parameters updates."""
508
+ super().set_num_updates(num_updates)
509
+ self.num_updates = num_updates
510
+
511
+ def Embedding(num_embeddings, embedding_dim, padding_idx):
512
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
513
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
514
+ nn.init.constant_(m.weight[padding_idx], 0)
515
+ return m
516
+
517
+
518
+ def Linear(in_features, out_features, bias=True):
519
+ m = nn.Linear(in_features, out_features, bias)
520
+ nn.init.xavier_uniform_(m.weight)
521
+ if bias:
522
+ nn.init.constant_(m.bias, 0.0)
523
+ return m
slam_llm/models/avhubert/hubert_criterion.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import re
9
+ from dataclasses import dataclass, field
10
+ from typing import List, Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from fairseq import metrics, utils
15
+ from fairseq.criterions import FairseqCriterion, register_criterion
16
+ from fairseq.dataclass import FairseqDataclass
17
+
18
+
19
+ @dataclass
20
+ class AVHubertCriterionConfig(FairseqDataclass):
21
+ pred_masked_weight: float = field(
22
+ default=1.0,
23
+ metadata={"help": "weight for predictive loss for masked frames"},
24
+ )
25
+ pred_nomask_weight: float = field(
26
+ default=0.0,
27
+ metadata={"help": "weight for predictive loss for unmasked frames"},
28
+ )
29
+ loss_weights: Optional[List[float]] = field(
30
+ default=None,
31
+ metadata={"help": "weights for additional loss terms (not first one)"},
32
+ )
33
+ log_keys: List[str] = field(
34
+ default_factory=lambda: [],
35
+ metadata={"help": "output keys to log"},
36
+ )
37
+
38
+
39
+ @register_criterion("av_hubert", dataclass=AVHubertCriterionConfig)
40
+ class AVHubertCriterion(FairseqCriterion):
41
+ def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
42
+ super().__init__(task)
43
+ self.pred_masked_weight = pred_masked_weight
44
+ self.pred_nomask_weight = pred_nomask_weight
45
+ self.loss_weights = loss_weights
46
+ self.log_keys = [] if log_keys is None else log_keys
47
+
48
+ def forward(self, model, sample, reduce=True, log_pred=False):
49
+ """Compute the loss for the given sample.
50
+ Returns a tuple with three elements:
51
+ 1) the loss
52
+ 2) the sample size, which is used as the denominator for the gradient
53
+ 3) logging outputs to display while training
54
+ """
55
+ net_output = model(target_list=sample["target_list"], **sample["net_input"])
56
+ loss = 0.
57
+ sample_size = 0
58
+ logging_output = {}
59
+ reduction = "sum" if reduce else "none"
60
+
61
+ loss_m_list = []
62
+ logp_m_list, targ_m_list = net_output['logit_m_list'], net_output['target_m_list']
63
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
64
+ loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
65
+ loss_m_list.append(loss_m)
66
+ logging_output[f"loss_m_{i}"] = loss_m.detach().item()
67
+ if self.pred_masked_weight > 0:
68
+ loss += self.pred_masked_weight * sum(loss_m_list)
69
+ sample_size += targ_m_list[0].numel()
70
+
71
+ loss_u_list = []
72
+ logp_u_list, targ_u_list = net_output['logit_u_list'], net_output['target_u_list']
73
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
74
+ loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
75
+ loss_u_list.append(loss_u)
76
+ logging_output[f"loss_u_{i}"] = loss_u.detach().item()
77
+ if self.pred_nomask_weight > 0:
78
+ loss += self.pred_nomask_weight * sum(loss_u_list)
79
+ sample_size += targ_u_list[0].numel()
80
+
81
+ if self.loss_weights is not None:
82
+ assert hasattr(model, "get_extra_losses")
83
+ extra_losses, names = model.get_extra_losses(net_output)
84
+ if torch.is_tensor(extra_losses):
85
+ extra_losses = [extra_losses]
86
+ names = [names]
87
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
88
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
89
+ assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
90
+ for p, n, coef in zip(extra_losses, names, self.loss_weights):
91
+ if coef != 0 and p is not None:
92
+ p = coef * p.float() * sample_size
93
+ loss += p
94
+ logging_output[f"loss_{n}"] = p.item()
95
+
96
+ logging_output = {
97
+ "loss": loss.item() if reduce else loss,
98
+ "ntokens": sample_size,
99
+ "nsentences": sample["id"].numel(),
100
+ "sample_size": sample_size,
101
+ **logging_output,
102
+ }
103
+
104
+ for lk in self.log_keys:
105
+ if lk in net_output:
106
+ logging_output[lk] = float((net_output[lk]))
107
+
108
+ with torch.no_grad():
109
+ for i, logp_m in enumerate(logp_m_list):
110
+ # corr_m, count_m = compute_correct(logp_m)
111
+ if logp_m.numel() == 0:
112
+ corr_m, count_m = 0, 0
113
+ else:
114
+ corr_m, count_m = (logp_m.argmax(dim=-1)==targ_m_list[i]).sum().item(), len(targ_m_list[i])
115
+ logging_output[f"correct_m_{i}"] = corr_m
116
+ logging_output[f"count_m_{i}"] = count_m
117
+
118
+ for i, logp_u in enumerate(logp_u_list):
119
+ if logp_u.numel() == 0:
120
+ corr_u, count_u = 0, 0
121
+ else:
122
+ corr_u, count_u = (logp_u.argmax(dim=-1)==targ_u_list[i]).sum().item(), len(targ_u_list[i])
123
+ logging_output[f"correct_u_{i}"] = corr_u
124
+ logging_output[f"count_u_{i}"] = count_u
125
+
126
+ return loss, sample_size, logging_output
127
+
128
+ @staticmethod
129
+ def reduce_metrics(logging_outputs) -> None:
130
+ """Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
131
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
132
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
133
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
134
+
135
+ metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
136
+ if sample_size != ntokens:
137
+ metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
138
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
139
+ else:
140
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
141
+
142
+ counts = {}
143
+ for lk in logging_outputs[0].keys():
144
+ if lk.startswith("count_"):
145
+ val = sum(log[lk] for log in logging_outputs)
146
+ metrics.log_scalar(lk, val)
147
+ counts[lk] = val
148
+
149
+ for lk in logging_outputs[0].keys():
150
+ if lk.startswith("loss_"):
151
+ val = sum(log[lk] for log in logging_outputs)
152
+ metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
153
+ elif lk.startswith("correct_"):
154
+ val = sum(log[lk] for log in logging_outputs)
155
+ metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
156
+
157
+ @staticmethod
158
+ def aggregate_logging_outputs(logging_outputs):
159
+ """Aggregate logging outputs from data parallel training."""
160
+ raise NotImplementedError()
161
+
162
+ @staticmethod
163
+ def logging_outputs_can_be_summed() -> bool:
164
+ """
165
+ Whether the logging outputs returned by `forward` can be summed
166
+ across workers prior to calling `reduce_metrics`. Setting this
167
+ to True will improves distributed training speed.
168
+ """
169
+ return False
slam_llm/models/avhubert/hubert_dataset.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import itertools
8
+ import logging
9
+ import os
10
+ import sys
11
+ import time
12
+ from typing import Any, List, Optional, Union
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from fairseq.data import data_utils
19
+ from fairseq.data.fairseq_dataset import FairseqDataset
20
+ from python_speech_features import logfbank
21
+ from scipy.io import wavfile
22
+
23
+ DBG=True if len(sys.argv) == 1 else False
24
+
25
+ if DBG:
26
+ import utils as custom_utils
27
+ logging.basicConfig(
28
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
29
+ datefmt="%Y-%m-%d %H:%M:%S",
30
+ level=os.environ.get("LOGLEVEL", "DEBUG").upper(),
31
+ stream=sys.stdout,
32
+ )
33
+ else:
34
+ from . import utils as custom_utils
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def load_audio_visual(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1):
40
+ def is_audio_label_aligned(audio_dur, label_durs):
41
+ return all([abs(audio_dur - label_dur)<tol for label_dur in label_durs])
42
+
43
+ n_long, n_short, n_unaligned = 0, 0, 0
44
+ names, inds, sizes = [], [], []
45
+ dur_from_label_list = []
46
+ is_seq_label = any([x==-1 for x in label_rates])
47
+ for label_path, label_rate in zip(label_paths, label_rates):
48
+ label_lengths = [len(line.rstrip().split())/label_rate for line in open(label_path).readlines()]
49
+ dur_from_label_list.append(label_lengths)
50
+ dur_from_label_list = list(zip(*dur_from_label_list))
51
+
52
+ with open(manifest_path) as f:
53
+ root = f.readline().strip()
54
+ for ind, line in enumerate(f):
55
+ items = line.strip().split("\t")
56
+ sz = int(items[-2]) #
57
+ if min_keep is not None and sz < min_keep:
58
+ n_short += 1
59
+ elif max_keep is not None and sz > max_keep:
60
+ n_long += 1
61
+ elif (not is_seq_label) and (not is_audio_label_aligned(sz/frame_rate, dur_from_label_list[ind])):
62
+ n_unaligned += 1
63
+ else:
64
+ video_path = items[1]
65
+ audio_path = items[2]
66
+ audio_id = items[0]
67
+ names.append((video_path, audio_path+':'+audio_id))
68
+ inds.append(ind)
69
+ sizes.append(sz)
70
+ tot = ind + 1
71
+ logger.info(
72
+ (
73
+ f"max_keep={max_keep}, min_keep={min_keep}, "
74
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long and {n_unaligned} unaligned, "
75
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
76
+ )
77
+ )
78
+ return root, names, inds, tot, sizes
79
+
80
+ def load_label(label_path, inds, tot):
81
+ with open(label_path) as f:
82
+ labels = [line.rstrip() for line in f]
83
+ assert (
84
+ len(labels) == tot
85
+ ), f"number of labels does not match ({len(labels)} != {tot})"
86
+ labels = [labels[i] for i in inds]
87
+ return labels
88
+
89
+
90
+ def load_label_offset(label_path, inds, tot):
91
+ with open(label_path) as f:
92
+ code_lengths = [len(line.encode("utf-8")) for line in f]
93
+ assert (
94
+ len(code_lengths) == tot
95
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
96
+ offsets = list(itertools.accumulate([0] + code_lengths))
97
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
98
+ return offsets
99
+
100
+
101
+ def verify_label_lengths(
102
+ audio_sizes,
103
+ audio_rate,
104
+ label_path,
105
+ label_rate,
106
+ inds,
107
+ tot,
108
+ tol=0.1, # tolerance in seconds
109
+ ):
110
+ if label_rate < 0:
111
+ logger.info(f"{label_path} is sequence label. skipped")
112
+ return
113
+
114
+ with open(label_path) as f:
115
+ lengths = [len(line.rstrip().split()) for line in f]
116
+ assert len(lengths) == tot
117
+ lengths = [lengths[i] for i in inds]
118
+ num_invalid = 0
119
+ for i, ind in enumerate(inds):
120
+ dur_from_audio = audio_sizes[i] / audio_rate
121
+ dur_from_label = lengths[i] / label_rate
122
+ if abs(dur_from_audio - dur_from_label) > tol:
123
+ logger.warning(
124
+ (
125
+ f"audio and label duration differ too much "
126
+ f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
127
+ f"in line {ind+1} of {label_path}. Check if `label_rate` "
128
+ f"is correctly set (currently {label_rate}). "
129
+ f"num. of samples = {audio_sizes[i]}; "
130
+ f"label length = {lengths[i]}"
131
+ )
132
+ )
133
+ num_invalid += 1
134
+ if num_invalid > 0:
135
+ logger.warning(
136
+ f"total {num_invalid} (audio, label) pairs with mismatched lengths"
137
+ )
138
+
139
+
140
+ class AVHubertDataset(FairseqDataset):
141
+ def __init__(
142
+ self,
143
+ manifest_path: str,
144
+ sample_rate: float,
145
+ label_paths: List[str],
146
+ label_rates: Union[List[float], float], # -1 for sequence labels
147
+ pad_list: List[str],
148
+ eos_list: List[str],
149
+ label_processors: Optional[List[Any]] = None,
150
+ max_keep_sample_size: Optional[int] = None,
151
+ min_keep_sample_size: Optional[int] = None,
152
+ max_sample_size: Optional[int] = None,
153
+ shuffle: bool = True,
154
+ pad_audio: bool = False,
155
+ normalize: bool = False,
156
+ store_labels: bool = True,
157
+ random_crop: bool = False,
158
+ single_target: bool = False,
159
+ stack_order_audio: int=1,
160
+ skip_verify: bool=False,
161
+ image_mean: float=0,
162
+ image_std: float=1,
163
+ image_crop_size: int=88,
164
+ image_aug: bool=False,
165
+ modalities: Optional[List[str]]=None,
166
+ is_s2s=False,
167
+ noise_fn=None,
168
+ noise_prob=0,
169
+ noise_snr=0,
170
+ noise_num=1
171
+ ):
172
+ self.label_rates = (
173
+ [label_rates for _ in range(len(label_paths))]
174
+ if isinstance(label_rates, int)
175
+ else label_rates
176
+ )
177
+ self.modalities = set(modalities)
178
+ self.audio_root, self.names, inds, tot, self.sizes = load_audio_visual(manifest_path, max_keep_sample_size, min_keep_sample_size, frame_rate=sample_rate, label_paths=label_paths, label_rates=self.label_rates)
179
+ self.sample_rate = sample_rate
180
+ self.stack_order_audio = stack_order_audio
181
+ self.shuffle = shuffle
182
+ self.random_crop = random_crop
183
+
184
+ self.num_labels = len(label_paths)
185
+ self.pad_list = pad_list
186
+ self.eos_list = eos_list
187
+ self.label_processors = label_processors
188
+ self.single_target = single_target
189
+ self.store_labels = store_labels
190
+ self.is_s2s = is_s2s
191
+ self.noise_wav, self.noise_prob, self.noise_snr, self.noise_num = [ln.strip() for ln in open(noise_fn).readlines()] if noise_fn is not None else [], noise_prob, noise_snr, noise_num
192
+
193
+ assert self.single_target == (self.label_rates[0] == -1), f"single target should be equivalent to sequence label (label_rate==-1)"
194
+ if store_labels:
195
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
196
+ else:
197
+ self.label_paths = label_paths
198
+ self.label_offsets_list = [
199
+ load_label_offset(p, inds, tot) for p in label_paths
200
+ ]
201
+ assert (
202
+ label_processors is None
203
+ or len(label_processors) == self.num_labels
204
+ )
205
+ if not skip_verify:
206
+ for label_path, label_rate in zip(label_paths, self.label_rates):
207
+ verify_label_lengths(self.sizes, self.sample_rate, label_path, label_rate, inds, tot)
208
+ else:
209
+ logger.info(f"Skip label alignment verifying")
210
+
211
+ self.max_sample_size = (
212
+ max_sample_size if max_sample_size is not None else sys.maxsize
213
+ )
214
+ self.pad_audio = pad_audio
215
+ self.normalize = normalize
216
+ if image_aug:
217
+ self.transform = custom_utils.Compose([
218
+ custom_utils.Normalize( 0.0,255.0 ),
219
+ custom_utils.RandomCrop((image_crop_size, image_crop_size)),
220
+ custom_utils.HorizontalFlip(0.5),
221
+ custom_utils.Normalize(image_mean, image_std) ])
222
+ else:
223
+ self.transform = custom_utils.Compose([
224
+ custom_utils.Normalize( 0.0,255.0 ),
225
+ custom_utils.CenterCrop((image_crop_size, image_crop_size)),
226
+ custom_utils.Normalize(image_mean, image_std) ])
227
+ logger.info(f"image transform: {self.transform}")
228
+
229
+ logger.info(
230
+ f"pad_audio={pad_audio}, random_crop={random_crop}, "
231
+ f"normalize={normalize}, max_sample_size={self.max_sample_size}, "
232
+ f"seqs2seq data={self.is_s2s},")
233
+ logger.info(
234
+ f"Noise wav: {noise_fn}->{len(self.noise_wav)} wav, Prob: {self.noise_prob}, SNR: {self.noise_snr}, Number of mixture: {self.noise_num}"
235
+ )
236
+
237
+ def get_label(self, index, label_idx):
238
+ if self.store_labels:
239
+ label = self.label_list[label_idx][index]
240
+ else:
241
+ with open(self.label_paths[label_idx]) as f:
242
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
243
+ f.seek(offset_s)
244
+ label = f.read(offset_e - offset_s)
245
+
246
+ if self.label_processors is not None:
247
+ label = self.label_processors[label_idx](label)
248
+ return label
249
+
250
+ def get_labels(self, index):
251
+ return [self.get_label(index, i) for i in range(self.num_labels)]
252
+
253
+ def load_feature(self, mix_name):
254
+ """
255
+ Load image and audio feature
256
+ Returns:
257
+ video_feats: numpy.ndarray of shape [T, H, W, 1], audio_feats: numpy.ndarray of shape [T, F]
258
+ """
259
+ def stacker(feats, stack_order):
260
+ """
261
+ Concatenating consecutive audio frames
262
+ Args:
263
+ feats - numpy.ndarray of shape [T, F]
264
+ stack_order - int (number of neighboring frames to concatenate
265
+ Returns:
266
+ feats - numpy.ndarray of shape [T', F']
267
+ """
268
+ feat_dim = feats.shape[1]
269
+ if len(feats) % stack_order != 0:
270
+ res = stack_order - len(feats) % stack_order
271
+ res = np.zeros([res, feat_dim]).astype(feats.dtype)
272
+ feats = np.concatenate([feats, res], axis=0)
273
+ feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim)
274
+ return feats
275
+ video_fn, audio_fn = mix_name
276
+ if 'video' in self.modalities:
277
+ video_feats = self.load_video(video_fn) # [T, H, W, 1]
278
+ else:
279
+ video_feats = None
280
+ if 'audio' in self.modalities:
281
+ audio_fn = audio_fn.split(':')[0]
282
+ sample_rate, wav_data = wavfile.read(audio_fn)
283
+ assert sample_rate == 16_000 and len(wav_data.shape) == 1
284
+ if np.random.rand() < self.noise_prob:
285
+ wav_data = self.add_noise(wav_data)
286
+ audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32) # [T, F]
287
+ audio_feats = stacker(audio_feats, self.stack_order_audio) # [T/stack_order_audio, F*stack_order_audio]
288
+ else:
289
+ audio_feats = None
290
+ if audio_feats is not None and video_feats is not None:
291
+ diff = len(audio_feats) - len(video_feats)
292
+ if diff < 0:
293
+ audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)])
294
+ elif diff > 0:
295
+ audio_feats = audio_feats[:-diff]
296
+ return video_feats, audio_feats
297
+
298
+ def load_video(self, audio_name):
299
+ feats = custom_utils.load_video(os.path.join(self.audio_root, audio_name))
300
+ feats = self.transform(feats)
301
+ feats = np.expand_dims(feats, axis=-1)
302
+ return feats
303
+
304
+ def select_noise(self):
305
+ rand_indexes = np.random.randint(0, len(self.noise_wav), size=self.noise_num)
306
+ noise_wav = []
307
+ for x in rand_indexes:
308
+ noise_wav.append(wavfile.read(self.noise_wav[x])[1].astype(np.float32))
309
+ if self.noise_num == 1:
310
+ return noise_wav[0]
311
+ else:
312
+ min_len = min([len(x) for x in noise_wav])
313
+ noise_wav = [x[:min_len] for x in noise_wav]
314
+ noise_wav = np.floor(np.stack(noise_wav).mean(axis=0))
315
+ return noise_wav
316
+
317
+ def add_noise(self, clean_wav):
318
+ clean_wav = clean_wav.astype(np.float32)
319
+ noise_wav = self.select_noise()
320
+ if type(self.noise_snr) == int or type(self.noise_snr) == float:
321
+ snr = self.noise_snr
322
+ elif type(self.noise_snr) == tuple:
323
+ snr = np.random.randint(self.noise_snr[0], self.noise_snr[1]+1)
324
+ clean_rms = np.sqrt(np.mean(np.square(clean_wav), axis=-1))
325
+ if len(clean_wav) > len(noise_wav):
326
+ ratio = int(np.ceil(len(clean_wav)/len(noise_wav)))
327
+ noise_wav = np.concatenate([noise_wav for _ in range(ratio)])
328
+ if len(clean_wav) < len(noise_wav):
329
+ start = 0
330
+ noise_wav = noise_wav[start: start + len(clean_wav)]
331
+ noise_rms = np.sqrt(np.mean(np.square(noise_wav), axis=-1))
332
+ adjusted_noise_rms = clean_rms / (10**(snr/20))
333
+ adjusted_noise_wav = noise_wav * (adjusted_noise_rms / noise_rms)
334
+ mixed = clean_wav + adjusted_noise_wav
335
+
336
+ #Avoid clipping noise
337
+ max_int16 = np.iinfo(np.int16).max
338
+ min_int16 = np.iinfo(np.int16).min
339
+ if mixed.max(axis=0) > max_int16 or mixed.min(axis=0) < min_int16:
340
+ if mixed.max(axis=0) >= abs(mixed.min(axis=0)):
341
+ reduction_rate = max_int16 / mixed.max(axis=0)
342
+ else :
343
+ reduction_rate = min_int16 / mixed.min(axis=0)
344
+ mixed = mixed * (reduction_rate)
345
+ mixed = mixed.astype(np.int16)
346
+ return mixed
347
+
348
+ def __getitem__(self, index):
349
+ video_feats, audio_feats = self.load_feature(self.names[index])
350
+ audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)) if audio_feats is not None else None, torch.from_numpy(video_feats.astype(np.float32)) if video_feats is not None else None
351
+ if self.normalize and 'audio' in self.modalities:
352
+ with torch.no_grad():
353
+ audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
354
+ labels = self.get_labels(index)
355
+ fid = self.names[index][1].split(':')[1]
356
+ return {"id": index, 'fid': fid, "video_source": video_feats, 'audio_source': audio_feats, "label_list": labels}
357
+
358
+ def __len__(self):
359
+ return len(self.sizes)
360
+
361
+ def crop_to_max_size(self, wav, target_size, start=None):
362
+ size = len(wav)
363
+ diff = size - target_size
364
+ if diff <= 0:
365
+ return wav, 0
366
+ # longer utterances
367
+ if start is None:
368
+ start, end = 0, target_size
369
+ if self.random_crop:
370
+ start = np.random.randint(0, diff + 1)
371
+ end = size - diff + start
372
+ else:
373
+ end = start + target_size
374
+ return wav[start:end], start
375
+
376
+ def collater(self, samples):
377
+ samples = [s for s in samples if s["id"] is not None]
378
+ if len(samples) == 0:
379
+ return {}
380
+
381
+ audio_source, video_source = [s["audio_source"] for s in samples], [s["video_source"] for s in samples]
382
+ if audio_source[0] is None:
383
+ audio_source = None
384
+ if video_source[0] is None:
385
+ video_source = None
386
+ if audio_source is not None:
387
+ audio_sizes = [len(s) for s in audio_source]
388
+ else:
389
+ audio_sizes = [len(s) for s in video_source]
390
+ if self.pad_audio:
391
+ audio_size = min(max(audio_sizes), self.max_sample_size)
392
+ else:
393
+ audio_size = min(min(audio_sizes), self.max_sample_size)
394
+ if audio_source is not None:
395
+ collated_audios, padding_mask, audio_starts = self.collater_audio(audio_source, audio_size)
396
+ else:
397
+ collated_audios, audio_starts = None, None
398
+ if video_source is not None:
399
+ collated_videos, padding_mask, audio_starts = self.collater_audio(video_source, audio_size, audio_starts)
400
+ else:
401
+ collated_videos = None
402
+ targets_by_label = [
403
+ [s["label_list"][i] for s in samples]
404
+ for i in range(self.num_labels)
405
+ ]
406
+ targets_list, lengths_list, ntokens_list = self.collater_label(
407
+ targets_by_label, audio_size, audio_starts
408
+ )
409
+ source = {"audio": collated_audios, "video": collated_videos}
410
+ net_input = {"source": source, "padding_mask": padding_mask}
411
+ batch = {
412
+ "id": torch.LongTensor([s["id"] for s in samples]),
413
+ "net_input": net_input,
414
+ "utt_id": [s['fid'] for s in samples]
415
+ }
416
+
417
+ if self.single_target:
418
+ batch["target_lengths"] = lengths_list[0]
419
+ batch["ntokens"] = ntokens_list[0]
420
+ if self.is_s2s:
421
+ batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1]
422
+ else:
423
+ batch["target"] = targets_list[0]
424
+ else:
425
+ batch["target_lengths_list"] = lengths_list
426
+ batch["ntokens_list"] = ntokens_list
427
+ batch["target_list"] = targets_list
428
+ return batch
429
+
430
+ def collater_audio(self, audios, audio_size, audio_starts=None):
431
+ audio_feat_shape = list(audios[0].shape[1:])
432
+ collated_audios = audios[0].new_zeros([len(audios), audio_size]+audio_feat_shape)
433
+ padding_mask = (
434
+ torch.BoolTensor(len(audios), audio_size).fill_(False) #
435
+ )
436
+ start_known = audio_starts is not None
437
+ audio_starts = [0 for _ in audios] if not start_known else audio_starts
438
+ for i, audio in enumerate(audios):
439
+ diff = len(audio) - audio_size
440
+ if diff == 0:
441
+ collated_audios[i] = audio
442
+ elif diff < 0:
443
+ assert self.pad_audio
444
+ collated_audios[i] = torch.cat(
445
+ [audio, audio.new_full([-diff]+audio_feat_shape, 0.0)]
446
+ )
447
+ padding_mask[i, diff:] = True
448
+ else:
449
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
450
+ audio, audio_size, audio_starts[i] if start_known else None
451
+ )
452
+ if len(audios[0].shape) == 2:
453
+ collated_audios = collated_audios.transpose(1, 2) # [B, T, F] -> [B, F, T]
454
+ else:
455
+ collated_audios = collated_audios.permute((0, 4, 1, 2, 3)).contiguous() # [B, T, H, W, C] -> [B, C, T, H, W]
456
+ return collated_audios, padding_mask, audio_starts
457
+
458
+ def collater_frm_label(
459
+ self, targets, audio_size, audio_starts, label_rate, pad
460
+ ):
461
+ assert label_rate > 0
462
+ s2f = label_rate / self.sample_rate # num label per sample
463
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
464
+ frm_size = int(round(audio_size * s2f))
465
+ if not self.pad_audio:
466
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
467
+ frm_size = min(frm_size, *rem_size)
468
+ targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)]
469
+ logger.debug(f"audio_starts={audio_starts}")
470
+ logger.debug(f"frame_starts={frm_starts}")
471
+ logger.debug(f"frame_size={frm_size}")
472
+
473
+ lengths = torch.LongTensor([len(t) for t in targets])
474
+ ntokens = lengths.sum().item()
475
+ targets = data_utils.collate_tokens(
476
+ targets, pad_idx=pad, left_pad=False
477
+ )
478
+ return targets, lengths, ntokens
479
+
480
+ def collater_seq_label(self, targets, pad):
481
+ lengths = torch.LongTensor([len(t) for t in targets])
482
+ ntokens = lengths.sum().item()
483
+ targets = data_utils.collate_tokens(
484
+ targets, pad_idx=pad, left_pad=False
485
+ )
486
+ return targets, lengths, ntokens
487
+
488
+ def collater_seq_label_s2s(self, targets, pad):
489
+ lengths = torch.LongTensor([len(t) for t in targets])
490
+ ntokens = lengths.sum().item()
491
+ pad, eos = self.label_processors[0].dictionary.pad(), self.label_processors[0].dictionary.eos()
492
+ targets_ = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False)
493
+ prev_output_tokens = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False, move_eos_to_beginning=True)
494
+ return (targets_, prev_output_tokens), lengths, ntokens
495
+
496
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
497
+ targets_list, lengths_list, ntokens_list = [], [], []
498
+ itr = zip(targets_by_label, self.label_rates, self.pad_list)
499
+ for targets, label_rate, pad in itr:
500
+ if label_rate == -1:
501
+ if self.is_s2s:
502
+ targets, lengths, ntokens = self.collater_seq_label_s2s(targets, pad)
503
+ else:
504
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
505
+ else:
506
+ targets, lengths, ntokens = self.collater_frm_label(
507
+ targets, audio_size, audio_starts, label_rate, pad
508
+ )
509
+ targets_list.append(targets)
510
+ lengths_list.append(lengths)
511
+ ntokens_list.append(ntokens)
512
+ return targets_list, lengths_list, ntokens_list
513
+
514
+ def num_tokens(self, index):
515
+ return self.size(index)
516
+
517
+ def size(self, index):
518
+ if self.pad_audio:
519
+ return self.sizes[index]
520
+ return min(self.sizes[index], self.max_sample_size)
521
+
522
+ def ordered_indices(self):
523
+ if self.shuffle:
524
+ order = [np.random.permutation(len(self))]
525
+ else:
526
+ order = [np.arange(len(self))]
527
+
528
+ order.append(self.sizes)
529
+ return np.lexsort(order)[::-1]
slam_llm/models/avhubert/hubert_pretraining.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os, glob
9
+ import sys
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ import numpy as np
13
+
14
+ from dataclasses import dataclass, field
15
+ from fairseq import metrics, search
16
+ from fairseq.data import Dictionary, encoders
17
+ from fairseq.dataclass.configs import FairseqDataclass
18
+ from fairseq.tasks import register_task
19
+ from fairseq.tasks.fairseq_task import FairseqTask
20
+ from omegaconf import MISSING, II
21
+ import numpy as np
22
+ from argparse import Namespace
23
+
24
+ DBG=True if len(sys.argv) == 1 else False
25
+
26
+ if DBG:
27
+ from hubert_dataset import AVHubertDataset
28
+ from sequence_generator import SequenceGenerator
29
+ else:
30
+ from .hubert_dataset import AVHubertDataset
31
+ from .sequence_generator import SequenceGenerator
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class LabelEncoder(object):
37
+ def __init__(self, dictionary: Dictionary) -> None:
38
+ self.dictionary = dictionary
39
+
40
+ def __call__(self, label: str) -> List[str]:
41
+ return self.dictionary.encode_line(
42
+ label, append_eos=False, add_if_not_exist=False,
43
+ )
44
+
45
+ class LabelEncoderS2SToken(object):
46
+ def __init__(self, dictionary: Dictionary, bpe_tokenizer) -> None:
47
+ self.bpe_tokenizer = bpe_tokenizer
48
+ self.dictionary = dictionary
49
+
50
+ def __call__(self, label: str) -> List[str]:
51
+ label = self.bpe_tokenizer.encode(label.lower())
52
+ return self.dictionary.encode_line(
53
+ label, append_eos=True, add_if_not_exist=False,
54
+ ).long()
55
+
56
+ def decode(self, tok, symbols_ignore=None):
57
+ tok = self.dictionary.string(tok, extra_symbols_to_ignore=symbols_ignore)
58
+ if self.bpe_tokenizer:
59
+ tok = self.bpe_tokenizer.decode(tok)
60
+ return tok
61
+
62
+ @dataclass
63
+ class AVHubertPretrainingConfig(FairseqDataclass):
64
+ input_modality: str = II("task.input_modality") #??
65
+ data: str = field(
66
+ default=MISSING, metadata={"help": "path to data directory"}
67
+ )
68
+ labels: List[str] = field(
69
+ default_factory=lambda: ["ltr"],
70
+ metadata={
71
+ "help": (
72
+ "extension of the label files to load, frame-level labels for"
73
+ " pre-training, and sequence-level label for fine-tuning"
74
+ )
75
+ },
76
+ )
77
+ label_dir: Optional[str] = field(
78
+ default=None,
79
+ metadata={
80
+ "help": "if set, looks for labels in this directory instead",
81
+ },
82
+ )
83
+ label_rate: int = field(
84
+ default=-1,
85
+ metadata={"help": "label frame rate. -1 for sequence label"},
86
+ )
87
+
88
+ sample_rate: int = field(
89
+ default=16_000,
90
+ metadata={
91
+ "help": "target sample rate. audio files will be up/down "
92
+ "sampled to this rate"
93
+ },
94
+ )
95
+ normalize: bool = field(
96
+ default=False,
97
+ metadata={
98
+ "help": "if set, normalizes input to have 0 mean and unit variance"
99
+ },
100
+ )
101
+ enable_padding: bool = field(
102
+ default=False,
103
+ metadata={"help": "pad shorter samples instead of cropping"},
104
+ )
105
+ max_sample_size: Optional[int] = field(
106
+ default=None,
107
+ metadata={"help": "max sample size to keep in training"},
108
+ )
109
+ min_sample_size: Optional[int] = field(
110
+ default=None,
111
+ metadata={"help": "min sample size to keep in training"},
112
+ )
113
+ max_trim_sample_size: Optional[int] = field(
114
+ default=II("task.max_sample_size"),
115
+ metadata={"help": "max sample size to trim to for batching"},
116
+ )
117
+ single_target: Optional[bool] = field(
118
+ default=False,
119
+ metadata={
120
+ "help": "if set, AddTargetDatasets outputs same keys "
121
+ "as AddTargetDataset"
122
+ },
123
+ )
124
+ random_crop: Optional[bool] = field(
125
+ default=True,
126
+ metadata={"help": "always crop from the beginning if false"},
127
+ )
128
+ pad_audio: Optional[bool] = field(
129
+ default=False,
130
+ metadata={"help": "pad audio to the longest one in the batch if true"},
131
+ )
132
+ pdb: Optional[bool] = field(
133
+ default=False,
134
+ metadata={"help": "pdb"},
135
+ )
136
+ stack_order_audio: int = field(
137
+ default=1,
138
+ metadata={"help": "concatenate n consecutive audio frames for one step"},
139
+ )
140
+ skip_verify: Optional[bool] = field(
141
+ default=False,
142
+ metadata={"help": "skip verifying label-audio alignment"},
143
+ )
144
+ image_aug: bool = field(default=False, metadata={'help': 'image data augmentation'})
145
+ image_crop_size: int = field(
146
+ default=88, metadata={"help": "image ROI size"})
147
+ image_mean: float = field(
148
+ default=0.421, metadata={"help": "image mean"})
149
+ image_std: float = field(
150
+ default=0.165, metadata={"help": "image std"})
151
+ modalities: Optional[List[str]] = field(default_factory=lambda: ["audio", "video"], metadata={'help': 'modalities to load'})
152
+ is_s2s: bool=field(default=False, metadata={'help': 'seq2seq fine-tuning only'})
153
+ tokenizer_bpe_name: Optional[str] = field(default=None, metadata={'help': 'tokenizer model name'})
154
+ tokenizer_bpe_model: Optional[str] = field(default=None, metadata={'help': 'tokenizer model path'})
155
+ noise_wav: Optional[str] = field(default=None, metadata={'help': 'manifest of noise wav files (one wav file path per line)'})
156
+ noise_prob: float = field(default=0, metadata={'help': 'noise probability'})
157
+ noise_snr: Optional[str] = field(default='0', metadata={'help': 'noise SNR in audio'})
158
+ noise_num: int = field(default=1, metadata={'help': 'number of noise wav files to mix'})
159
+ fine_tuning: bool = field(default=False, metadata={"help": "set to true if fine-tuning AV-Hubert"})
160
+
161
+ @register_task("av_hubert_pretraining", dataclass=AVHubertPretrainingConfig)
162
+ class AVHubertPretrainingTask(FairseqTask):
163
+
164
+ cfg: AVHubertPretrainingConfig
165
+
166
+ def __init__(
167
+ self,
168
+ cfg: AVHubertPretrainingConfig,
169
+ ) -> None:
170
+ super().__init__(cfg)
171
+
172
+ logger.info(f"current directory is {os.getcwd()}")
173
+ logger.info(f"AVHubertPretrainingTask Config {cfg}")
174
+
175
+ self.fine_tuning = cfg.fine_tuning
176
+ if cfg.fine_tuning:
177
+ self.state.add_factory("target_dictionary", self.load_dictionaries)
178
+ if cfg.is_s2s:
179
+ self.state.add_factory("s2s_tokenizer", self.load_tokenizer)
180
+ else:
181
+ self.state.add_factory("dictionaries", self.load_dictionaries)
182
+
183
+ self.blank_symbol = "<s>"
184
+
185
+ @property
186
+ def source_dictionary(self) -> Optional[Dictionary]:
187
+ return None # self._source_dictionary
188
+
189
+ @property
190
+ def target_dictionary(self) -> Optional[Dictionary]:
191
+ return self.state.target_dictionary # self._target_dictionary
192
+
193
+ @property
194
+ def dictionaries(self) -> List[Dictionary]:
195
+ return self.state.dictionaries
196
+
197
+ def load_dictionaries(self):
198
+ label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
199
+ dictionaries = [
200
+ Dictionary.load(f"{label_dir}/dict.{label}.txt")
201
+ for label in self.cfg.labels
202
+ ]
203
+ return dictionaries[0] if self.cfg.fine_tuning else dictionaries
204
+
205
+ def load_tokenizer(self):
206
+ bpe_args = Namespace(**{'bpe': self.cfg.tokenizer_bpe_name, f"{self.cfg.tokenizer_bpe_name}_model": self.cfg.tokenizer_bpe_model})
207
+ bpe_tokenizer = encoders.build_bpe(bpe_args)
208
+ return bpe_tokenizer
209
+
210
+ @property
211
+ def s2s_tokenizer(self):
212
+ return self.state.s2s_tokenizer
213
+
214
+ @classmethod
215
+ def setup_task(
216
+ cls, cfg: AVHubertPretrainingConfig, **kwargs
217
+ ) -> "AVHubertPretrainingTask":
218
+ if cfg.pdb:
219
+ import pdb
220
+ pdb.set_trace()
221
+ return cls(cfg)
222
+
223
+ def get_label_dir(self) -> str:
224
+ if self.cfg.label_dir is None:
225
+ return self.cfg.data
226
+ return self.cfg.label_dir
227
+
228
+ def load_dataset(self, split: str, **kwargs) -> None:
229
+ manifest = f"{self.cfg.data}/{split}.tsv"
230
+ dictionaries = [self.target_dictionary] if self.fine_tuning else self.dictionaries
231
+ pad_list = [dictionary.pad() for dictionary in dictionaries]
232
+ eos_list = [dictionary.eos() for dictionary in dictionaries]
233
+ if not self.cfg.is_s2s:
234
+ procs = [LabelEncoder(dictionary) for dictionary in dictionaries]
235
+ else:
236
+ logger.info(f"Using tokenizer")
237
+ bpe_tokenizer = self.s2s_tokenizer
238
+ procs = [LabelEncoderS2SToken(dictionary, bpe_tokenizer) for dictionary in dictionaries]
239
+ paths = [
240
+ f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels
241
+ ]
242
+ image_aug = self.cfg.image_aug if split == 'train' else False
243
+ noise_fn, noise_snr = f"{self.cfg.noise_wav}/{split}.tsv" if self.cfg.noise_wav is not None else None, eval(self.cfg.noise_snr)
244
+ noise_num = self.cfg.noise_num #
245
+ self.datasets[split] = AVHubertDataset(
246
+ manifest,
247
+ sample_rate=self.cfg.sample_rate,
248
+ label_paths=paths,
249
+ label_rates=self.cfg.label_rate,
250
+ pad_list=pad_list,
251
+ eos_list=eos_list,
252
+ label_processors=procs,
253
+ max_keep_sample_size=self.cfg.max_sample_size,
254
+ min_keep_sample_size=self.cfg.min_sample_size,
255
+ max_sample_size=self.cfg.max_trim_sample_size,
256
+ pad_audio=self.cfg.pad_audio,
257
+ normalize=self.cfg.normalize,
258
+ store_labels=False,
259
+ random_crop=self.cfg.random_crop,
260
+ single_target=self.cfg.single_target,
261
+ stack_order_audio=self.cfg.stack_order_audio,
262
+ skip_verify=self.cfg.skip_verify,
263
+ image_mean=self.cfg.image_mean,
264
+ image_std=self.cfg.image_std,
265
+ image_crop_size=self.cfg.image_crop_size,
266
+ image_aug=image_aug,
267
+ modalities=self.cfg.modalities,
268
+ is_s2s=self.cfg.is_s2s,
269
+ noise_fn=noise_fn,
270
+ noise_prob=self.cfg.noise_prob,
271
+ noise_snr=noise_snr,
272
+ noise_num=noise_num
273
+ )
274
+
275
+ def max_positions(self) -> Tuple[int, int]:
276
+ return (sys.maxsize, sys.maxsize)
277
+
278
+ def filter_indices_by_size(
279
+ self, indices: np.array, *args, **kwargs
280
+ ) -> np.array:
281
+ return indices
282
+
283
+ def build_generator(
284
+ self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
285
+ ):
286
+ """
287
+ Build a :class:`~fairseq.SequenceGenerator` instance for this
288
+ task.
289
+ Args:
290
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
291
+ args (fairseq.dataclass.configs.GenerationConfig):
292
+ configuration object (dataclass) for generation
293
+ extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
294
+ through to SequenceGenerator
295
+ prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
296
+ If provided, this function constrains the beam search to
297
+ allowed tokens only at each step. The provided function
298
+ should take 2 arguments: the batch ID (`batch_id: int`)
299
+ and a unidimensional tensor of token ids (`inputs_ids:
300
+ torch.Tensor`). It has to return a `List[int]` with the
301
+ allowed tokens for the next generation step conditioned
302
+ on the previously generated tokens (`inputs_ids`) and
303
+ the batch ID (`batch_id`). This argument is useful for
304
+ constrained generation conditioned on the prefix, as
305
+ described in "Autoregressive Entity Retrieval"
306
+ (https://arxiv.org/abs/2010.00904) and
307
+ https://github.com/facebookresearch/GENRE.
308
+ """
309
+ if getattr(args, "score_reference", False):
310
+ from fairseq.sequence_scorer import SequenceScorer
311
+
312
+ return SequenceScorer(
313
+ self.target_dictionary,
314
+ compute_alignment=getattr(args, "print_alignment", False),
315
+ )
316
+
317
+ # Choose search strategy. Defaults to Beam Search.
318
+ sampling = getattr(args, "sampling", False)
319
+ sampling_topk = getattr(args, "sampling_topk", -1)
320
+ sampling_topp = getattr(args, "sampling_topp", -1.0)
321
+ diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
322
+ diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
323
+ match_source_len = getattr(args, "match_source_len", False)
324
+ diversity_rate = getattr(args, "diversity_rate", -1)
325
+ constrained = getattr(args, "constraints", False)
326
+ if prefix_allowed_tokens_fn is None:
327
+ prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
328
+ if (
329
+ sum(
330
+ int(cond)
331
+ for cond in [
332
+ sampling,
333
+ diverse_beam_groups > 0,
334
+ match_source_len,
335
+ diversity_rate > 0,
336
+ ]
337
+ )
338
+ > 1
339
+ ):
340
+ raise ValueError("Provided Search parameters are mutually exclusive.")
341
+ assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
342
+ assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
343
+
344
+ if sampling:
345
+ search_strategy = search.Sampling(
346
+ self.target_dictionary, sampling_topk, sampling_topp
347
+ )
348
+ elif diverse_beam_groups > 0:
349
+ search_strategy = search.DiverseBeamSearch(
350
+ self.target_dictionary, diverse_beam_groups, diverse_beam_strength
351
+ )
352
+ elif match_source_len:
353
+ # this is useful for tagging applications where the output
354
+ # length should match the input length, so we hardcode the
355
+ # length constraints for simplicity
356
+ search_strategy = search.LengthConstrainedBeamSearch(
357
+ self.target_dictionary,
358
+ min_len_a=1,
359
+ min_len_b=0,
360
+ max_len_a=1,
361
+ max_len_b=0,
362
+ )
363
+ elif diversity_rate > -1:
364
+ search_strategy = search.DiverseSiblingsSearch(
365
+ self.target_dictionary, diversity_rate
366
+ )
367
+ elif constrained:
368
+ search_strategy = search.LexicallyConstrainedBeamSearch(
369
+ self.target_dictionary, args.constraints
370
+ )
371
+ elif prefix_allowed_tokens_fn:
372
+ search_strategy = search.PrefixConstrainedBeamSearch(
373
+ self.target_dictionary, prefix_allowed_tokens_fn
374
+ )
375
+ else:
376
+ search_strategy = search.BeamSearch(self.target_dictionary)
377
+
378
+ extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
379
+ if seq_gen_cls is None:
380
+ if getattr(args, "print_alignment", False):
381
+ seq_gen_cls = SequenceGeneratorWithAlignment
382
+ extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
383
+ else:
384
+ seq_gen_cls = SequenceGenerator
385
+
386
+ return seq_gen_cls(
387
+ models,
388
+ self.target_dictionary,
389
+ beam_size=getattr(args, "beam", 5),
390
+ max_len_a=getattr(args, "max_len_a", 0),
391
+ max_len_b=getattr(args, "max_len_b", 200),
392
+ min_len=getattr(args, "min_len", 1),
393
+ normalize_scores=(not getattr(args, "unnormalized", False)),
394
+ len_penalty=getattr(args, "lenpen", 1),
395
+ unk_penalty=getattr(args, "unkpen", 0),
396
+ temperature=getattr(args, "temperature", 1.0),
397
+ match_source_len=getattr(args, "match_source_len", False),
398
+ no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
399
+ search_strategy=search_strategy,
400
+ **extra_gen_cls_kwargs,
401
+ )
slam_llm/models/avhubert/infer_s2s.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import ast
8
+ from itertools import chain
9
+ import logging
10
+ import math
11
+ import os
12
+ import sys
13
+ import json
14
+ import hashlib
15
+ import editdistance
16
+ from argparse import Namespace
17
+
18
+ import numpy as np
19
+ import torch
20
+ from fairseq import checkpoint_utils, options, tasks, utils, distributed_utils
21
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
22
+ from fairseq.logging import progress_bar
23
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
24
+ from fairseq.models import FairseqLanguageModel
25
+ from omegaconf import DictConfig
26
+
27
+ from pathlib import Path
28
+ import hydra
29
+ from hydra.core.config_store import ConfigStore
30
+ from fairseq.dataclass.configs import (
31
+ CheckpointConfig,
32
+ CommonConfig,
33
+ CommonEvalConfig,
34
+ DatasetConfig,
35
+ DistributedTrainingConfig,
36
+ GenerationConfig,
37
+ FairseqDataclass,
38
+ )
39
+ from dataclasses import dataclass, field, is_dataclass
40
+ from typing import Any, Dict, List, Optional, Tuple, Union
41
+ from omegaconf import OmegaConf
42
+
43
+ logging.root.setLevel(logging.INFO)
44
+ logging.basicConfig(level=logging.INFO)
45
+ logger = logging.getLogger(__name__)
46
+
47
+ config_path = Path(__file__).resolve().parent / "conf"
48
+
49
+ @dataclass
50
+ class OverrideConfig(FairseqDataclass):
51
+ noise_wav: Optional[str] = field(default=None, metadata={'help': 'noise wav file'})
52
+ noise_prob: float = field(default=0, metadata={'help': 'noise probability'})
53
+ noise_snr: float = field(default=0, metadata={'help': 'noise SNR in audio'})
54
+ modalities: List[str] = field(default_factory=lambda: [""], metadata={'help': 'which modality to use'})
55
+ data: Optional[str] = field(default=None, metadata={'help': 'path to test data directory'})
56
+ label_dir: Optional[str] = field(default=None, metadata={'help': 'path to test label directory'})
57
+
58
+ @dataclass
59
+ class InferConfig(FairseqDataclass):
60
+ task: Any = None
61
+ generation: GenerationConfig = GenerationConfig()
62
+ common: CommonConfig = CommonConfig()
63
+ common_eval: CommonEvalConfig = CommonEvalConfig()
64
+ checkpoint: CheckpointConfig = CheckpointConfig()
65
+ distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
66
+ dataset: DatasetConfig = DatasetConfig()
67
+ override: OverrideConfig = OverrideConfig()
68
+ is_ax: bool = field(
69
+ default=False,
70
+ metadata={
71
+ "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
72
+ },
73
+ )
74
+
75
+
76
+ def main(cfg: DictConfig):
77
+
78
+ if isinstance(cfg, Namespace):
79
+ cfg = convert_namespace_to_omegaconf(cfg)
80
+
81
+ assert cfg.common_eval.path is not None, "--path required for recognition!"
82
+ assert (
83
+ not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
84
+ ), "--sampling requires --nbest to be equal to --beam"
85
+
86
+ if cfg.common_eval.results_path is not None:
87
+ os.makedirs(cfg.common_eval.results_path, exist_ok=True)
88
+ output_path = os.path.join(cfg.common_eval.results_path, "decode.log")
89
+ with open(output_path, "w", buffering=1, encoding="utf-8") as h:
90
+ return _main(cfg, h)
91
+ return _main(cfg, sys.stdout)
92
+
93
+
94
+ def get_symbols_to_strip_from_output(generator):
95
+ if hasattr(generator, "symbols_to_strip_from_output"):
96
+ return generator.symbols_to_strip_from_output
97
+ else:
98
+ return {generator.eos, generator.pad}
99
+
100
+ def _main(cfg, output_file):
101
+ logging.basicConfig(
102
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
103
+ datefmt="%Y-%m-%d %H:%M:%S",
104
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
105
+ stream=output_file,
106
+ )
107
+ logger = logging.getLogger("hybrid.speech_recognize")
108
+ if output_file is not sys.stdout: # also print to stdout
109
+ logger.addHandler(logging.StreamHandler(sys.stdout))
110
+
111
+ utils.import_user_module(cfg.common)
112
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([cfg.common_eval.path])
113
+ models = [model.eval().cuda() for model in models] #!!
114
+ saved_cfg.task.modalities = cfg.override.modalities
115
+ task = tasks.setup_task(saved_cfg.task)
116
+
117
+ task.build_tokenizer(saved_cfg.tokenizer)
118
+ task.build_bpe(saved_cfg.bpe)
119
+
120
+ logger.info(cfg)
121
+
122
+ # Fix seed for stochastic decoding
123
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
124
+ np.random.seed(cfg.common.seed)
125
+ utils.set_torch_seed(cfg.common.seed)
126
+
127
+ use_cuda = torch.cuda.is_available()
128
+
129
+ # Set dictionary
130
+ dictionary = task.target_dictionary
131
+
132
+ # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
133
+ task.cfg.noise_prob = cfg.override.noise_prob
134
+ task.cfg.noise_snr = cfg.override.noise_snr
135
+ task.cfg.noise_wav = cfg.override.noise_wav
136
+ if cfg.override.data is not None:
137
+ task.cfg.data = cfg.override.data
138
+ if cfg.override.label_dir is not None:
139
+ task.cfg.label_dir = cfg.override.label_dir
140
+ task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
141
+
142
+ lms = [None]
143
+
144
+ # Optimize ensemble for generation
145
+ for model in chain(models, lms):
146
+ if model is None:
147
+ continue
148
+ if cfg.common.fp16:
149
+ model.half()
150
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
151
+ model.cuda()
152
+ model.prepare_for_inference_(cfg)
153
+
154
+ # Load dataset (possibly sharded)
155
+ itr = task.get_batch_iterator(
156
+ dataset=task.dataset(cfg.dataset.gen_subset),
157
+ max_tokens=cfg.dataset.max_tokens,
158
+ max_sentences=cfg.dataset.batch_size,
159
+ max_positions=utils.resolve_max_positions(
160
+ task.max_positions(), *[m.max_positions() for m in models]
161
+ ),
162
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
163
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
164
+ seed=cfg.common.seed,
165
+ num_shards=cfg.distributed_training.distributed_world_size,
166
+ shard_id=cfg.distributed_training.distributed_rank,
167
+ num_workers=cfg.dataset.num_workers,
168
+ data_buffer_size=cfg.dataset.data_buffer_size,
169
+ ).next_epoch_itr(shuffle=False)
170
+ progress = progress_bar.progress_bar(
171
+ itr,
172
+ log_format=cfg.common.log_format,
173
+ log_interval=cfg.common.log_interval,
174
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
175
+ )
176
+
177
+ # Initialize generator
178
+ if cfg.generation.match_source_len:
179
+ logger.warning(
180
+ "The option match_source_len is not applicable to speech recognition. Ignoring it."
181
+ )
182
+ gen_timer = StopwatchMeter()
183
+ extra_gen_cls_kwargs = {
184
+ "lm_model": lms[0],
185
+ "lm_weight": cfg.generation.lm_weight,
186
+ }
187
+ cfg.generation.score_reference = False #
188
+ save_attention_plot = cfg.generation.print_alignment is not None
189
+ cfg.generation.print_alignment = None #
190
+ generator = task.build_generator(
191
+ models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
192
+ )
193
+
194
+ def decode_fn(x):
195
+ symbols_ignore = get_symbols_to_strip_from_output(generator)
196
+ symbols_ignore.add(dictionary.pad())
197
+ if hasattr(task.datasets[cfg.dataset.gen_subset].label_processors[0], 'decode'):
198
+ return task.datasets[cfg.dataset.gen_subset].label_processors[0].decode(x, symbols_ignore)
199
+ chars = dictionary.string(x, extra_symbols_to_ignore=symbols_ignore)
200
+ words = " ".join("".join(chars.split()).replace('|', ' ').split())
201
+ return words
202
+
203
+ num_sentences = 0
204
+ has_target = True
205
+ wps_meter = TimeMeter()
206
+ result_dict = {'utt_id': [], 'ref': [], 'hypo': []}
207
+ for sample in progress:
208
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
209
+ if "net_input" not in sample:
210
+ continue
211
+
212
+ prefix_tokens = None
213
+ if cfg.generation.prefix_size > 0:
214
+ prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
215
+
216
+ constraints = None
217
+ if "constraints" in sample:
218
+ constraints = sample["constraints"]
219
+
220
+ gen_timer.start()
221
+ hypos = task.inference_step(
222
+ generator,
223
+ models,
224
+ sample,
225
+ prefix_tokens=prefix_tokens,
226
+ constraints=constraints,
227
+ )
228
+ num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
229
+ gen_timer.stop(num_generated_tokens)
230
+
231
+ for i in range(len(sample["id"])):
232
+ result_dict['utt_id'].append(sample['utt_id'][i])
233
+ ref_sent = decode_fn(sample['target'][i].int().cpu())
234
+ result_dict['ref'].append(ref_sent)
235
+ best_hypo = hypos[i][0]['tokens'].int().cpu()
236
+ hypo_str = decode_fn(best_hypo)
237
+ result_dict['hypo'].append(hypo_str)
238
+ logger.info(f"\nREF:{ref_sent}\nHYP:{hypo_str}\n")
239
+ wps_meter.update(num_generated_tokens)
240
+ progress.log({"wps": round(wps_meter.avg)})
241
+ num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
242
+
243
+ logger.info("NOTE: hypothesis and token scores are output in base 2")
244
+ logger.info("Recognized {:,} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
245
+ num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
246
+
247
+ yaml_str = OmegaConf.to_yaml(cfg.generation)
248
+ fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16)
249
+ fid = fid % 1000000
250
+ result_fn = f"{cfg.common_eval.results_path}/hypo-{fid}.json"
251
+ json.dump(result_dict, open(result_fn, 'w'), indent=4)
252
+ n_err, n_total = 0, 0
253
+ assert len(result_dict['hypo']) == len(result_dict['ref'])
254
+ for hypo, ref in zip(result_dict['hypo'], result_dict['ref']):
255
+ hypo, ref = hypo.strip().split(), ref.strip().split()
256
+ n_err += editdistance.eval(hypo, ref)
257
+ n_total += len(ref)
258
+ wer = 100 * n_err / n_total
259
+ wer_fn = f"{cfg.common_eval.results_path}/wer.{fid}"
260
+ with open(wer_fn, "w") as fo:
261
+ fo.write(f"WER: {wer}\n")
262
+ fo.write(f"err / num_ref_words = {n_err} / {n_total}\n\n")
263
+ fo.write(f"{yaml_str}")
264
+ logger.info(f"WER: {wer}%")
265
+ return
266
+
267
+
268
+ @hydra.main(config_path=config_path, config_name="infer")
269
+ def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]:
270
+ container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
271
+ cfg = OmegaConf.create(container)
272
+ OmegaConf.set_struct(cfg, True)
273
+
274
+ if cfg.common.reset_logging:
275
+ reset_logging()
276
+
277
+ wer = float("inf")
278
+
279
+ try:
280
+ if cfg.common.profile:
281
+ with torch.cuda.profiler.profile():
282
+ with torch.autograd.profiler.emit_nvtx():
283
+ distributed_utils.call_main(cfg, main)
284
+ else:
285
+ distributed_utils.call_main(cfg, main)
286
+
287
+ except BaseException as e: # pylint: disable=broad-except
288
+ if not cfg.common.suppress_crashes:
289
+ raise
290
+ else:
291
+ logger.error("Crashed! %s", str(e))
292
+ return
293
+
294
+
295
+ def cli_main() -> None:
296
+ try:
297
+ from hydra._internal.utils import (
298
+ get_args,
299
+ ) # pylint: disable=import-outside-toplevel
300
+
301
+ cfg_name = get_args().config_name or "infer"
302
+ except ImportError:
303
+ logger.warning("Failed to get config name from hydra args")
304
+ cfg_name = "infer"
305
+
306
+ cs = ConfigStore.instance()
307
+ cs.store(name=cfg_name, node=InferConfig)
308
+
309
+ for k in InferConfig.__dataclass_fields__:
310
+ if is_dataclass(InferConfig.__dataclass_fields__[k].type):
311
+ v = InferConfig.__dataclass_fields__[k].default
312
+ cs.store(name=k, node=v)
313
+
314
+ hydra_main() # pylint: disable=no-value-for-parameter
315
+
316
+
317
+ if __name__ == "__main__":
318
+ cli_main()
slam_llm/models/avhubert/resnet.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import math
9
+ import torch.nn as nn
10
+ import pdb
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def conv3x3(in_planes, out_planes, stride=1):
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ def downsample_basic_block( inplanes, outplanes, stride ):
21
+ return nn.Sequential(
22
+ nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False),
23
+ nn.BatchNorm2d(outplanes),
24
+ )
25
+
26
+ def downsample_basic_block_v2( inplanes, outplanes, stride ):
27
+ return nn.Sequential(
28
+ nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False),
29
+ nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False),
30
+ nn.BatchNorm2d(outplanes),
31
+ )
32
+
33
+
34
+
35
+ class BasicBlock(nn.Module):
36
+ expansion = 1
37
+
38
+ def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type = 'relu' ):
39
+ super(BasicBlock, self).__init__()
40
+
41
+ assert relu_type in ['relu','prelu']
42
+
43
+ self.conv1 = conv3x3(inplanes, planes, stride)
44
+ self.bn1 = nn.BatchNorm2d(planes)
45
+
46
+ if relu_type == 'relu':
47
+ self.relu1 = nn.ReLU(inplace=True)
48
+ self.relu2 = nn.ReLU(inplace=True)
49
+ elif relu_type == 'prelu':
50
+ self.relu1 = nn.PReLU(num_parameters=planes)
51
+ self.relu2 = nn.PReLU(num_parameters=planes)
52
+ else:
53
+ raise Exception('relu type not implemented')
54
+
55
+ self.conv2 = conv3x3(planes, planes)
56
+ self.bn2 = nn.BatchNorm2d(planes)
57
+
58
+ self.downsample = downsample
59
+ self.stride = stride
60
+
61
+ def forward(self, x):
62
+ residual = x
63
+ out = self.conv1(x)
64
+ out = self.bn1(out)
65
+ out = self.relu1(out)
66
+ out = self.conv2(out)
67
+ out = self.bn2(out)
68
+ if self.downsample is not None:
69
+ residual = self.downsample(x)
70
+
71
+ out += residual
72
+ out = self.relu2(out)
73
+
74
+ return out
75
+
76
+
77
+ class ResNet(nn.Module):
78
+
79
+ def __init__(self, block, layers, num_classes=1000, relu_type = 'relu', gamma_zero = False, avg_pool_downsample = False):
80
+ self.inplanes = 64
81
+ self.relu_type = relu_type
82
+ self.gamma_zero = gamma_zero
83
+ self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block
84
+
85
+ super(ResNet, self).__init__()
86
+ self.layer1 = self._make_layer(block, 64, layers[0])
87
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
88
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
89
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
90
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
91
+
92
+ for m in self.modules():
93
+ if isinstance(m, nn.Conv2d):
94
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
95
+ m.weight.data.normal_(0, math.sqrt(2. / n))
96
+ elif isinstance(m, nn.BatchNorm2d):
97
+ m.weight.data.fill_(1)
98
+ m.bias.data.zero_()
99
+
100
+ if self.gamma_zero:
101
+ for m in self.modules():
102
+ if isinstance(m, BasicBlock ):
103
+ m.bn2.weight.data.zero_()
104
+
105
+ def _make_layer(self, block, planes, blocks, stride=1):
106
+
107
+
108
+ downsample = None
109
+ if stride != 1 or self.inplanes != planes * block.expansion:
110
+ downsample = self.downsample_block( inplanes = self.inplanes,
111
+ outplanes = planes * block.expansion,
112
+ stride = stride )
113
+
114
+ layers = []
115
+ layers.append(block(self.inplanes, planes, stride, downsample, relu_type = self.relu_type))
116
+ self.inplanes = planes * block.expansion
117
+ for i in range(1, blocks):
118
+ layers.append(block(self.inplanes, planes, relu_type = self.relu_type))
119
+
120
+ return nn.Sequential(*layers)
121
+
122
+ def forward(self, x):
123
+ x = self.layer1(x)
124
+ x = self.layer2(x)
125
+ x = self.layer3(x)
126
+ x = self.layer4(x)
127
+ x = self.avgpool(x)
128
+ x = x.view(x.size(0), -1)
129
+ return x
130
+
131
+ class ResEncoder(nn.Module):
132
+ def __init__(self, relu_type, weights):
133
+ super(ResEncoder, self).__init__()
134
+ self.frontend_nout = 64
135
+ self.backend_out = 512
136
+ frontend_relu = nn.PReLU(num_parameters=self.frontend_nout) if relu_type == 'prelu' else nn.ReLU()
137
+ self.frontend3D = nn.Sequential(
138
+ nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False),
139
+ nn.BatchNorm3d(self.frontend_nout),
140
+ frontend_relu,
141
+ nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))
142
+ self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)
143
+ if weights is not None:
144
+ logger.info(f"Load {weights} for resnet")
145
+ std = torch.load(weights, map_location=torch.device('cpu'))['model_state_dict']
146
+ frontend_std, trunk_std = OrderedDict(), OrderedDict()
147
+ for key, val in std.items():
148
+ new_key = '.'.join(key.split('.')[1:])
149
+ if 'frontend3D' in key:
150
+ frontend_std[new_key] = val
151
+ if 'trunk' in key:
152
+ trunk_std[new_key] = val
153
+ self.frontend3D.load_state_dict(frontend_std)
154
+ self.trunk.load_state_dict(trunk_std)
155
+
156
+ def forward(self, x):
157
+ B, C, T, H, W = x.size()
158
+ x = self.frontend3D(x)
159
+ Tnew = x.shape[2]
160
+ x = self.threeD_to_2D_tensor(x)
161
+ x = self.trunk(x)
162
+ x = x.view(B, Tnew, x.size(1))
163
+ x = x.transpose(1, 2).contiguous()
164
+ return x
165
+
166
+ def threeD_to_2D_tensor(self, x):
167
+ n_batch, n_channels, s_time, sx, sy = x.shape
168
+ x = x.transpose(1, 2).contiguous()
169
+ return x.reshape(n_batch*s_time, n_channels, sx, sy)
slam_llm/models/avhubert/sequence_generator.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Dict, List, Optional
9
+ import sys
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from fairseq import search, utils
14
+ from fairseq.data import data_utils
15
+ from fairseq.models import FairseqIncrementalDecoder
16
+ from torch import Tensor
17
+ from fairseq.ngram_repeat_block import NGramRepeatBlock
18
+
19
+
20
+ class SequenceGenerator(nn.Module):
21
+ def __init__(
22
+ self,
23
+ models,
24
+ tgt_dict,
25
+ beam_size=1,
26
+ max_len_a=0,
27
+ max_len_b=200,
28
+ max_len=0,
29
+ min_len=1,
30
+ normalize_scores=True,
31
+ len_penalty=1.0,
32
+ unk_penalty=0.0,
33
+ temperature=1.0,
34
+ match_source_len=False,
35
+ no_repeat_ngram_size=0,
36
+ search_strategy=None,
37
+ eos=None,
38
+ symbols_to_strip_from_output=None,
39
+ lm_model=None,
40
+ lm_weight=1.0,
41
+ ):
42
+ """Generates translations of a given source sentence.
43
+
44
+ Args:
45
+ models (List[~fairseq.models.FairseqModel]): ensemble of models,
46
+ currently support fairseq.models.TransformerModel for scripting
47
+ beam_size (int, optional): beam width (default: 1)
48
+ max_len_a/b (int, optional): generate sequences of maximum length
49
+ ax + b, where x is the source length
50
+ max_len (int, optional): the maximum length of the generated output
51
+ (not including end-of-sentence)
52
+ min_len (int, optional): the minimum length of the generated output
53
+ (not including end-of-sentence)
54
+ normalize_scores (bool, optional): normalize scores by the length
55
+ of the output (default: True)
56
+ len_penalty (float, optional): length penalty, where <1.0 favors
57
+ shorter, >1.0 favors longer sentences (default: 1.0)
58
+ unk_penalty (float, optional): unknown word penalty, where <0
59
+ produces more unks, >0 produces fewer (default: 0.0)
60
+ temperature (float, optional): temperature, where values
61
+ >1.0 produce more uniform samples and values <1.0 produce
62
+ sharper samples (default: 1.0)
63
+ match_source_len (bool, optional): outputs should match the source
64
+ length (default: False)
65
+ """
66
+ super().__init__()
67
+ if isinstance(models, EnsembleModel):
68
+ self.model = models
69
+ else:
70
+ self.model = EnsembleModel(models)
71
+ self.tgt_dict = tgt_dict
72
+ self.pad = tgt_dict.pad()
73
+ self.unk = tgt_dict.unk()
74
+ self.eos = tgt_dict.eos() if eos is None else eos
75
+ self.symbols_to_strip_from_output = (
76
+ symbols_to_strip_from_output.union({self.eos})
77
+ if symbols_to_strip_from_output is not None
78
+ else {self.eos}
79
+ )
80
+ self.vocab_size = len(tgt_dict)
81
+ self.beam_size = beam_size
82
+ # the max beam size is the dictionary size - 1, since we never select pad
83
+ self.beam_size = min(beam_size, self.vocab_size - 1)
84
+ self.max_len_a = max_len_a
85
+ self.max_len_b = max_len_b
86
+ self.min_len = min_len
87
+ self.max_len = max_len or self.model.max_decoder_positions()
88
+
89
+ self.normalize_scores = normalize_scores
90
+ self.len_penalty = len_penalty
91
+ self.unk_penalty = unk_penalty
92
+ self.temperature = temperature
93
+ self.match_source_len = match_source_len
94
+
95
+ if no_repeat_ngram_size > 0:
96
+ self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
97
+ else:
98
+ self.repeat_ngram_blocker = None
99
+
100
+ assert temperature > 0, "--temperature must be greater than 0"
101
+
102
+ self.search = (
103
+ search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
104
+ )
105
+ # We only need to set src_lengths in LengthConstrainedBeamSearch.
106
+ # As a module attribute, setting it would break in multithread
107
+ # settings when the model is shared.
108
+ self.should_set_src_lengths = (
109
+ hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
110
+ )
111
+
112
+ self.model.eval()
113
+
114
+ self.lm_model = lm_model
115
+ self.lm_weight = lm_weight
116
+ if self.lm_model is not None:
117
+ self.lm_model.eval()
118
+
119
+ def cuda(self):
120
+ self.model.cuda()
121
+ return self
122
+
123
+ @torch.no_grad()
124
+ def forward(
125
+ self,
126
+ sample: Dict[str, Dict[str, Tensor]],
127
+ prefix_tokens: Optional[Tensor] = None,
128
+ bos_token: Optional[int] = None,
129
+ ):
130
+ """Generate a batch of translations.
131
+
132
+ Args:
133
+ sample (dict): batch
134
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
135
+ with these tokens
136
+ bos_token (int, optional): beginning of sentence token
137
+ (default: self.eos)
138
+ """
139
+ return self._generate(sample, prefix_tokens, bos_token=bos_token)
140
+
141
+ # TODO(myleott): unused, deprecate after pytorch-translate migration
142
+ def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
143
+ """Iterate over a batched dataset and yield individual translations.
144
+ Args:
145
+ cuda (bool, optional): use GPU for generation
146
+ timer (StopwatchMeter, optional): time generations
147
+ """
148
+ for sample in data_itr:
149
+ s = utils.move_to_cuda(sample) if cuda else sample
150
+ if "net_input" not in s:
151
+ continue
152
+ input = s["net_input"]
153
+ # model.forward normally channels prev_output_tokens into the decoder
154
+ # separately, but SequenceGenerator directly calls model.encoder
155
+ encoder_input = {
156
+ k: v for k, v in input.items() if k != "prev_output_tokens"
157
+ }
158
+ if timer is not None:
159
+ timer.start()
160
+ with torch.no_grad():
161
+ hypos = self.generate(encoder_input)
162
+ if timer is not None:
163
+ timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
164
+ for i, id in enumerate(s["id"].data):
165
+ # remove padding
166
+ src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
167
+ ref = (
168
+ utils.strip_pad(s["target"].data[i, :], self.pad)
169
+ if s["target"] is not None
170
+ else None
171
+ )
172
+ yield id, src, ref, hypos[i]
173
+
174
+ @torch.no_grad()
175
+ def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
176
+ """Generate translations. Match the api of other fairseq generators.
177
+
178
+ Args:
179
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
180
+ sample (dict): batch
181
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
182
+ with these tokens
183
+ constraints (torch.LongTensor, optional): force decoder to include
184
+ the list of constraints
185
+ bos_token (int, optional): beginning of sentence token
186
+ (default: self.eos)
187
+ """
188
+ return self._generate(sample, **kwargs)
189
+
190
+ def _generate(
191
+ self,
192
+ sample: Dict[str, Dict[str, Tensor]],
193
+ prefix_tokens: Optional[Tensor] = None,
194
+ constraints: Optional[Tensor] = None,
195
+ bos_token: Optional[int] = None,
196
+ ):
197
+ incremental_states = torch.jit.annotate(
198
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
199
+ [
200
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
201
+ for i in range(self.model.models_size)
202
+ ],
203
+ )
204
+ net_input = sample["net_input"]
205
+
206
+ if "src_tokens" in net_input:
207
+ src_tokens = net_input["src_tokens"]
208
+ # length of the source text being the character length except EndOfSentence and pad
209
+ src_lengths = (
210
+ (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
211
+ )
212
+ elif "source" in net_input:
213
+ src_tokens = net_input["source"]
214
+ src_lengths = (
215
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
216
+ if net_input["padding_mask"] is not None
217
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
218
+ )
219
+ elif "features" in net_input:
220
+ src_tokens = net_input["features"]
221
+ src_lengths = (
222
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
223
+ if net_input["padding_mask"] is not None
224
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
225
+ )
226
+ else:
227
+ raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
228
+
229
+ # bsz: total number of sentences in beam
230
+ # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
231
+ if src_tokens['audio'] is not None:
232
+ bsz, src_len = src_tokens['audio'].size()[:2]
233
+ src_device = src_tokens['audio'].device
234
+ else:
235
+ bsz, src_len = net_input['padding_mask'].size()
236
+ src_device = src_tokens['video'].device
237
+ beam_size = self.beam_size
238
+ if constraints is not None and not self.search.supports_constraints:
239
+ raise NotImplementedError(
240
+ "Target-side constraints were provided, but search method doesn't support them"
241
+ )
242
+
243
+ # Initialize constraints, when active
244
+ self.search.init_constraints(constraints, beam_size)
245
+
246
+ max_len: int = -1
247
+ if self.match_source_len:
248
+ max_len = src_lengths.max().item()
249
+ else:
250
+ max_len = min(
251
+ int(self.max_len_a * src_len + self.max_len_b),
252
+ self.max_len - 1,
253
+ )
254
+ assert (
255
+ self.min_len <= max_len
256
+ ), "min_len cannot be larger than max_len, please adjust these!"
257
+ # compute the encoder output for each beam
258
+ encoder_outs = self.model.forward_encoder(net_input)
259
+
260
+ # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
261
+ new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
262
+ new_order = new_order.to(src_device).long()
263
+ encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
264
+ # ensure encoder_outs is a List.
265
+ assert encoder_outs is not None
266
+
267
+ # initialize buffers
268
+ scores = (
269
+ torch.zeros(bsz * beam_size, max_len + 1).to(src_device).float()
270
+ ) # +1 for eos; pad is never chosen for scoring
271
+ tokens = (
272
+ torch.zeros(bsz * beam_size, max_len + 2)
273
+ .to(src_device)
274
+ .long()
275
+ .fill_(self.pad)
276
+ ) # +2 for eos and pad
277
+ tokens[:, 0] = self.eos if bos_token is None else bos_token
278
+ attn: Optional[Tensor] = None
279
+
280
+ # A list that indicates candidates that should be ignored.
281
+ # For example, suppose we're sampling and have already finalized 2/5
282
+ # samples. Then cands_to_ignore would mark 2 positions as being ignored,
283
+ # so that we only finalize the remaining 3 samples.
284
+ cands_to_ignore = (
285
+ torch.zeros(bsz, beam_size).to(src_device).eq(-1)
286
+ ) # forward and backward-compatible False mask
287
+
288
+ # list of completed sentences
289
+ finalized = torch.jit.annotate(
290
+ List[List[Dict[str, Tensor]]],
291
+ [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
292
+ ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
293
+
294
+ # a boolean array indicating if the sentence at the index is finished or not
295
+ finished = [False for i in range(bsz)]
296
+ num_remaining_sent = bsz # number of sentences remaining
297
+
298
+ # number of candidate hypos per step
299
+ cand_size = 2 * beam_size # 2 x beam size in case half are EOS
300
+
301
+ # offset arrays for converting between different indexing schemes
302
+ bbsz_offsets = (
303
+ (torch.arange(0, bsz) * beam_size)
304
+ .unsqueeze(1)
305
+ .type_as(tokens)
306
+ .to(src_device)
307
+ )
308
+ cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_device)
309
+
310
+ reorder_state: Optional[Tensor] = None
311
+ batch_idxs: Optional[Tensor] = None
312
+
313
+ original_batch_idxs: Optional[Tensor] = None
314
+ if "id" in sample and isinstance(sample["id"], Tensor):
315
+ original_batch_idxs = sample["id"]
316
+ else:
317
+ original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
318
+
319
+ for step in range(max_len + 1): # one extra step for EOS marker
320
+ # reorder decoder internal states based on the prev choice of beams
321
+ if reorder_state is not None:
322
+ if batch_idxs is not None:
323
+ # update beam indices to take into account removed sentences
324
+ corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
325
+ batch_idxs
326
+ )
327
+ reorder_state.view(-1, beam_size).add_(
328
+ corr.unsqueeze(-1) * beam_size
329
+ )
330
+ original_batch_idxs = original_batch_idxs[batch_idxs]
331
+ self.model.reorder_incremental_state(incremental_states, reorder_state)
332
+ encoder_outs = self.model.reorder_encoder_out(
333
+ encoder_outs, reorder_state
334
+ )
335
+
336
+ lprobs, avg_attn_scores = self.model.forward_decoder(
337
+ tokens[:, : step + 1],
338
+ encoder_outs,
339
+ incremental_states,
340
+ self.temperature,
341
+ )
342
+
343
+ if self.lm_model is not None:
344
+ lm_out = self.lm_model(tokens[:, : step + 1])
345
+ probs = self.lm_model.get_normalized_probs(
346
+ lm_out, log_probs=True, sample=None
347
+ )
348
+ probs = probs[:, -1, :] * self.lm_weight
349
+ lprobs += probs
350
+
351
+ lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
352
+
353
+ lprobs[:, self.pad] = -math.inf # never select pad
354
+ lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
355
+
356
+ # handle max length constraint
357
+ if step >= max_len:
358
+ lprobs[:, : self.eos] = -math.inf
359
+ lprobs[:, self.eos + 1 :] = -math.inf
360
+
361
+ # handle prefix tokens (possibly with different lengths)
362
+ if (
363
+ prefix_tokens is not None
364
+ and step < prefix_tokens.size(1)
365
+ and step < max_len
366
+ ):
367
+ lprobs, tokens, scores = self._prefix_tokens(
368
+ step, lprobs, scores, tokens, prefix_tokens, beam_size
369
+ )
370
+ elif step < self.min_len:
371
+ # minimum length constraint (does not apply if using prefix_tokens)
372
+ lprobs[:, self.eos] = -math.inf
373
+
374
+ # Record attention scores, only support avg_attn_scores is a Tensor
375
+ if avg_attn_scores is not None:
376
+ if attn is None:
377
+ attn = torch.empty(
378
+ bsz * beam_size, avg_attn_scores.size(1), max_len + 2
379
+ ).to(scores)
380
+ attn[:, :, step + 1].copy_(avg_attn_scores)
381
+
382
+ scores = scores.type_as(lprobs)
383
+ eos_bbsz_idx = torch.empty(0).to(
384
+ tokens
385
+ ) # indices of hypothesis ending with eos (finished sentences)
386
+ eos_scores = torch.empty(0).to(
387
+ scores
388
+ ) # scores of hypothesis ending with eos (finished sentences)
389
+
390
+ if self.should_set_src_lengths:
391
+ self.search.set_src_lengths(src_lengths)
392
+
393
+ if self.repeat_ngram_blocker is not None:
394
+ lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
395
+
396
+ # Shape: (batch, cand_size)
397
+ cand_scores, cand_indices, cand_beams = self.search.step(
398
+ step,
399
+ lprobs.view(bsz, -1, self.vocab_size),
400
+ scores.view(bsz, beam_size, -1)[:, :, :step],
401
+ tokens[:, : step + 1],
402
+ original_batch_idxs,
403
+ )
404
+
405
+ # cand_bbsz_idx contains beam indices for the top candidate
406
+ # hypotheses, with a range of values: [0, bsz*beam_size),
407
+ # and dimensions: [bsz, cand_size]
408
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
409
+
410
+ # finalize hypotheses that end in eos
411
+ # Shape of eos_mask: (batch size, beam size)
412
+ eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
413
+ eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
414
+
415
+ # only consider eos when it's among the top beam_size indices
416
+ # Now we know what beam item(s) to finish
417
+ # Shape: 1d list of absolute-numbered
418
+ eos_bbsz_idx = torch.masked_select(
419
+ cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
420
+ )
421
+
422
+ finalized_sents: List[int] = []
423
+ if eos_bbsz_idx.numel() > 0:
424
+ eos_scores = torch.masked_select(
425
+ cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
426
+ )
427
+
428
+ finalized_sents = self.finalize_hypos(
429
+ step,
430
+ eos_bbsz_idx,
431
+ eos_scores,
432
+ tokens,
433
+ scores,
434
+ finalized,
435
+ finished,
436
+ beam_size,
437
+ attn,
438
+ src_lengths,
439
+ max_len,
440
+ )
441
+ num_remaining_sent -= len(finalized_sents)
442
+
443
+ assert num_remaining_sent >= 0
444
+ if num_remaining_sent == 0:
445
+ break
446
+ if self.search.stop_on_max_len and step >= max_len:
447
+ break
448
+ assert step < max_len, f"{step} < {max_len}"
449
+
450
+ # Remove finalized sentences (ones for which {beam_size}
451
+ # finished hypotheses have been generated) from the batch.
452
+ if len(finalized_sents) > 0:
453
+ new_bsz = bsz - len(finalized_sents)
454
+
455
+ # construct batch_idxs which holds indices of batches to keep for the next pass
456
+ batch_mask = torch.ones(
457
+ bsz, dtype=torch.bool, device=cand_indices.device
458
+ )
459
+ batch_mask[finalized_sents] = False
460
+ # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
461
+ batch_idxs = torch.arange(
462
+ bsz, device=cand_indices.device
463
+ ).masked_select(batch_mask)
464
+
465
+ # Choose the subset of the hypothesized constraints that will continue
466
+ self.search.prune_sentences(batch_idxs)
467
+
468
+ eos_mask = eos_mask[batch_idxs]
469
+ cand_beams = cand_beams[batch_idxs]
470
+ bbsz_offsets.resize_(new_bsz, 1)
471
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
472
+ cand_scores = cand_scores[batch_idxs]
473
+ cand_indices = cand_indices[batch_idxs]
474
+
475
+ if prefix_tokens is not None:
476
+ prefix_tokens = prefix_tokens[batch_idxs]
477
+ src_lengths = src_lengths[batch_idxs]
478
+ cands_to_ignore = cands_to_ignore[batch_idxs]
479
+
480
+ scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
481
+ tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
482
+ if attn is not None:
483
+ attn = attn.view(bsz, -1)[batch_idxs].view(
484
+ new_bsz * beam_size, attn.size(1), -1
485
+ )
486
+ bsz = new_bsz
487
+ else:
488
+ batch_idxs = None
489
+
490
+ # Set active_mask so that values > cand_size indicate eos hypos
491
+ # and values < cand_size indicate candidate active hypos.
492
+ # After, the min values per row are the top candidate active hypos
493
+
494
+ # Rewrite the operator since the element wise or is not supported in torchscript.
495
+
496
+ eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
497
+ active_mask = torch.add(
498
+ eos_mask.type_as(cand_offsets) * cand_size,
499
+ cand_offsets[: eos_mask.size(1)],
500
+ )
501
+
502
+ # get the top beam_size active hypotheses, which are just
503
+ # the hypos with the smallest values in active_mask.
504
+ # {active_hypos} indicates which {beam_size} hypotheses
505
+ # from the list of {2 * beam_size} candidates were
506
+ # selected. Shapes: (batch size, beam size)
507
+ new_cands_to_ignore, active_hypos = torch.topk(
508
+ active_mask, k=beam_size, dim=1, largest=False
509
+ )
510
+
511
+ # update cands_to_ignore to ignore any finalized hypos.
512
+ cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
513
+ # Make sure there is at least one active item for each sentence in the batch.
514
+ assert (~cands_to_ignore).any(dim=1).all()
515
+
516
+ # update cands_to_ignore to ignore any finalized hypos
517
+
518
+ # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
519
+ # can be selected more than once).
520
+ active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
521
+ active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
522
+
523
+ active_bbsz_idx = active_bbsz_idx.view(-1)
524
+ active_scores = active_scores.view(-1)
525
+
526
+ # copy tokens and scores for active hypotheses
527
+
528
+ # Set the tokens for each beam (can select the same row more than once)
529
+ tokens[:, : step + 1] = torch.index_select(
530
+ tokens[:, : step + 1], dim=0, index=active_bbsz_idx
531
+ )
532
+ # Select the next token for each of them
533
+ tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
534
+ cand_indices, dim=1, index=active_hypos
535
+ )
536
+ if step > 0:
537
+ scores[:, :step] = torch.index_select(
538
+ scores[:, :step], dim=0, index=active_bbsz_idx
539
+ )
540
+ scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
541
+ cand_scores, dim=1, index=active_hypos
542
+ )
543
+
544
+ # Update constraints based on which candidates were selected for the next beam
545
+ self.search.update_constraints(active_hypos)
546
+
547
+ # copy attention for active hypotheses
548
+ if attn is not None:
549
+ attn[:, :, : step + 2] = torch.index_select(
550
+ attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
551
+ )
552
+
553
+ # reorder incremental state in decoder
554
+ reorder_state = active_bbsz_idx
555
+
556
+ # sort by score descending
557
+ for sent in range(len(finalized)):
558
+ scores = torch.tensor(
559
+ [float(elem["score"].item()) for elem in finalized[sent]]
560
+ )
561
+ _, sorted_scores_indices = torch.sort(scores, descending=True)
562
+ finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
563
+ finalized[sent] = torch.jit.annotate(
564
+ List[Dict[str, Tensor]], finalized[sent]
565
+ )
566
+ return finalized
567
+
568
+ def _prefix_tokens(
569
+ self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
570
+ ):
571
+ """Handle prefix tokens"""
572
+ prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
573
+ prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
574
+ prefix_mask = prefix_toks.ne(self.pad)
575
+ lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs)
576
+ lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
577
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
578
+ )
579
+ # if prefix includes eos, then we should make sure tokens and
580
+ # scores are the same across all beams
581
+ eos_mask = prefix_toks.eq(self.eos)
582
+ if eos_mask.any():
583
+ # validate that the first beam matches the prefix
584
+ first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
585
+ :, 0, 1 : step + 1
586
+ ]
587
+ eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
588
+ target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
589
+ assert (first_beam == target_prefix).all()
590
+
591
+ # copy tokens, scores and lprobs from the first beam to all beams
592
+ tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
593
+ scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
594
+ lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
595
+ return lprobs, tokens, scores
596
+
597
+ def replicate_first_beam(self, tensor, mask, beam_size: int):
598
+ tensor = tensor.view(-1, beam_size, tensor.size(-1))
599
+ tensor[mask] = tensor[mask][:, :1, :]
600
+ return tensor.view(-1, tensor.size(-1))
601
+
602
+ def finalize_hypos(
603
+ self,
604
+ step: int,
605
+ bbsz_idx,
606
+ eos_scores,
607
+ tokens,
608
+ scores,
609
+ finalized: List[List[Dict[str, Tensor]]],
610
+ finished: List[bool],
611
+ beam_size: int,
612
+ attn: Optional[Tensor],
613
+ src_lengths,
614
+ max_len: int,
615
+ ):
616
+ """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
617
+ A sentence is finalized when {beam_size} finished items have been collected for it.
618
+
619
+ Returns number of sentences (not beam items) being finalized.
620
+ These will be removed from the batch and not processed further.
621
+ Args:
622
+ bbsz_idx (Tensor):
623
+ """
624
+ assert bbsz_idx.numel() == eos_scores.numel()
625
+
626
+ # clone relevant token and attention tensors.
627
+ # tokens is (batch * beam, max_len). So the index_select
628
+ # gets the newly EOS rows, then selects cols 1..{step + 2}
629
+ tokens_clone = tokens.index_select(0, bbsz_idx)[
630
+ :, 1 : step + 2
631
+ ] # skip the first index, which is EOS
632
+
633
+ tokens_clone[:, step] = self.eos
634
+ attn_clone = (
635
+ attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
636
+ if attn is not None
637
+ else None
638
+ )
639
+
640
+ # compute scores per token position
641
+ pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
642
+ pos_scores[:, step] = eos_scores
643
+ # convert from cumulative to per-position scores
644
+ pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
645
+
646
+ # normalize sentence-level scores
647
+ if self.normalize_scores:
648
+ eos_scores /= (step + 1) ** self.len_penalty
649
+
650
+ # cum_unfin records which sentences in the batch are finished.
651
+ # It helps match indexing between (a) the original sentences
652
+ # in the batch and (b) the current, possibly-reduced set of
653
+ # sentences.
654
+ cum_unfin: List[int] = []
655
+ prev = 0
656
+ for f in finished:
657
+ if f:
658
+ prev += 1
659
+ else:
660
+ cum_unfin.append(prev)
661
+
662
+ # The keys here are of the form "{sent}_{unfin_idx}", where
663
+ # "unfin_idx" is the index in the current (possibly reduced)
664
+ # list of sentences, and "sent" is the index in the original,
665
+ # unreduced batch
666
+ # set() is not supported in script export
667
+ sents_seen: Dict[str, Optional[Tensor]] = {}
668
+
669
+ # For every finished beam item
670
+ for i in range(bbsz_idx.size()[0]):
671
+ idx = bbsz_idx[i]
672
+ score = eos_scores[i]
673
+ # sentence index in the current (possibly reduced) batch
674
+ unfin_idx = idx // beam_size
675
+ # sentence index in the original (unreduced) batch
676
+ sent = unfin_idx + cum_unfin[unfin_idx]
677
+ # Cannot create dict for key type '(int, int)' in torchscript.
678
+ # The workaround is to cast int to string
679
+ seen = str(sent.item()) + "_" + str(unfin_idx.item())
680
+ if seen not in sents_seen:
681
+ sents_seen[seen] = None
682
+
683
+ if self.match_source_len and step > src_lengths[unfin_idx]:
684
+ score = torch.tensor(-math.inf).to(score)
685
+
686
+ # An input sentence (among those in a batch) is finished when
687
+ # beam_size hypotheses have been collected for it
688
+ if len(finalized[sent]) < beam_size:
689
+ if attn_clone is not None:
690
+ # remove padding tokens from attn scores
691
+ hypo_attn = attn_clone[i]
692
+ else:
693
+ hypo_attn = torch.empty(0)
694
+
695
+ finalized[sent].append(
696
+ {
697
+ "tokens": tokens_clone[i],
698
+ "score": score,
699
+ "attention": hypo_attn, # src_len x tgt_len
700
+ "alignment": torch.empty(0),
701
+ "positional_scores": pos_scores[i],
702
+ }
703
+ )
704
+
705
+ newly_finished: List[int] = []
706
+
707
+ for seen in sents_seen.keys():
708
+ # check termination conditions for this sentence
709
+ sent: int = int(float(seen.split("_")[0]))
710
+ unfin_idx: int = int(float(seen.split("_")[1]))
711
+
712
+ if not finished[sent] and self.is_finished(
713
+ step, unfin_idx, max_len, len(finalized[sent]), beam_size
714
+ ):
715
+ finished[sent] = True
716
+ newly_finished.append(unfin_idx)
717
+
718
+ return newly_finished
719
+
720
+ def is_finished(
721
+ self,
722
+ step: int,
723
+ unfin_idx: int,
724
+ max_len: int,
725
+ finalized_sent_len: int,
726
+ beam_size: int,
727
+ ):
728
+ """
729
+ Check whether decoding for a sentence is finished, which
730
+ occurs when the list of finalized sentences has reached the
731
+ beam size, or when we reach the maximum length.
732
+ """
733
+ assert finalized_sent_len <= beam_size
734
+ if finalized_sent_len == beam_size or step == max_len:
735
+ return True
736
+ return False
737
+
738
+
739
+ class EnsembleModel(nn.Module):
740
+ """A wrapper around an ensemble of models."""
741
+
742
+ def __init__(self, models):
743
+ super().__init__()
744
+ self.models_size = len(models)
745
+ # method '__len__' is not supported in ModuleList for torch script
746
+ self.single_model = models[0]
747
+ self.models = nn.ModuleList(models)
748
+
749
+ self.has_incremental: bool = False
750
+ if all(
751
+ hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
752
+ for m in models
753
+ ):
754
+ self.has_incremental = True
755
+
756
+ def forward(self):
757
+ pass
758
+
759
+ def has_encoder(self):
760
+ return hasattr(self.single_model, "encoder")
761
+
762
+ def has_incremental_states(self):
763
+ return self.has_incremental
764
+
765
+ def max_decoder_positions(self):
766
+ return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
767
+
768
+ @torch.jit.export
769
+ def forward_encoder(self, net_input: Dict[str, Tensor]):
770
+ if not self.has_encoder():
771
+ return None
772
+ return [model.encoder.forward_torchscript(net_input) for model in self.models]
773
+
774
+ @torch.jit.export
775
+ def forward_decoder(
776
+ self,
777
+ tokens,
778
+ encoder_outs: List[Dict[str, List[Tensor]]],
779
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
780
+ temperature: float = 1.0,
781
+ ):
782
+ log_probs = []
783
+ avg_attn: Optional[Tensor] = None
784
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None
785
+ for i, model in enumerate(self.models):
786
+ if self.has_encoder():
787
+ encoder_out = encoder_outs[i]
788
+ # decode each model
789
+ if self.has_incremental_states():
790
+ decoder_out = model.decoder.forward(
791
+ tokens,
792
+ encoder_out=encoder_out,
793
+ incremental_state=incremental_states[i],
794
+ )
795
+ else:
796
+ if hasattr(model, "decoder"):
797
+ decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
798
+ else:
799
+ decoder_out = model.forward(tokens)
800
+
801
+ attn: Optional[Tensor] = None
802
+ decoder_len = len(decoder_out)
803
+ if decoder_len > 1 and decoder_out[1] is not None:
804
+ if isinstance(decoder_out[1], Tensor):
805
+ attn = decoder_out[1]
806
+ else:
807
+ attn_holder = decoder_out[1]["attn"]
808
+ if isinstance(attn_holder, Tensor):
809
+ attn = attn_holder
810
+ elif attn_holder is not None:
811
+ attn = attn_holder[0]
812
+ if attn is not None:
813
+ attn = attn[:, -1, :]
814
+
815
+ decoder_out_tuple = (
816
+ decoder_out[0][:, -1:, :].div_(temperature),
817
+ None if decoder_len <= 1 else decoder_out[1],
818
+ )
819
+ probs = model.get_normalized_probs(
820
+ decoder_out_tuple, log_probs=True, sample=None
821
+ )
822
+ probs = probs[:, -1, :]
823
+ if self.models_size == 1:
824
+ return probs, attn
825
+
826
+ log_probs.append(probs)
827
+ if attn is not None:
828
+ if avg_attn is None:
829
+ avg_attn = attn
830
+ else:
831
+ avg_attn.add_(attn)
832
+
833
+ avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
834
+ self.models_size
835
+ )
836
+
837
+ if avg_attn is not None:
838
+ avg_attn.div_(self.models_size)
839
+ return avg_probs, avg_attn
840
+
841
+ @torch.jit.export
842
+ def reorder_encoder_out(
843
+ self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
844
+ ):
845
+ """
846
+ Reorder encoder output according to *new_order*.
847
+
848
+ Args:
849
+ encoder_out: output from the ``forward()`` method
850
+ new_order (LongTensor): desired order
851
+
852
+ Returns:
853
+ *encoder_out* rearranged according to *new_order*
854
+ """
855
+ new_outs: List[Dict[str, List[Tensor]]] = []
856
+ if not self.has_encoder():
857
+ return new_outs
858
+ for i, model in enumerate(self.models):
859
+ assert encoder_outs is not None
860
+ new_outs.append(
861
+ model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
862
+ )
863
+ return new_outs
864
+
865
+ @torch.jit.export
866
+ def reorder_incremental_state(
867
+ self,
868
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
869
+ new_order,
870
+ ):
871
+ if not self.has_incremental_states():
872
+ return
873
+ for i, model in enumerate(self.models):
874
+ model.decoder.reorder_incremental_state_scripting(
875
+ incremental_states[i], new_order
876
+ )
877
+
878
+
879
+ class SequenceGeneratorWithAlignment(SequenceGenerator):
880
+ def __init__(
881
+ self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
882
+ ):
883
+ """Generates translations of a given source sentence.
884
+
885
+ Produces alignments following "Jointly Learning to Align and
886
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
887
+
888
+ Args:
889
+ left_pad_target (bool, optional): Whether or not the
890
+ hypothesis should be left padded or not when they are
891
+ teacher forced for generating alignments.
892
+ """
893
+ super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
894
+ self.left_pad_target = left_pad_target
895
+
896
+ if print_alignment == "hard":
897
+ self.extract_alignment = utils.extract_hard_alignment
898
+ elif print_alignment == "soft":
899
+ self.extract_alignment = utils.extract_soft_alignment
900
+
901
+ @torch.no_grad()
902
+ def generate(self, models, sample, **kwargs):
903
+ finalized = super()._generate(sample, **kwargs)
904
+
905
+ src_tokens = sample["net_input"]["src_tokens"]
906
+ bsz = src_tokens.shape[0]
907
+ beam_size = self.beam_size
908
+ (
909
+ src_tokens,
910
+ src_lengths,
911
+ prev_output_tokens,
912
+ tgt_tokens,
913
+ ) = self._prepare_batch_for_alignment(sample, finalized)
914
+ if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
915
+ attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
916
+ else:
917
+ attn = [
918
+ finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
919
+ for i in range(bsz * beam_size)
920
+ ]
921
+
922
+ if src_tokens.device != "cpu":
923
+ src_tokens = src_tokens.to("cpu")
924
+ tgt_tokens = tgt_tokens.to("cpu")
925
+ attn = [i.to("cpu") for i in attn]
926
+
927
+ # Process the attn matrix to extract hard alignments.
928
+ for i in range(bsz * beam_size):
929
+ alignment = self.extract_alignment(
930
+ attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
931
+ )
932
+ finalized[i // beam_size][i % beam_size]["alignment"] = alignment
933
+ return finalized
934
+
935
+ def _prepare_batch_for_alignment(self, sample, hypothesis):
936
+ src_tokens = sample["net_input"]["src_tokens"]
937
+ bsz = src_tokens.shape[0]
938
+ src_tokens = (
939
+ src_tokens[:, None, :]
940
+ .expand(-1, self.beam_size, -1)
941
+ .contiguous()
942
+ .view(bsz * self.beam_size, -1)
943
+ )
944
+ src_lengths = sample["net_input"]["src_lengths"]
945
+ src_lengths = (
946
+ src_lengths[:, None]
947
+ .expand(-1, self.beam_size)
948
+ .contiguous()
949
+ .view(bsz * self.beam_size)
950
+ )
951
+ prev_output_tokens = data_utils.collate_tokens(
952
+ [beam["tokens"] for example in hypothesis for beam in example],
953
+ self.pad,
954
+ self.eos,
955
+ self.left_pad_target,
956
+ move_eos_to_beginning=True,
957
+ )
958
+ tgt_tokens = data_utils.collate_tokens(
959
+ [beam["tokens"] for example in hypothesis for beam in example],
960
+ self.pad,
961
+ self.eos,
962
+ self.left_pad_target,
963
+ move_eos_to_beginning=False,
964
+ )
965
+ return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
966
+
967
+
968
+ class EnsembleModelWithAlignment(EnsembleModel):
969
+ """A wrapper around an ensemble of models."""
970
+
971
+ def __init__(self, models):
972
+ super().__init__(models)
973
+
974
+ def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
975
+ avg_attn = None
976
+ for model in self.models:
977
+ decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
978
+ attn = decoder_out[1]["attn"][0]
979
+ if avg_attn is None:
980
+ avg_attn = attn
981
+ else:
982
+ avg_attn.add_(attn)
983
+ if len(self.models) > 1:
984
+ avg_attn.div_(len(self.models))
985
+ return avg_attn
slam_llm/models/avhubert/utils.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import cv2
8
+ import torch
9
+ import random
10
+ import numpy as np
11
+ from typing import Dict, List, Optional, Tuple
12
+
13
+ def load_video(path):
14
+ for i in range(3):
15
+ try:
16
+ cap = cv2.VideoCapture(path)
17
+ frames = []
18
+ while True:
19
+ ret, frame = cap.read()
20
+ if ret:
21
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
22
+ frames.append(frame)
23
+ else:
24
+ break
25
+ frames = np.stack(frames)
26
+ return frames
27
+ except Exception:
28
+ print(f"failed loading {path} ({i} / 3)")
29
+ if i == 2:
30
+ raise ValueError(f"Unable to load {path}")
31
+
32
+
33
+ class Compose(object):
34
+ """Compose several preprocess together.
35
+ Args:
36
+ preprocess (list of ``Preprocess`` objects): list of preprocess to compose.
37
+ """
38
+
39
+ def __init__(self, preprocess):
40
+ self.preprocess = preprocess
41
+
42
+ def __call__(self, sample):
43
+ for t in self.preprocess:
44
+ sample = t(sample)
45
+ return sample
46
+
47
+ def __repr__(self):
48
+ format_string = self.__class__.__name__ + '('
49
+ for t in self.preprocess:
50
+ format_string += '\n'
51
+ format_string += ' {0}'.format(t)
52
+ format_string += '\n)'
53
+ return format_string
54
+
55
+
56
+ class Normalize(object):
57
+ """Normalize a ndarray image with mean and standard deviation.
58
+ """
59
+
60
+ def __init__(self, mean, std):
61
+ self.mean = mean
62
+ self.std = std
63
+
64
+ def __call__(self, frames):
65
+ """
66
+ Args:
67
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
68
+ Returns:
69
+ Tensor: Normalized Tensor image.
70
+ """
71
+ frames = (frames - self.mean) / self.std
72
+ return frames
73
+
74
+ def __repr__(self):
75
+ return self.__class__.__name__+'(mean={0}, std={1})'.format(self.mean, self.std)
76
+
77
+ class CenterCrop(object):
78
+ """Crop the given image at the center
79
+ """
80
+ def __init__(self, size):
81
+ self.size = size
82
+
83
+ def __call__(self, frames):
84
+ """
85
+ Args:
86
+ img (numpy.ndarray): Images to be cropped.
87
+ Returns:
88
+ numpy.ndarray: Cropped image.
89
+ """
90
+ t, h, w = frames.shape
91
+ th, tw = self.size
92
+ delta_w = int(round((w - tw))/2.)
93
+ delta_h = int(round((h - th))/2.)
94
+ frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw]
95
+ return frames
96
+
97
+
98
+ class RandomCrop(object):
99
+ """Crop the given image at the center
100
+ """
101
+
102
+ def __init__(self, size):
103
+ self.size = size
104
+
105
+ def __call__(self, frames):
106
+ """
107
+ Args:
108
+ img (numpy.ndarray): Images to be cropped.
109
+ Returns:
110
+ numpy.ndarray: Cropped image.
111
+ """
112
+ t, h, w = frames.shape
113
+ th, tw = self.size
114
+ delta_w = random.randint(0, w-tw)
115
+ delta_h = random.randint(0, h-th)
116
+ frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw]
117
+ return frames
118
+
119
+ def __repr__(self):
120
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
121
+
122
+ class HorizontalFlip(object):
123
+ """Flip image horizontally.
124
+ """
125
+
126
+ def __init__(self, flip_ratio):
127
+ self.flip_ratio = flip_ratio
128
+
129
+ def __call__(self, frames):
130
+ """
131
+ Args:
132
+ img (numpy.ndarray): Images to be flipped with a probability flip_ratio
133
+ Returns:
134
+ numpy.ndarray: Cropped image.
135
+ """
136
+ t, h, w = frames.shape
137
+ if random.random() < self.flip_ratio:
138
+ for index in range(t):
139
+ frames[index] = cv2.flip(frames[index], 1)
140
+ return frames
141
+
142
+ def compute_mask_indices(
143
+ shape: Tuple[int, int],
144
+ padding_mask: Optional[torch.Tensor],
145
+ mask_prob: float,
146
+ mask_length: int,
147
+ mask_type: str = "static",
148
+ mask_other: float = 0.0,
149
+ min_masks: int = 0,
150
+ no_overlap: bool = False,
151
+ min_space: int = 0,
152
+ ) -> np.ndarray:
153
+ """
154
+ Computes random mask spans for a given shape
155
+ Args:
156
+ shape: the the shape for which to compute masks.
157
+ should be of size 2 where first element is batch size and 2nd is timesteps
158
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
159
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
160
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
161
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
162
+ mask_type: how to compute mask lengths
163
+ static = fixed size
164
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
165
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
166
+ poisson = sample from possion distribution with lambda = mask length
167
+ min_masks: minimum number of masked spans
168
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
169
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
170
+ """
171
+
172
+ bsz, all_sz = shape
173
+ mask = np.full((bsz, all_sz), False)
174
+
175
+ all_num_mask = int(
176
+ # add a random number for probabilistic rounding
177
+ mask_prob * all_sz / float(mask_length)
178
+ + np.random.rand()
179
+ )
180
+
181
+ all_num_mask = max(min_masks, all_num_mask)
182
+
183
+ mask_idcs = []
184
+ for i in range(bsz):
185
+ if padding_mask is not None:
186
+ sz = all_sz - padding_mask[i].long().sum().item()
187
+ num_mask = int(
188
+ # add a random number for probabilistic rounding
189
+ mask_prob * sz / float(mask_length)
190
+ + np.random.rand()
191
+ )
192
+ num_mask = max(min_masks, num_mask)
193
+ else:
194
+ sz = all_sz
195
+ num_mask = all_num_mask
196
+
197
+ if mask_type == "static":
198
+ lengths = np.full(num_mask, mask_length)
199
+ elif mask_type == "uniform":
200
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
201
+ elif mask_type == "normal":
202
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
203
+ lengths = [max(1, int(round(x))) for x in lengths]
204
+ elif mask_type == "poisson":
205
+ lengths = np.random.poisson(mask_length, size=num_mask)
206
+ lengths = [int(round(x)) for x in lengths]
207
+ else:
208
+ raise Exception("unknown mask selection " + mask_type)
209
+
210
+ if sum(lengths) == 0:
211
+ lengths[0] = min(mask_length, sz - 1)
212
+
213
+ if no_overlap:
214
+ mask_idc = []
215
+
216
+ def arrange(s, e, length, keep_length):
217
+ span_start = np.random.randint(s, e - length)
218
+ mask_idc.extend(span_start + i for i in range(length))
219
+
220
+ new_parts = []
221
+ if span_start - s - min_space >= keep_length:
222
+ new_parts.append((s, span_start - min_space + 1))
223
+ if e - span_start - keep_length - min_space > keep_length:
224
+ new_parts.append((span_start + length + min_space, e))
225
+ return new_parts
226
+
227
+ parts = [(0, sz)]
228
+ min_length = min(lengths)
229
+ for length in sorted(lengths, reverse=True):
230
+ lens = np.fromiter(
231
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
232
+ np.int,
233
+ )
234
+ l_sum = np.sum(lens)
235
+ if l_sum == 0:
236
+ break
237
+ probs = lens / np.sum(lens)
238
+ c = np.random.choice(len(parts), p=probs)
239
+ s, e = parts.pop(c)
240
+ parts.extend(arrange(s, e, length, min_length))
241
+ mask_idc = np.asarray(mask_idc)
242
+ else:
243
+ min_len = min(lengths)
244
+ if sz - min_len <= num_mask:
245
+ min_len = sz - num_mask - 1
246
+
247
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
248
+
249
+ mask_idc = np.asarray(
250
+ [
251
+ mask_idc[j] + offset
252
+ for j in range(len(mask_idc))
253
+ for offset in range(lengths[j])
254
+ ]
255
+ )
256
+
257
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
258
+
259
+ min_len = min([len(m) for m in mask_idcs])
260
+ batch_indexes, starts, ends = [], [], []
261
+ for i, mask_idc in enumerate(mask_idcs):
262
+ if len(mask_idc) > min_len:
263
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
264
+ mask[i, mask_idc] = True
265
+ vals, run_starts, run_lengths = find_runs(mask[i])
266
+ start_indices, lengths = run_starts[vals == True], run_lengths[vals == True]
267
+ starts.append(start_indices)
268
+ ends.append(start_indices+lengths)
269
+ batch_indexes.append(np.zeros([len(start_indices)])+i)
270
+ return mask, np.concatenate(starts).astype(np.int64), np.concatenate(ends).astype(np.int64), np.concatenate(batch_indexes).astype(np.int64)
271
+
272
+ def find_runs(x):
273
+ """Find runs of consecutive items in an array."""
274
+
275
+ # ensure array
276
+ x = np.asanyarray(x)
277
+ if x.ndim != 1:
278
+ raise ValueError('only 1D array supported')
279
+ n = x.shape[0]
280
+
281
+ # handle empty array
282
+ if n == 0:
283
+ return np.array([]), np.array([]), np.array([])
284
+
285
+ else:
286
+ # find run starts
287
+ loc_run_start = np.empty(n, dtype=bool)
288
+ loc_run_start[0] = True
289
+ np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
290
+ run_starts = np.nonzero(loc_run_start)[0]
291
+
292
+ # find run values
293
+ run_values = x[loc_run_start]
294
+
295
+ # find run lengths
296
+ run_lengths = np.diff(np.append(run_starts, n))
297
+
298
+ return run_values, run_starts, run_lengths
slam_llm/models/encoder.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from dataclasses import dataclass
6
+
7
+ class WhisperWrappedEncoder:
8
+
9
+ @classmethod
10
+ def load(cls, model_config):
11
+
12
+ def extract_variable_length_features(self, x: torch.Tensor):
13
+ """
14
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
15
+ the mel spectrogram of the audio
16
+ """
17
+ x = F.gelu(self.conv1(x))
18
+ x = F.gelu(self.conv2(x))
19
+ x = x.permute(0, 2, 1)
20
+
21
+ # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
22
+ # x = (x + self.positional_embedding).to(x.dtype)
23
+ x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
24
+
25
+ for block in self.blocks:
26
+ x = block(x)
27
+
28
+ x = self.ln_post(x)
29
+ return x
30
+
31
+ import whisper
32
+ encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder
33
+ encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder)
34
+ return encoder
35
+
36
+
37
+ class BEATsEncoder:
38
+
39
+ @classmethod
40
+ def load(cls, model_config):
41
+ from .BEATs.BEATs import BEATs, BEATsConfig
42
+ checkpoint = torch.load(model_config.encoder_path)
43
+ cfg = BEATsConfig(checkpoint['cfg'])
44
+ BEATs_model = BEATs(cfg)
45
+ BEATs_model.load_state_dict(checkpoint['model'])
46
+
47
+ return BEATs_model
48
+
49
+
50
+ @dataclass
51
+ class UserDirModule:
52
+ user_dir: str
53
+
54
+ class EATEncoder:
55
+
56
+ @classmethod
57
+ def load(cls, model_config):
58
+ import fairseq
59
+ model_path = UserDirModule(model_config.encoder_fairseq_dir)
60
+ fairseq.utils.import_user_module(model_path)
61
+ EATEncoder, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
62
+ EATEncoder = EATEncoder[0]
63
+
64
+ return EATEncoder
65
+
66
+ def extract_features(self, source, padding_mask):
67
+ return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x']
68
+
69
+ class SpatialASTEncoder:
70
+ @classmethod
71
+ def load(cls, model_config):
72
+ from functools import partial
73
+ from .SpatialAST import SpatialAST
74
+ binaural_encoder = SpatialAST.BinauralEncoder(
75
+ num_classes=355, drop_path_rate=0.1, num_cls_tokens=3,
76
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
77
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
78
+ )
79
+
80
+ checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu')
81
+ binaural_encoder.load_state_dict(checkpoint['model'], strict=False)
82
+ return binaural_encoder
83
+
84
+ class WavLMEncoder(nn.Module):
85
+ def __init__(self, config, model):
86
+ super().__init__()
87
+ self.config = config
88
+ self.model = model
89
+
90
+ @classmethod
91
+ def load(cls, model_config):
92
+ from .wavlm.WavLM import WavLM, WavLMConfig
93
+ checkpoint = torch.load(model_config.encoder_path)
94
+ cfg = WavLMConfig(checkpoint['cfg'])
95
+ WavLM_model = WavLM(cfg)
96
+ WavLM_model.load_state_dict(checkpoint['model'])
97
+ assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match"
98
+
99
+ return cls(cfg, WavLM_model)
100
+
101
+ def extract_features(self, source, padding_mask):
102
+ return self.model.extract_features(source, padding_mask)[0]
103
+
104
+ class AVHubertEncoder:
105
+
106
+ @classmethod
107
+ def load(cls, model_config):
108
+ import fairseq
109
+ from .avhubert import hubert_pretraining, hubert, hubert_asr
110
+ models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
111
+ model = models[0]
112
+ return model
113
+
114
+ class HubertEncoder:
115
+
116
+ @classmethod
117
+ def load(cls, model_config):
118
+ import fairseq
119
+ models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
120
+ model = models[0]
121
+ if model_config.encoder_type == "pretrain":
122
+ pass
123
+ elif model_config.encoder_type == "finetune":
124
+ model.w2v_encoder.proj = None
125
+ model.w2v_encoder.apply_mask = False
126
+ else:
127
+ assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]"
128
+ return model
129
+
130
+
131
+ class HfTextEncoder:
132
+
133
+ @classmethod
134
+ def load(cls, model_config):
135
+ from transformers import AutoModel
136
+ model = AutoModel.from_pretrained(model_config.encoder_path)
137
+ return model
138
+
139
+ class MusicFMEncoder(nn.Module):
140
+ def __init__(self, config, model):
141
+ super().__init__()
142
+ self.config = config
143
+ self.model = model
144
+
145
+ @classmethod
146
+ def load(cls, model_config):
147
+ from .musicfm.model.musicfm_25hz import MusicFM25Hz
148
+ model = MusicFM25Hz(
149
+ stat_path = model_config.encoder_stat_path,
150
+ model_path = model_config.encoder_path,
151
+ w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft")
152
+ )
153
+ return cls(model_config, model)
154
+
155
+ def extract_features(self, source, padding_mask=None):
156
+ _, hidden_states = self.model.get_predictions(source)
157
+ out = hidden_states[self.config.encoder_layer_idx]
158
+ return out
slam_llm/models/musicfm/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
slam_llm/models/musicfm/model/musicfm_25hz.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright 2023 ByteDance Inc.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
6
+ # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10
+ #
11
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
13
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
14
+ # IN THE SOFTWARE.
15
+
16
+ import json
17
+ import random
18
+ import torch
19
+ from torch import nn
20
+ from einops import rearrange
21
+
22
+ from ..modules.random_quantizer import RandomProjectionQuantizer
23
+ from ..modules.features import MelSTFT
24
+ from ..modules.conv import Conv2dSubsampling
25
+
26
+
27
+ class MusicFM25Hz(nn.Module):
28
+ """
29
+ MusicFM
30
+
31
+ Input: 128-band mel spectrogram
32
+ Frontend: 2-layer Residual convolution
33
+ Backend: 12-layer Conformer
34
+ Quantizer: a codebook for mel spectrogram
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ num_codebooks=1,
40
+ codebook_dim=16,
41
+ codebook_size=4096,
42
+ features=["melspec_2048"],
43
+ hop_length=240,
44
+ n_mels=128,
45
+ conv_dim=512,
46
+ encoder_dim=1024,
47
+ encoder_depth=12,
48
+ mask_hop=0.4,
49
+ mask_prob=0.6,
50
+ is_flash=False,
51
+ stat_path="./data/fma_stats.json",
52
+ model_path="./data/pretrained_fma.pt",
53
+ w2v2_config_path="facebook/wav2vec2-conformer-rope-large-960h-ft",
54
+ ):
55
+ super(MusicFM25Hz, self).__init__()
56
+
57
+ # global variables
58
+ self.hop_length = hop_length
59
+ self.mask_hop = mask_hop
60
+ self.mask_prob = mask_prob
61
+ self.num_codebooks = num_codebooks
62
+ self.codebook_size = codebook_size
63
+ self.features = features
64
+
65
+ # load feature mean / std stats
66
+ with open(stat_path, "r") as f:
67
+ self.stat = json.load(f)
68
+
69
+ # feature extractor
70
+ self.preprocessor_melspec_2048 = MelSTFT(
71
+ n_fft=2048, hop_length=hop_length, is_db=True
72
+ )
73
+
74
+ # random quantizer
75
+ seed = 142
76
+ for feature in self.features:
77
+ for i in range(num_codebooks):
78
+ setattr(
79
+ self,
80
+ f"quantizer_{feature}_{i}",
81
+ RandomProjectionQuantizer(
82
+ n_mels * 4, codebook_dim, codebook_size, seed=seed + i
83
+ ),
84
+ )
85
+
86
+ # two residual convolution layers + one projection layer
87
+ self.conv = Conv2dSubsampling(
88
+ 1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
89
+ )
90
+
91
+ # Conformer
92
+ if is_flash:
93
+ from modules.flash_conformer import (
94
+ Wav2Vec2ConformerEncoder,
95
+ Wav2Vec2ConformerConfig,
96
+ )
97
+ else:
98
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
99
+ Wav2Vec2ConformerEncoder,
100
+ Wav2Vec2ConformerConfig,
101
+ )
102
+ config = Wav2Vec2ConformerConfig.from_pretrained(
103
+ w2v2_config_path
104
+ )
105
+ config.num_hidden_layers = encoder_depth
106
+ config.hidden_size = encoder_dim
107
+
108
+ self.conformer = Wav2Vec2ConformerEncoder(config)
109
+
110
+ # projection
111
+ self.linear = nn.Linear(encoder_dim, codebook_size)
112
+
113
+ # loss function
114
+ self.loss = nn.CrossEntropyLoss()
115
+
116
+ # cls token (used for sequence classification)
117
+ random.seed(seed)
118
+ self.cls_token = nn.Parameter(torch.randn(encoder_dim))
119
+
120
+ # load model
121
+ if model_path:
122
+ S = torch.load(model_path)["state_dict"]
123
+ SS = {k[6:]: v for k, v in S.items()}
124
+ self.load_state_dict(SS, strict=True)
125
+
126
+ def masking(self, x):
127
+ """random masking of 400ms with given probability"""
128
+ mx = x.clone()
129
+ b, t = mx.shape
130
+ len_masking_raw = int(24000 * self.mask_hop)
131
+ len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
132
+
133
+ # get random mask indices
134
+ start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
135
+ time_domain_masked_indices = torch.nonzero(
136
+ start_indices.repeat_interleave(len_masking_raw, dim=1)
137
+ )
138
+ token_domain_masked_indices = torch.nonzero(
139
+ start_indices.repeat_interleave(len_masking_token, dim=1)
140
+ )
141
+
142
+ # mask with random values
143
+ masking_noise = (
144
+ torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
145
+ ) # 0 mean 0.1 std
146
+ mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
147
+
148
+ return mx, token_domain_masked_indices
149
+
150
+ @torch.no_grad()
151
+ def preprocessing(self, x, features):
152
+ """extract classic audio features"""
153
+ # check precision
154
+ if x.dtype == torch.float16:
155
+ precision = 16
156
+ else:
157
+ precision = 32
158
+
159
+ out = {}
160
+ for key in features:
161
+ layer = getattr(self, "preprocessor_%s" % key)
162
+ out[key] = layer.float()(x.float())[..., :-1]
163
+ if precision == 16:
164
+ out[key] = out[key].half()
165
+ return out
166
+
167
+ def encoder(self, x):
168
+ """2-layer conv + w2v-conformer"""
169
+ x = self.conv(x)
170
+ out = self.conformer(x, output_hidden_states=True)
171
+ hidden_emb = out["hidden_states"]
172
+ last_emb = out["last_hidden_state"]
173
+ logits = self.linear(last_emb)
174
+ logits = {
175
+ key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
176
+ for i, key in enumerate(self.features)
177
+ }
178
+ return logits, hidden_emb
179
+
180
+ @torch.no_grad()
181
+ def normalize(self, x):
182
+ """normalize the input audio to have zero mean unit variance"""
183
+ for key in x.keys():
184
+ x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
185
+ return x
186
+
187
+ @torch.no_grad()
188
+ def rearrange(self, x):
189
+ """rearrange the batch to flatten every 4 steps"""
190
+ for key in x.keys():
191
+ if key == "chromagram":
192
+ x[key] = rearrange(x[key], "b f t -> b t f")
193
+ else:
194
+ x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
195
+ return x
196
+
197
+ @torch.no_grad()
198
+ def tokenize(self, x):
199
+ out = {}
200
+ for key in x.keys():
201
+ layer = getattr(self, "quantizer_%s" % key)
202
+ out[key] = layer(x[key])
203
+ return out
204
+
205
+ def get_targets(self, x):
206
+ x = self.preprocessing(x, features=self.features)
207
+ x = self.normalize(x)
208
+ x = self.rearrange(x)
209
+ target_tokens = self.tokenize(x)
210
+ return target_tokens
211
+
212
+ def get_predictions(self, x):
213
+ # preprocessing
214
+ x = self.preprocessing(x, features=["melspec_2048"])
215
+ x = self.normalize(x)
216
+
217
+ # encoding
218
+ logits, hidden_emb = self.encoder(x["melspec_2048"])
219
+
220
+ return logits, hidden_emb
221
+
222
+ def get_latent(self, x, layer_ix=12):
223
+ _, hidden_states = self.get_predictions(x)
224
+ emb = hidden_states[layer_ix]
225
+ return emb
226
+
227
+ def get_loss(self, logits, target_tokens, masked_indices):
228
+ losses = {}
229
+ accuracies = {}
230
+ for key in logits.keys():
231
+ masked_logits = logits[key][tuple(masked_indices.t())]
232
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
233
+ losses[key] = self.loss(masked_logits, masked_tokens)
234
+ accuracies[key] = (
235
+ torch.sum(masked_logits.argmax(-1) == masked_tokens)
236
+ / masked_tokens.numel()
237
+ )
238
+ return losses, accuracies
239
+
240
+ def forward(self, x):
241
+ # get target feature tokens
242
+ target_tokens = self.get_targets(x)
243
+
244
+ # masking
245
+ x, masked_indices = self.masking(x)
246
+
247
+ # forward
248
+ logits, hidden_emb = self.get_predictions(x)
249
+
250
+ # get loss
251
+ losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
252
+
253
+ return logits, hidden_emb, losses, accuracies
slam_llm/models/musicfm/modules/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
slam_llm/models/musicfm/modules/conv.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright 2023 ByteDance Inc.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
6
+ # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10
+ #
11
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
13
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
14
+ # IN THE SOFTWARE.
15
+
16
+ from torch import nn
17
+ from einops import rearrange
18
+
19
+
20
+ class Res2dModule(nn.Module):
21
+ def __init__(self, idim, odim, stride=(2, 2)):
22
+ super(Res2dModule, self).__init__()
23
+ self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
24
+ self.bn1 = nn.BatchNorm2d(odim)
25
+ self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
26
+ self.bn2 = nn.BatchNorm2d(odim)
27
+ self.relu = nn.ReLU()
28
+
29
+ # residual
30
+ self.diff = False
31
+ if (idim != odim) or (stride[0] > 1):
32
+ self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
33
+ self.bn3 = nn.BatchNorm2d(odim)
34
+ self.diff = True
35
+
36
+ def forward(self, x):
37
+ out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
38
+ if self.diff:
39
+ x = self.bn3(self.conv3(x))
40
+ out = x + out
41
+ out = self.relu(out)
42
+ return out
43
+
44
+
45
+ class Conv2dSubsampling(nn.Module):
46
+ """Convolutional 2D subsampling (to 1/4 length).
47
+
48
+ Args:
49
+ idim (int): Input dimension.
50
+ hdim (int): Hidden dimension.
51
+ odim (int): Output dimension.
52
+ strides (list): Sizes of strides.
53
+ n_bands (int): Number of frequency bands.
54
+ """
55
+
56
+ def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
57
+ """Construct an Conv2dSubsampling object."""
58
+ super(Conv2dSubsampling, self).__init__()
59
+
60
+ self.conv = nn.Sequential(
61
+ Res2dModule(idim, hdim, (2, strides[0])),
62
+ Res2dModule(hdim, hdim, (2, strides[1])),
63
+ )
64
+ self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
65
+
66
+ def forward(self, x):
67
+ """Subsample x.
68
+
69
+ Args:
70
+ x (torch.Tensor): Input tensor (#batch, idim, time).
71
+
72
+ Returns:
73
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
74
+ where time' = time // 4.
75
+ """
76
+
77
+ if x.dim() == 3:
78
+ x = x.unsqueeze(1) # (b, c, f, t)
79
+ x = self.conv(x)
80
+ x = rearrange(x, "b c f t -> b t (c f)")
81
+ x = self.linear(x)
82
+ return x
slam_llm/models/musicfm/modules/features.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright 2023 ByteDance Inc.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
6
+ # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10
+ #
11
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
13
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
14
+ # IN THE SOFTWARE.
15
+
16
+ import torchaudio
17
+ from torch import nn
18
+
19
+
20
+ class MelSTFT(nn.Module):
21
+ def __init__(
22
+ self,
23
+ sample_rate=24000,
24
+ n_fft=2048,
25
+ hop_length=240,
26
+ n_mels=128,
27
+ is_db=False,
28
+ ):
29
+ super(MelSTFT, self).__init__()
30
+
31
+ # spectrogram
32
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
33
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
34
+ )
35
+
36
+ # amplitude to decibel
37
+ self.is_db = is_db
38
+ if is_db:
39
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
40
+
41
+ def forward(self, waveform):
42
+ if self.is_db:
43
+ return self.amplitude_to_db(self.mel_stft(waveform))
44
+ else:
45
+ return self.mel_stft(waveform)
slam_llm/models/musicfm/modules/flash_conformer.py ADDED
@@ -0,0 +1,2114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Wav2Vec2-Conformer model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+ from torch.nn import functional as F
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ CausalLMOutput,
33
+ SequenceClassifierOutput,
34
+ TokenClassifierOutput,
35
+ Wav2Vec2BaseModelOutput,
36
+ XVectorOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.utils import (
40
+ ModelOutput,
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ _HIDDEN_STATES_START_POSITION = 2
54
+
55
+ # General docstring
56
+ _CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
57
+
58
+ # Base docstring
59
+ _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
60
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
61
+
62
+ # CTC docstring
63
+ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
64
+ _CTC_EXPECTED_LOSS = 64.21
65
+
66
+
67
+ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
68
+ "facebook/wav2vec2-conformer-rel-pos-large",
69
+ # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
70
+ ]
71
+
72
+
73
+ @dataclass
74
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
75
+ class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
76
+ """
77
+ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
78
+
79
+ Args:
80
+ loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
81
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
82
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
83
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
84
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
85
+ projected quantized states.
86
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
87
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
88
+ target vectors for contrastive loss.
89
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
90
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
91
+ shape `(batch_size, sequence_length, hidden_size)`.
92
+
93
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
94
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
95
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
96
+ sequence_length)`.
97
+
98
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
99
+ heads.
100
+ contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
101
+ The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
102
+ diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
103
+ The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
104
+ """
105
+
106
+ loss: Optional[torch.FloatTensor] = None
107
+ projected_states: torch.FloatTensor = None
108
+ projected_quantized_states: torch.FloatTensor = None
109
+ codevector_perplexity: torch.FloatTensor = None
110
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
111
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
112
+ contrastive_loss: Optional[torch.FloatTensor] = None
113
+ diversity_loss: Optional[torch.FloatTensor] = None
114
+
115
+
116
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
117
+ def _compute_mask_indices(
118
+ shape: Tuple[int, int],
119
+ mask_prob: float,
120
+ mask_length: int,
121
+ attention_mask: Optional[torch.LongTensor] = None,
122
+ min_masks: int = 0,
123
+ ) -> np.ndarray:
124
+ """
125
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
126
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
127
+ CPU as part of the preprocessing during training.
128
+
129
+ Args:
130
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
131
+ the first element is the batch size and the second element is the length of the axis to span.
132
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
133
+ independently generated mask spans of length `mask_length` is computed by
134
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
135
+ actual percentage will be smaller.
136
+ mask_length: size of the mask
137
+ min_masks: minimum number of masked spans
138
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
139
+ each batch dimension.
140
+ """
141
+ batch_size, sequence_length = shape
142
+
143
+ if mask_length < 1:
144
+ raise ValueError("`mask_length` has to be bigger than 0.")
145
+
146
+ if mask_length > sequence_length:
147
+ raise ValueError(
148
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
149
+ f" and `sequence_length`: {sequence_length}`"
150
+ )
151
+
152
+ # epsilon is used for probabilistic rounding
153
+ epsilon = np.random.rand(1).item()
154
+
155
+ def compute_num_masked_span(input_length):
156
+ """Given input length, compute how many spans should be masked"""
157
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
158
+ num_masked_span = max(num_masked_span, min_masks)
159
+
160
+ # make sure num masked span <= sequence_length
161
+ if num_masked_span * mask_length > sequence_length:
162
+ num_masked_span = sequence_length // mask_length
163
+
164
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
165
+ if input_length - (mask_length - 1) < num_masked_span:
166
+ num_masked_span = max(input_length - (mask_length - 1), 0)
167
+
168
+ return num_masked_span
169
+
170
+ # compute number of masked spans in batch
171
+ input_lengths = (
172
+ attention_mask.sum(-1).detach().tolist()
173
+ if attention_mask is not None
174
+ else [sequence_length for _ in range(batch_size)]
175
+ )
176
+
177
+ # SpecAugment mask to fill
178
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
179
+ spec_aug_mask_idxs = []
180
+
181
+ max_num_masked_span = compute_num_masked_span(sequence_length)
182
+
183
+ if max_num_masked_span == 0:
184
+ return spec_aug_mask
185
+
186
+ for input_length in input_lengths:
187
+ # compute num of masked spans for this input
188
+ num_masked_span = compute_num_masked_span(input_length)
189
+
190
+ # get random indices to mask
191
+ spec_aug_mask_idx = np.random.choice(
192
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
193
+ )
194
+
195
+ # pick first sampled index that will serve as a dummy index to pad vector
196
+ # to ensure same dimension for all batches due to probabilistic rounding
197
+ # Picking first sample just pads those vectors twice.
198
+ if len(spec_aug_mask_idx) == 0:
199
+ # this case can only happen if `input_length` is strictly smaller then
200
+ # `sequence_length` in which case the last token has to be a padding
201
+ # token which we can use as a dummy mask id
202
+ dummy_mask_idx = sequence_length - 1
203
+ else:
204
+ dummy_mask_idx = spec_aug_mask_idx[0]
205
+
206
+ spec_aug_mask_idx = np.concatenate(
207
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
208
+ )
209
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
210
+
211
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
212
+
213
+ # expand masked indices to masked spans
214
+ spec_aug_mask_idxs = np.broadcast_to(
215
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
216
+ )
217
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
218
+
219
+ # add offset to the starting indexes so that indexes now create a span
220
+ offsets = np.arange(mask_length)[None, None, :]
221
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
222
+ batch_size, max_num_masked_span * mask_length
223
+ )
224
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
225
+
226
+ # ensure that we cannot have indices larger than sequence_length
227
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
228
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
229
+
230
+ # scatter indices to mask
231
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
232
+
233
+ return spec_aug_mask
234
+
235
+
236
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
237
+ def _sample_negative_indices(
238
+ features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
239
+ ):
240
+ """
241
+ Sample `num_negatives` vectors from feature vectors.
242
+ """
243
+ batch_size, sequence_length = features_shape
244
+
245
+ # generate indices of the positive vectors themselves, repeat them `num_negatives` times
246
+ sequence_length_range = np.arange(sequence_length)
247
+
248
+ # get `num_negatives` random vector indices from the same utterance
249
+ sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
250
+
251
+ mask_time_indices = (
252
+ mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
253
+ )
254
+
255
+ for batch_idx in range(batch_size):
256
+ high = mask_time_indices[batch_idx].sum() - 1
257
+ mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
258
+
259
+ feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
260
+ sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
261
+ # avoid sampling the same positive vector, but keep the distribution uniform
262
+ sampled_indices[sampled_indices >= feature_indices] += 1
263
+
264
+ # remap to actual indices
265
+ sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
266
+
267
+ # correct for batch size
268
+ sampled_negative_indices[batch_idx] += batch_idx * sequence_length
269
+
270
+ return sampled_negative_indices
271
+
272
+
273
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
274
+ class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
275
+ def __init__(self, config, layer_id=0):
276
+ super().__init__()
277
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
278
+ self.out_conv_dim = config.conv_dim[layer_id]
279
+
280
+ self.conv = nn.Conv1d(
281
+ self.in_conv_dim,
282
+ self.out_conv_dim,
283
+ kernel_size=config.conv_kernel[layer_id],
284
+ stride=config.conv_stride[layer_id],
285
+ bias=config.conv_bias,
286
+ )
287
+ self.activation = ACT2FN[config.feat_extract_activation]
288
+
289
+ def forward(self, hidden_states):
290
+ hidden_states = self.conv(hidden_states)
291
+ hidden_states = self.activation(hidden_states)
292
+ return hidden_states
293
+
294
+
295
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
296
+ class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
297
+ def __init__(self, config, layer_id=0):
298
+ super().__init__()
299
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
300
+ self.out_conv_dim = config.conv_dim[layer_id]
301
+
302
+ self.conv = nn.Conv1d(
303
+ self.in_conv_dim,
304
+ self.out_conv_dim,
305
+ kernel_size=config.conv_kernel[layer_id],
306
+ stride=config.conv_stride[layer_id],
307
+ bias=config.conv_bias,
308
+ )
309
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
310
+ self.activation = ACT2FN[config.feat_extract_activation]
311
+
312
+ def forward(self, hidden_states):
313
+ hidden_states = self.conv(hidden_states)
314
+
315
+ hidden_states = hidden_states.transpose(-2, -1)
316
+ hidden_states = self.layer_norm(hidden_states)
317
+ hidden_states = hidden_states.transpose(-2, -1)
318
+
319
+ hidden_states = self.activation(hidden_states)
320
+ return hidden_states
321
+
322
+
323
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
324
+ class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
325
+ def __init__(self, config, layer_id=0):
326
+ super().__init__()
327
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
328
+ self.out_conv_dim = config.conv_dim[layer_id]
329
+
330
+ self.conv = nn.Conv1d(
331
+ self.in_conv_dim,
332
+ self.out_conv_dim,
333
+ kernel_size=config.conv_kernel[layer_id],
334
+ stride=config.conv_stride[layer_id],
335
+ bias=config.conv_bias,
336
+ )
337
+ self.activation = ACT2FN[config.feat_extract_activation]
338
+
339
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
340
+
341
+ def forward(self, hidden_states):
342
+ hidden_states = self.conv(hidden_states)
343
+ hidden_states = self.layer_norm(hidden_states)
344
+ hidden_states = self.activation(hidden_states)
345
+ return hidden_states
346
+
347
+
348
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
349
+ class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.conv = nn.Conv1d(
353
+ config.hidden_size,
354
+ config.hidden_size,
355
+ kernel_size=config.num_conv_pos_embeddings,
356
+ padding=config.num_conv_pos_embeddings // 2,
357
+ groups=config.num_conv_pos_embedding_groups,
358
+ )
359
+
360
+ if is_deepspeed_zero3_enabled():
361
+ import deepspeed
362
+
363
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
364
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
365
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
366
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
367
+ else:
368
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
369
+
370
+ self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
371
+ self.activation = ACT2FN[config.feat_extract_activation]
372
+
373
+ def forward(self, hidden_states):
374
+ hidden_states = hidden_states.transpose(1, 2)
375
+
376
+ hidden_states = self.conv(hidden_states)
377
+ hidden_states = self.padding(hidden_states)
378
+ hidden_states = self.activation(hidden_states)
379
+
380
+ hidden_states = hidden_states.transpose(1, 2)
381
+ return hidden_states
382
+
383
+
384
+ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
385
+ """Rotary positional embedding
386
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
387
+ """
388
+
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ dim = config.hidden_size // config.num_attention_heads
392
+ base = config.rotary_embedding_base
393
+
394
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
395
+ self.register_buffer("inv_freq", inv_freq)
396
+ self.cached_sequence_length = None
397
+ self.cached_rotary_positional_embedding = None
398
+
399
+ def forward(self, hidden_states):
400
+ sequence_length = hidden_states.shape[1]
401
+
402
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
403
+ return self.cached_rotary_positional_embedding
404
+
405
+ self.cached_sequence_length = sequence_length
406
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
407
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
408
+ embeddings = torch.cat((freqs, freqs), dim=-1)
409
+
410
+ cos_embeddings = embeddings.cos()[:, None, None, :]
411
+ sin_embeddings = embeddings.sin()[:, None, None, :]
412
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
413
+ return self.cached_rotary_positional_embedding
414
+
415
+
416
+ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
417
+ """Relative positional encoding module."""
418
+
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.max_len = config.max_source_positions
422
+ self.d_model = config.hidden_size
423
+ self.pe = None
424
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
425
+
426
+ def extend_pe(self, x):
427
+ # Reset the positional encodings
428
+ if self.pe is not None:
429
+ # self.pe contains both positive and negative parts
430
+ # the length of self.pe is 2 * input_len - 1
431
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
432
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
433
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
434
+ return
435
+ # Suppose `i` is the position of query vector and `j` is the
436
+ # position of key vector. We use positive relative positions when keys
437
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
438
+ pe_positive = torch.zeros(x.size(1), self.d_model)
439
+ pe_negative = torch.zeros(x.size(1), self.d_model)
440
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
441
+ div_term = torch.exp(
442
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
443
+ )
444
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
445
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
446
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
447
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
448
+
449
+ # Reverse the order of positive indices and concat both positive and
450
+ # negative indices. This is used to support the shifting trick
451
+ # as in https://arxiv.org/abs/1901.02860
452
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
453
+ pe_negative = pe_negative[1:].unsqueeze(0)
454
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
455
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
456
+
457
+ def forward(self, hidden_states: torch.Tensor):
458
+ self.extend_pe(hidden_states)
459
+ start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
460
+ end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
461
+ relative_position_embeddings = self.pe[:, start_idx:end_idx]
462
+
463
+ return relative_position_embeddings
464
+
465
+
466
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
467
+ class Wav2Vec2ConformerSamePadLayer(nn.Module):
468
+ def __init__(self, num_conv_pos_embeddings):
469
+ super().__init__()
470
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
471
+
472
+ def forward(self, hidden_states):
473
+ if self.num_pad_remove > 0:
474
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
475
+ return hidden_states
476
+
477
+
478
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
479
+ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
480
+ """Construct the features from raw audio waveform"""
481
+
482
+ def __init__(self, config):
483
+ super().__init__()
484
+
485
+ if config.feat_extract_norm == "group":
486
+ conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
487
+ Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
488
+ for i in range(config.num_feat_extract_layers - 1)
489
+ ]
490
+ elif config.feat_extract_norm == "layer":
491
+ conv_layers = [
492
+ Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
493
+ ]
494
+ else:
495
+ raise ValueError(
496
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
497
+ )
498
+ self.conv_layers = nn.ModuleList(conv_layers)
499
+ self.gradient_checkpointing = False
500
+ self._requires_grad = True
501
+
502
+ def _freeze_parameters(self):
503
+ for param in self.parameters():
504
+ param.requires_grad = False
505
+ self._requires_grad = False
506
+
507
+ def forward(self, input_values):
508
+ hidden_states = input_values[:, None]
509
+
510
+ # make sure hidden_states require grad for gradient_checkpointing
511
+ if self._requires_grad and self.training:
512
+ hidden_states.requires_grad = True
513
+
514
+ for conv_layer in self.conv_layers:
515
+ if self._requires_grad and self.gradient_checkpointing and self.training:
516
+
517
+ def create_custom_forward(module):
518
+ def custom_forward(*inputs):
519
+ return module(*inputs)
520
+
521
+ return custom_forward
522
+
523
+ hidden_states = torch.utils.checkpoint.checkpoint(
524
+ create_custom_forward(conv_layer),
525
+ hidden_states,
526
+ )
527
+ else:
528
+ hidden_states = conv_layer(hidden_states)
529
+
530
+ return hidden_states
531
+
532
+
533
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
534
+ class Wav2Vec2ConformerFeatureProjection(nn.Module):
535
+ def __init__(self, config):
536
+ super().__init__()
537
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
538
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
539
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
540
+
541
+ def forward(self, hidden_states):
542
+ # non-projected hidden states are needed for quantization
543
+ norm_hidden_states = self.layer_norm(hidden_states)
544
+ hidden_states = self.projection(norm_hidden_states)
545
+ hidden_states = self.dropout(hidden_states)
546
+ return hidden_states, norm_hidden_states
547
+
548
+
549
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
550
+ class Wav2Vec2ConformerFeedForward(nn.Module):
551
+ def __init__(self, config):
552
+ super().__init__()
553
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
554
+
555
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
556
+ if isinstance(config.hidden_act, str):
557
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
558
+ else:
559
+ self.intermediate_act_fn = config.hidden_act
560
+
561
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
562
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
563
+
564
+ def forward(self, hidden_states):
565
+ hidden_states = self.intermediate_dense(hidden_states)
566
+ hidden_states = self.intermediate_act_fn(hidden_states)
567
+ hidden_states = self.intermediate_dropout(hidden_states)
568
+
569
+ hidden_states = self.output_dense(hidden_states)
570
+ hidden_states = self.output_dropout(hidden_states)
571
+ return hidden_states
572
+
573
+
574
+ class Wav2Vec2ConformerConvolutionModule(nn.Module):
575
+ """Convolution block used in the conformer block"""
576
+
577
+ def __init__(self, config):
578
+ super().__init__()
579
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
580
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
581
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
582
+ self.pointwise_conv1 = torch.nn.Conv1d(
583
+ config.hidden_size,
584
+ 2 * config.hidden_size,
585
+ kernel_size=1,
586
+ stride=1,
587
+ padding=0,
588
+ bias=False,
589
+ )
590
+ self.glu = torch.nn.GLU(dim=1)
591
+ self.depthwise_conv = torch.nn.Conv1d(
592
+ config.hidden_size,
593
+ config.hidden_size,
594
+ config.conv_depthwise_kernel_size,
595
+ stride=1,
596
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
597
+ groups=config.hidden_size,
598
+ bias=False,
599
+ )
600
+ self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
601
+ self.activation = ACT2FN[config.hidden_act]
602
+ self.pointwise_conv2 = torch.nn.Conv1d(
603
+ config.hidden_size,
604
+ config.hidden_size,
605
+ kernel_size=1,
606
+ stride=1,
607
+ padding=0,
608
+ bias=False,
609
+ )
610
+ self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
611
+
612
+ def forward(self, hidden_states):
613
+ hidden_states = self.layer_norm(hidden_states)
614
+ # exchange the temporal dimension and the feature dimension
615
+ hidden_states = hidden_states.transpose(1, 2)
616
+
617
+ # GLU mechanism
618
+ # => (batch, 2*channel, dim)
619
+ hidden_states = self.pointwise_conv1(hidden_states)
620
+ # => (batch, channel, dim)
621
+ hidden_states = self.glu(hidden_states)
622
+
623
+ # 1D Depthwise Conv
624
+ hidden_states = self.depthwise_conv(hidden_states)
625
+ hidden_states = self.batch_norm(hidden_states)
626
+ hidden_states = self.activation(hidden_states)
627
+
628
+ hidden_states = self.pointwise_conv2(hidden_states)
629
+ hidden_states = self.dropout(hidden_states)
630
+ hidden_states = hidden_states.transpose(1, 2)
631
+ return hidden_states
632
+
633
+
634
+ class Wav2Vec2ConformerSelfAttention(nn.Module):
635
+ """Construct an Wav2Vec2ConformerSelfAttention object.
636
+ Can be enhanced with rotary or relative position embeddings.
637
+ """
638
+
639
+ def __init__(self, config):
640
+ super().__init__()
641
+
642
+ self.head_size = config.hidden_size // config.num_attention_heads
643
+ self.num_heads = config.num_attention_heads
644
+ self.position_embeddings_type = config.position_embeddings_type
645
+
646
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
647
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
648
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
649
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
650
+
651
+ self.dropout = nn.Dropout(p=config.attention_dropout)
652
+ self.dropout_p = config.attention_dropout
653
+
654
+ self.is_causal = config.is_causal
655
+
656
+ if self.position_embeddings_type == "relative":
657
+ # linear transformation for positional encoding
658
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
659
+ # these two learnable bias are used in matrix c and matrix d
660
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
661
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
662
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
663
+
664
+ def forward(
665
+ self,
666
+ hidden_states: torch.Tensor,
667
+ attention_mask: Optional[torch.Tensor] = None,
668
+ relative_position_embeddings: Optional[torch.Tensor] = None,
669
+ output_attentions: bool = False,
670
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
671
+ # self-attention mechanism
672
+ batch_size, sequence_length, hidden_size = hidden_states.size()
673
+
674
+ # make sure query/key states can be != value states
675
+ query_key_states = hidden_states
676
+ value_states = hidden_states
677
+
678
+ if self.position_embeddings_type == "rotary":
679
+ if relative_position_embeddings is None:
680
+ raise ValueError(
681
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
682
+ )
683
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
684
+
685
+ # project query_key_states and value_states
686
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
687
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
688
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
689
+
690
+ # => (batch, head, time1, d_k)
691
+ query = query.transpose(1, 2)
692
+ key = key.transpose(1, 2)
693
+ value = value.transpose(1, 2)
694
+
695
+ with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
696
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
697
+ probs = None
698
+
699
+ # # apply attention_mask if necessary
700
+ # if attention_mask is not None:
701
+ # scores = scores + attention_mask
702
+
703
+ # # => (batch, head, time1, time2)
704
+ # probs = torch.softmax(scores, dim=-1)
705
+ # probs = self.dropout(probs)
706
+
707
+ # # => (batch, head, time1, d_k)
708
+ # hidden_states = torch.matmul(probs, value)
709
+
710
+ # => (batch, time1, hidden_size)
711
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
712
+ hidden_states = self.linear_out(hidden_states)
713
+
714
+ return hidden_states, probs
715
+
716
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
717
+ batch_size, sequence_length, hidden_size = hidden_states.size()
718
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
719
+
720
+ cos = relative_position_embeddings[0, :sequence_length, ...]
721
+ sin = relative_position_embeddings[1, :sequence_length, ...]
722
+
723
+ # rotate hidden_states with rotary embeddings
724
+ hidden_states = hidden_states.transpose(0, 1)
725
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
726
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
727
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
728
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
729
+ hidden_states = hidden_states.transpose(0, 1)
730
+
731
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
732
+
733
+ return hidden_states
734
+
735
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
736
+ # 1. project positional embeddings
737
+ # => (batch, head, 2*time1-1, d_k)
738
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
739
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
740
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
741
+ )
742
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
743
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
744
+
745
+ # 2. Add bias to query
746
+ # => (batch, head, time1, d_k)
747
+ query = query.transpose(1, 2)
748
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
749
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
750
+
751
+ # 3. attention score: first compute matrix a and matrix c
752
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
753
+ # => (batch, head, time1, time2)
754
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
755
+
756
+ # 4. then compute matrix b and matrix d
757
+ # => (batch, head, time1, 2*time1-1)
758
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
759
+
760
+ # 5. shift matrix b and matrix d
761
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
762
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
763
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
764
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
765
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
766
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
767
+
768
+ # 6. sum matrices
769
+ # => (batch, head, time1, time2)
770
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
771
+
772
+ return scores
773
+
774
+
775
+ class Wav2Vec2ConformerEncoderLayer(nn.Module):
776
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
777
+
778
+ def __init__(self, config):
779
+ super().__init__()
780
+ embed_dim = config.hidden_size
781
+ dropout = config.attention_dropout
782
+
783
+ # Feed-forward 1
784
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
785
+ self.ffn1 = Wav2Vec2ConformerFeedForward(config)
786
+
787
+ # Self-Attention
788
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
789
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
790
+ self.self_attn = Wav2Vec2ConformerSelfAttention(config)
791
+
792
+ # Conformer Convolution
793
+ self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
794
+
795
+ # Feed-forward 2
796
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
797
+ self.ffn2 = Wav2Vec2ConformerFeedForward(config)
798
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
799
+
800
+ def forward(
801
+ self,
802
+ hidden_states,
803
+ attention_mask: Optional[torch.Tensor] = None,
804
+ relative_position_embeddings: Optional[torch.Tensor] = None,
805
+ output_attentions: bool = False,
806
+ ):
807
+ hidden_states = hidden_states
808
+
809
+ # 1. Feed-Forward 1 layer
810
+ residual = hidden_states
811
+ hidden_states = self.ffn1_layer_norm(hidden_states)
812
+ hidden_states = self.ffn1(hidden_states)
813
+ hidden_states = hidden_states * 0.5 + residual
814
+ residual = hidden_states
815
+
816
+ # 2. Self-Attention layer
817
+ hidden_states = self.self_attn_layer_norm(hidden_states)
818
+ hidden_states, attn_weigts = self.self_attn(
819
+ hidden_states=hidden_states,
820
+ attention_mask=attention_mask,
821
+ relative_position_embeddings=relative_position_embeddings,
822
+ output_attentions=output_attentions,
823
+ )
824
+ hidden_states = self.self_attn_dropout(hidden_states)
825
+ hidden_states = hidden_states + residual
826
+
827
+ # 3. Convolutional Layer
828
+ residual = hidden_states
829
+ hidden_states = self.conv_module(hidden_states)
830
+ hidden_states = residual + hidden_states
831
+
832
+ # 4. Feed-Forward 2 Layer
833
+ residual = hidden_states
834
+ hidden_states = self.ffn2_layer_norm(hidden_states)
835
+ hidden_states = self.ffn2(hidden_states)
836
+ hidden_states = hidden_states * 0.5 + residual
837
+ hidden_states = self.final_layer_norm(hidden_states)
838
+
839
+ return hidden_states, attn_weigts
840
+
841
+
842
+ class Wav2Vec2ConformerEncoder(nn.Module):
843
+ def __init__(self, config, is_causal=False):
844
+ super().__init__()
845
+ config.is_causal = is_causal
846
+ self.config = config
847
+
848
+ if config.position_embeddings_type == "relative":
849
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
850
+ elif config.position_embeddings_type == "rotary":
851
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
852
+ else:
853
+ self.embed_positions = None
854
+
855
+ self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
856
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
857
+ self.dropout = nn.Dropout(config.hidden_dropout)
858
+ self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
859
+ self.gradient_checkpointing = False
860
+
861
+ def forward(
862
+ self,
863
+ hidden_states,
864
+ attention_mask=None,
865
+ output_attentions=False,
866
+ output_hidden_states=False,
867
+ return_dict=True,
868
+ ):
869
+ all_hidden_states = () if output_hidden_states else None
870
+ all_self_attentions = () if output_attentions else None
871
+
872
+ if attention_mask is not None:
873
+ # make sure padded tokens output 0
874
+ hidden_states[~attention_mask] = 0.0
875
+
876
+ # extend attention_mask
877
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
878
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
879
+ attention_mask = attention_mask.expand(
880
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
881
+ )
882
+
883
+ hidden_states = self.dropout(hidden_states)
884
+
885
+ if self.embed_positions is not None:
886
+ relative_position_embeddings = self.embed_positions(hidden_states)
887
+ else:
888
+ relative_position_embeddings = None
889
+
890
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
891
+
892
+ for i, layer in enumerate(self.layers):
893
+ if output_hidden_states:
894
+ all_hidden_states = all_hidden_states + (hidden_states,)
895
+
896
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
897
+ dropout_probability = np.random.uniform(0, 1)
898
+
899
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
900
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
901
+ # under deepspeed zero3 all gpus must run in sync
902
+ if self.gradient_checkpointing and self.training:
903
+ # create gradient checkpointing function
904
+ def create_custom_forward(module):
905
+ def custom_forward(*inputs):
906
+ return module(*inputs, output_attentions)
907
+
908
+ return custom_forward
909
+
910
+ layer_outputs = torch.utils.checkpoint.checkpoint(
911
+ create_custom_forward(layer),
912
+ hidden_states,
913
+ attention_mask,
914
+ relative_position_embeddings,
915
+ )
916
+ else:
917
+ layer_outputs = layer(
918
+ hidden_states,
919
+ attention_mask=attention_mask,
920
+ relative_position_embeddings=relative_position_embeddings,
921
+ output_attentions=output_attentions,
922
+ )
923
+ hidden_states = layer_outputs[0]
924
+
925
+ if skip_the_layer:
926
+ layer_outputs = (None, None)
927
+
928
+ if output_attentions:
929
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
930
+
931
+ hidden_states = self.layer_norm(hidden_states)
932
+ if output_hidden_states:
933
+ all_hidden_states = all_hidden_states + (hidden_states,)
934
+
935
+ if not return_dict:
936
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
937
+ return BaseModelOutput(
938
+ last_hidden_state=hidden_states,
939
+ hidden_states=all_hidden_states,
940
+ attentions=all_self_attentions,
941
+ )
942
+
943
+
944
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
945
+ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
946
+ """
947
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
948
+ GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
949
+ """
950
+
951
+ def __init__(self, config):
952
+ super().__init__()
953
+ self.num_groups = config.num_codevector_groups
954
+ self.num_vars = config.num_codevectors_per_group
955
+
956
+ if config.codevector_dim % self.num_groups != 0:
957
+ raise ValueError(
958
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
959
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
960
+ )
961
+
962
+ # storage for codebook variables (codewords)
963
+ self.codevectors = nn.Parameter(
964
+ torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
965
+ )
966
+ self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
967
+
968
+ # can be decayed for training
969
+ self.temperature = 2
970
+
971
+ @staticmethod
972
+ def _compute_perplexity(probs, mask=None):
973
+ if mask is not None:
974
+ mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
975
+ probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
976
+ marginal_probs = probs.sum(dim=0) / mask.sum()
977
+ else:
978
+ marginal_probs = probs.mean(dim=0)
979
+
980
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
981
+ return perplexity
982
+
983
+ def forward(self, hidden_states, mask_time_indices=None):
984
+ batch_size, sequence_length, hidden_size = hidden_states.shape
985
+
986
+ # project to codevector dim
987
+ hidden_states = self.weight_proj(hidden_states)
988
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
989
+
990
+ if self.training:
991
+ # sample code vector probs via gumbel in differentiateable way
992
+ codevector_probs = nn.functional.gumbel_softmax(
993
+ hidden_states.float(), tau=self.temperature, hard=True
994
+ ).type_as(hidden_states)
995
+
996
+ # compute perplexity
997
+ codevector_soft_dist = torch.softmax(
998
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
999
+ )
1000
+ perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
1001
+ else:
1002
+ # take argmax in non-differentiable way
1003
+ # comptute hard codevector distribution (one hot)
1004
+ codevector_idx = hidden_states.argmax(dim=-1)
1005
+ codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
1006
+ -1, codevector_idx.view(-1, 1), 1.0
1007
+ )
1008
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
1009
+
1010
+ perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
1011
+
1012
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
1013
+ # use probs to retrieve codevectors
1014
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
1015
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
1016
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
1017
+
1018
+ return codevectors, perplexity
1019
+
1020
+
1021
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
1022
+ class Wav2Vec2ConformerAdapter(nn.Module):
1023
+ def __init__(self, config):
1024
+ super().__init__()
1025
+
1026
+ # feature dim might need to be down-projected
1027
+ if config.output_hidden_size != config.hidden_size:
1028
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
1029
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
1030
+ else:
1031
+ self.proj = self.proj_layer_norm = None
1032
+
1033
+ self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
1034
+ self.layerdrop = config.layerdrop
1035
+
1036
+ def forward(self, hidden_states):
1037
+ # down project hidden_states if necessary
1038
+ if self.proj is not None and self.proj_layer_norm is not None:
1039
+ hidden_states = self.proj(hidden_states)
1040
+ hidden_states = self.proj_layer_norm(hidden_states)
1041
+
1042
+ hidden_states = hidden_states.transpose(1, 2)
1043
+
1044
+ for layer in self.layers:
1045
+ layerdrop_prob = np.random.random()
1046
+ if not self.training or (layerdrop_prob > self.layerdrop):
1047
+ hidden_states = layer(hidden_states)
1048
+
1049
+ hidden_states = hidden_states.transpose(1, 2)
1050
+ return hidden_states
1051
+
1052
+
1053
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
1054
+ class Wav2Vec2ConformerAdapterLayer(nn.Module):
1055
+ def __init__(self, config):
1056
+ super().__init__()
1057
+ self.conv = nn.Conv1d(
1058
+ config.output_hidden_size,
1059
+ 2 * config.output_hidden_size,
1060
+ config.adapter_kernel_size,
1061
+ stride=config.adapter_stride,
1062
+ padding=1,
1063
+ )
1064
+
1065
+ def forward(self, hidden_states):
1066
+ hidden_states = self.conv(hidden_states)
1067
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
1068
+
1069
+ return hidden_states
1070
+
1071
+
1072
+ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
1073
+ """
1074
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1075
+ models.
1076
+ """
1077
+
1078
+ config_class = Wav2Vec2ConformerConfig
1079
+ base_model_prefix = "wav2vec2_conformer"
1080
+ main_input_name = "input_values"
1081
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1082
+ supports_gradient_checkpointing = True
1083
+
1084
+ def _init_weights(self, module):
1085
+ """Initialize the weights"""
1086
+ # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
1087
+ if isinstance(module, Wav2Vec2ConformerForPreTraining):
1088
+ module.project_hid.reset_parameters()
1089
+ module.project_q.reset_parameters()
1090
+ module.project_hid._is_hf_initialized = True
1091
+ module.project_q._is_hf_initialized = True
1092
+ # gumbel softmax requires special init
1093
+ elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
1094
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
1095
+ module.weight_proj.bias.data.zero_()
1096
+ nn.init.uniform_(module.codevectors)
1097
+ elif isinstance(module, Wav2Vec2ConformerSelfAttention):
1098
+ if hasattr(module, "pos_bias_u"):
1099
+ nn.init.xavier_uniform_(module.pos_bias_u)
1100
+ if hasattr(module, "pos_bias_v"):
1101
+ nn.init.xavier_uniform_(module.pos_bias_v)
1102
+ elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
1103
+ nn.init.normal_(
1104
+ module.conv.weight,
1105
+ mean=0,
1106
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
1107
+ )
1108
+ nn.init.constant_(module.conv.bias, 0)
1109
+ elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
1110
+ k = math.sqrt(1 / module.projection.in_features)
1111
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
1112
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
1113
+ elif isinstance(module, nn.Linear):
1114
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1115
+
1116
+ if module.bias is not None:
1117
+ module.bias.data.zero_()
1118
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1119
+ module.bias.data.zero_()
1120
+ module.weight.data.fill_(1.0)
1121
+ elif isinstance(module, nn.Conv1d):
1122
+ nn.init.kaiming_normal_(module.weight)
1123
+
1124
+ if module.bias is not None:
1125
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1126
+ nn.init.uniform_(module.bias, a=-k, b=k)
1127
+
1128
+ def _get_feat_extract_output_lengths(
1129
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
1130
+ ):
1131
+ """
1132
+ Computes the output length of the convolutional layers
1133
+ """
1134
+
1135
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
1136
+
1137
+ def _conv_out_length(input_length, kernel_size, stride):
1138
+ # 1D convolutional layer output length formula taken
1139
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
1140
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
1141
+
1142
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
1143
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
1144
+
1145
+ if add_adapter:
1146
+ for _ in range(self.config.num_adapter_layers):
1147
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
1148
+
1149
+ return input_lengths
1150
+
1151
+ def _get_feature_vector_attention_mask(
1152
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
1153
+ ):
1154
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
1155
+ # on inference mode.
1156
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
1157
+
1158
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
1159
+ output_lengths = output_lengths.to(torch.long)
1160
+
1161
+ batch_size = attention_mask.shape[0]
1162
+
1163
+ attention_mask = torch.zeros(
1164
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
1165
+ )
1166
+ # these two operations makes sure that all values before the output lengths idxs are attended to
1167
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
1168
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
1169
+ return attention_mask
1170
+
1171
+ def _set_gradient_checkpointing(self, module, value=False):
1172
+ if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
1173
+ module.gradient_checkpointing = value
1174
+
1175
+
1176
+ WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
1177
+ Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
1178
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
1179
+ Auli.
1180
+
1181
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1182
+ library implements for all its model (such as downloading or saving etc.).
1183
+
1184
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
1185
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
1186
+
1187
+ Parameters:
1188
+ config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
1189
+ Initializing with a config file does not load the weights associated with the model, only the
1190
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1191
+ """
1192
+
1193
+
1194
+ WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
1195
+ Args:
1196
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1197
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
1198
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
1199
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
1200
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
1201
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1202
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1203
+ 1]`:
1204
+
1205
+ - 1 for tokens that are **not masked**,
1206
+ - 0 for tokens that are **masked**.
1207
+
1208
+ [What are attention masks?](../glossary#attention-mask)
1209
+
1210
+ <Tip warning={true}>
1211
+
1212
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
1213
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
1214
+ [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
1215
+ `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
1216
+ such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
1217
+ that these models also yield slightly different results depending on whether `input_values` is padded or
1218
+ not.
1219
+
1220
+ </Tip>
1221
+
1222
+ output_attentions (`bool`, *optional*):
1223
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1224
+ tensors for more detail.
1225
+ output_hidden_states (`bool`, *optional*):
1226
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1227
+ more detail.
1228
+ return_dict (`bool`, *optional*):
1229
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1230
+ """
1231
+
1232
+
1233
+ @add_start_docstrings(
1234
+ "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
1235
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1236
+ )
1237
+ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
1238
+ def __init__(self, config: Wav2Vec2ConformerConfig):
1239
+ super().__init__(config)
1240
+ self.config = config
1241
+ self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
1242
+ self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
1243
+
1244
+ # model only needs masking vector if mask prob is > 0.0
1245
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
1246
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
1247
+
1248
+ self.encoder = Wav2Vec2ConformerEncoder(config)
1249
+
1250
+ self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
1251
+
1252
+ # Initialize weights and apply final processing
1253
+ self.post_init()
1254
+
1255
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
1256
+ def freeze_feature_encoder(self):
1257
+ """
1258
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1259
+ not be updated during training.
1260
+ """
1261
+ self.feature_extractor._freeze_parameters()
1262
+
1263
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
1264
+ def _mask_hidden_states(
1265
+ self,
1266
+ hidden_states: torch.FloatTensor,
1267
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1268
+ attention_mask: Optional[torch.LongTensor] = None,
1269
+ ):
1270
+ """
1271
+ Masks extracted features along time axis and/or along feature axis according to
1272
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
1273
+ """
1274
+
1275
+ # `config.apply_spec_augment` can set masking to False
1276
+ if not getattr(self.config, "apply_spec_augment", True):
1277
+ return hidden_states
1278
+
1279
+ # generate indices & apply SpecAugment along time axis
1280
+ batch_size, sequence_length, hidden_size = hidden_states.size()
1281
+
1282
+ if mask_time_indices is not None:
1283
+ # apply SpecAugment along time axis with given mask_time_indices
1284
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1285
+ elif self.config.mask_time_prob > 0 and self.training:
1286
+ mask_time_indices = _compute_mask_indices(
1287
+ (batch_size, sequence_length),
1288
+ mask_prob=self.config.mask_time_prob,
1289
+ mask_length=self.config.mask_time_length,
1290
+ attention_mask=attention_mask,
1291
+ min_masks=self.config.mask_time_min_masks,
1292
+ )
1293
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
1294
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1295
+
1296
+ if self.config.mask_feature_prob > 0 and self.training:
1297
+ # generate indices & apply SpecAugment along feature axis
1298
+ mask_feature_indices = _compute_mask_indices(
1299
+ (batch_size, hidden_size),
1300
+ mask_prob=self.config.mask_feature_prob,
1301
+ mask_length=self.config.mask_feature_length,
1302
+ min_masks=self.config.mask_feature_min_masks,
1303
+ )
1304
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
1305
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
1306
+ hidden_states[mask_feature_indices] = 0
1307
+
1308
+ return hidden_states
1309
+
1310
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1311
+ @add_code_sample_docstrings(
1312
+ checkpoint=_CHECKPOINT_FOR_DOC,
1313
+ output_type=Wav2Vec2BaseModelOutput,
1314
+ config_class=_CONFIG_FOR_DOC,
1315
+ modality="audio",
1316
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1317
+ )
1318
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
1319
+ def forward(
1320
+ self,
1321
+ input_values: Optional[torch.Tensor],
1322
+ attention_mask: Optional[torch.Tensor] = None,
1323
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1324
+ output_attentions: Optional[bool] = None,
1325
+ output_hidden_states: Optional[bool] = None,
1326
+ return_dict: Optional[bool] = None,
1327
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
1328
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1329
+ output_hidden_states = (
1330
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1331
+ )
1332
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1333
+
1334
+ extract_features = self.feature_extractor(input_values)
1335
+ extract_features = extract_features.transpose(1, 2)
1336
+
1337
+ if attention_mask is not None:
1338
+ # compute reduced attention_mask corresponding to feature vectors
1339
+ attention_mask = self._get_feature_vector_attention_mask(
1340
+ extract_features.shape[1], attention_mask, add_adapter=False
1341
+ )
1342
+
1343
+ hidden_states, extract_features = self.feature_projection(extract_features)
1344
+ hidden_states = self._mask_hidden_states(
1345
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
1346
+ )
1347
+
1348
+ encoder_outputs = self.encoder(
1349
+ hidden_states,
1350
+ attention_mask=attention_mask,
1351
+ output_attentions=output_attentions,
1352
+ output_hidden_states=output_hidden_states,
1353
+ return_dict=return_dict,
1354
+ )
1355
+
1356
+ hidden_states = encoder_outputs[0]
1357
+
1358
+ if self.adapter is not None:
1359
+ hidden_states = self.adapter(hidden_states)
1360
+
1361
+ if not return_dict:
1362
+ return (hidden_states, extract_features) + encoder_outputs[1:]
1363
+
1364
+ return Wav2Vec2BaseModelOutput(
1365
+ last_hidden_state=hidden_states,
1366
+ extract_features=extract_features,
1367
+ hidden_states=encoder_outputs.hidden_states,
1368
+ attentions=encoder_outputs.attentions,
1369
+ )
1370
+
1371
+
1372
+ @add_start_docstrings(
1373
+ """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
1374
+ )
1375
+ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
1376
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1377
+ def __init__(self, config: Wav2Vec2ConformerConfig):
1378
+ super().__init__(config)
1379
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1380
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
1381
+
1382
+ self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
1383
+
1384
+ self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
1385
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
1386
+
1387
+ # Initialize weights and apply final processing
1388
+ self.post_init()
1389
+
1390
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
1391
+ def set_gumbel_temperature(self, temperature: int):
1392
+ """
1393
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
1394
+ """
1395
+ self.quantizer.temperature = temperature
1396
+
1397
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1398
+ def freeze_feature_encoder(self):
1399
+ """
1400
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1401
+ not be updated during training.
1402
+ """
1403
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1404
+
1405
+ @staticmethod
1406
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
1407
+ def compute_contrastive_logits(
1408
+ target_features: torch.FloatTensor,
1409
+ negative_features: torch.FloatTensor,
1410
+ predicted_features: torch.FloatTensor,
1411
+ temperature: int = 0.1,
1412
+ ):
1413
+ """
1414
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
1415
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
1416
+ """
1417
+ target_features = torch.cat([target_features, negative_features], dim=0)
1418
+
1419
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
1420
+ target_features
1421
+ )
1422
+
1423
+ # apply temperature
1424
+ logits = logits / temperature
1425
+ return logits
1426
+
1427
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1428
+ @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1429
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
1430
+ def forward(
1431
+ self,
1432
+ input_values: Optional[torch.Tensor],
1433
+ attention_mask: Optional[torch.Tensor] = None,
1434
+ mask_time_indices: Optional[torch.BoolTensor] = None,
1435
+ sampled_negative_indices: Optional[torch.BoolTensor] = None,
1436
+ output_attentions: Optional[bool] = None,
1437
+ output_hidden_states: Optional[bool] = None,
1438
+ return_dict: Optional[bool] = None,
1439
+ ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
1440
+ r"""
1441
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
1442
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
1443
+ masked extracted features in *config.proj_codevector_dim* space.
1444
+ sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
1445
+ Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
1446
+ Required input for pre-training.
1447
+
1448
+ Returns:
1449
+
1450
+ Example:
1451
+
1452
+ ```python
1453
+ >>> import torch
1454
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
1455
+ >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
1456
+ ... _compute_mask_indices,
1457
+ ... _sample_negative_indices,
1458
+ ... )
1459
+ >>> from datasets import load_dataset
1460
+
1461
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1462
+ >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1463
+
1464
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1465
+ >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
1466
+
1467
+ >>> # compute masked indices
1468
+ >>> batch_size, raw_sequence_length = input_values.shape
1469
+ >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
1470
+ >>> mask_time_indices = _compute_mask_indices(
1471
+ ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
1472
+ ... )
1473
+ >>> sampled_negative_indices = _sample_negative_indices(
1474
+ ... features_shape=(batch_size, sequence_length),
1475
+ ... num_negatives=model.config.num_negatives,
1476
+ ... mask_time_indices=mask_time_indices,
1477
+ ... )
1478
+ >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
1479
+ >>> sampled_negative_indices = torch.tensor(
1480
+ ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
1481
+ ... )
1482
+
1483
+ >>> with torch.no_grad():
1484
+ ... outputs = model(input_values, mask_time_indices=mask_time_indices)
1485
+
1486
+ >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
1487
+ >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
1488
+
1489
+ >>> # show that cosine similarity is much higher than random
1490
+ >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
1491
+ tensor(True)
1492
+
1493
+ >>> # for contrastive loss training model should be put into train mode
1494
+ >>> model = model.train()
1495
+ >>> loss = model(
1496
+ ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
1497
+ ... ).loss
1498
+ ```"""
1499
+
1500
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1501
+
1502
+ if mask_time_indices is not None:
1503
+ mask_time_indices = mask_time_indices.to(torch.bool)
1504
+
1505
+ outputs = self.wav2vec2_conformer(
1506
+ input_values,
1507
+ attention_mask=attention_mask,
1508
+ output_attentions=output_attentions,
1509
+ output_hidden_states=output_hidden_states,
1510
+ mask_time_indices=mask_time_indices,
1511
+ return_dict=return_dict,
1512
+ )
1513
+
1514
+ # 1. project all transformed features (including masked) to final vq dim
1515
+ transformer_features = self.project_hid(outputs[0])
1516
+
1517
+ # 2. quantize all (unmasked) extracted features and project to final vq dim
1518
+ extract_features = self.dropout_features(outputs[1])
1519
+
1520
+ if attention_mask is not None:
1521
+ # compute reduced attention_mask correponding to feature vectors
1522
+ attention_mask = self._get_feature_vector_attention_mask(
1523
+ extract_features.shape[1], attention_mask, add_adapter=False
1524
+ )
1525
+
1526
+ quantized_features, codevector_perplexity = self.quantizer(
1527
+ extract_features, mask_time_indices=mask_time_indices
1528
+ )
1529
+ quantized_features = self.project_q(quantized_features)
1530
+
1531
+ loss = contrastive_loss = diversity_loss = None
1532
+ if sampled_negative_indices is not None:
1533
+ batch_size, sequence_length, hidden_size = quantized_features.shape
1534
+
1535
+ # for training, we sample negatives
1536
+ # 3. sample K negatives (distractors) quantized states for contrastive loss
1537
+ # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
1538
+ # sample negative quantized vectors BTC => (BxT)C
1539
+ negative_quantized_features = quantized_features.view(-1, hidden_size)[
1540
+ sampled_negative_indices.long().view(-1)
1541
+ ]
1542
+ negative_quantized_features = negative_quantized_features.view(
1543
+ batch_size, sequence_length, -1, hidden_size
1544
+ ).permute(2, 0, 1, 3)
1545
+
1546
+ # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
1547
+ # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
1548
+ logits = self.compute_contrastive_logits(
1549
+ quantized_features[None, :],
1550
+ negative_quantized_features,
1551
+ transformer_features,
1552
+ self.config.contrastive_logits_temperature,
1553
+ )
1554
+
1555
+ # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
1556
+ # its cosine similarity will be masked
1557
+ neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
1558
+
1559
+ if neg_is_pos.any():
1560
+ logits[1:][neg_is_pos] = float("-inf")
1561
+
1562
+ # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
1563
+ # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
1564
+ logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
1565
+ target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
1566
+
1567
+ contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
1568
+ # 7. compute diversity loss: \mathbf{L}_d
1569
+ num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
1570
+ diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
1571
+
1572
+ # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
1573
+ loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
1574
+
1575
+ if not return_dict:
1576
+ if loss is not None:
1577
+ return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1578
+ return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1579
+
1580
+ return Wav2Vec2ConformerForPreTrainingOutput(
1581
+ loss=loss,
1582
+ projected_states=transformer_features,
1583
+ projected_quantized_states=quantized_features,
1584
+ codevector_perplexity=codevector_perplexity,
1585
+ hidden_states=outputs.hidden_states,
1586
+ attentions=outputs.attentions,
1587
+ contrastive_loss=contrastive_loss,
1588
+ diversity_loss=diversity_loss,
1589
+ )
1590
+
1591
+
1592
+ @add_start_docstrings(
1593
+ """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
1594
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1595
+ )
1596
+ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
1597
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1598
+ def __init__(self, config):
1599
+ super().__init__(config)
1600
+
1601
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1602
+ self.dropout = nn.Dropout(config.final_dropout)
1603
+
1604
+ if config.vocab_size is None:
1605
+ raise ValueError(
1606
+ f"You are trying to instantiate {self.__class__} with a configuration that "
1607
+ "does not define the vocabulary size of the language model head. Please "
1608
+ "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
1609
+ "or define `vocab_size` of your model's configuration."
1610
+ )
1611
+ output_hidden_size = (
1612
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
1613
+ )
1614
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
1615
+
1616
+ # Initialize weights and apply final processing
1617
+ self.post_init()
1618
+
1619
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1620
+ def freeze_feature_encoder(self):
1621
+ """
1622
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1623
+ not be updated during training.
1624
+ """
1625
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1626
+
1627
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1628
+ @add_code_sample_docstrings(
1629
+ checkpoint=_CHECKPOINT_FOR_DOC,
1630
+ output_type=CausalLMOutput,
1631
+ config_class=_CONFIG_FOR_DOC,
1632
+ expected_output=_CTC_EXPECTED_OUTPUT,
1633
+ expected_loss=_CTC_EXPECTED_LOSS,
1634
+ )
1635
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1636
+ def forward(
1637
+ self,
1638
+ input_values: Optional[torch.Tensor],
1639
+ attention_mask: Optional[torch.Tensor] = None,
1640
+ output_attentions: Optional[bool] = None,
1641
+ output_hidden_states: Optional[bool] = None,
1642
+ return_dict: Optional[bool] = None,
1643
+ labels: Optional[torch.Tensor] = None,
1644
+ ) -> Union[Tuple, CausalLMOutput]:
1645
+ r"""
1646
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
1647
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
1648
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
1649
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
1650
+ config.vocab_size - 1]`.
1651
+ """
1652
+
1653
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1654
+
1655
+ outputs = self.wav2vec2_conformer(
1656
+ input_values,
1657
+ attention_mask=attention_mask,
1658
+ output_attentions=output_attentions,
1659
+ output_hidden_states=output_hidden_states,
1660
+ return_dict=return_dict,
1661
+ )
1662
+
1663
+ hidden_states = outputs[0]
1664
+ hidden_states = self.dropout(hidden_states)
1665
+
1666
+ logits = self.lm_head(hidden_states)
1667
+
1668
+ loss = None
1669
+ if labels is not None:
1670
+ if labels.max() >= self.config.vocab_size:
1671
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
1672
+
1673
+ # retrieve loss input_lengths from attention_mask
1674
+ attention_mask = (
1675
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
1676
+ )
1677
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1678
+
1679
+ # assuming that padded tokens are filled with -100
1680
+ # when not being attended to
1681
+ labels_mask = labels >= 0
1682
+ target_lengths = labels_mask.sum(-1)
1683
+ flattened_targets = labels.masked_select(labels_mask)
1684
+
1685
+ # ctc_loss doesn't support fp16
1686
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
1687
+
1688
+ with torch.backends.cudnn.flags(enabled=False):
1689
+ loss = nn.functional.ctc_loss(
1690
+ log_probs,
1691
+ flattened_targets,
1692
+ input_lengths,
1693
+ target_lengths,
1694
+ blank=self.config.pad_token_id,
1695
+ reduction=self.config.ctc_loss_reduction,
1696
+ zero_infinity=self.config.ctc_zero_infinity,
1697
+ )
1698
+
1699
+ if not return_dict:
1700
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1701
+ return ((loss,) + output) if loss is not None else output
1702
+
1703
+ return CausalLMOutput(
1704
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1705
+ )
1706
+
1707
+
1708
+ @add_start_docstrings(
1709
+ """
1710
+ Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
1711
+ tasks like SUPERB Keyword Spotting.
1712
+ """,
1713
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1714
+ )
1715
+ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
1716
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1717
+ def __init__(self, config):
1718
+ super().__init__(config)
1719
+
1720
+ if hasattr(config, "add_adapter") and config.add_adapter:
1721
+ raise ValueError(
1722
+ "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1723
+ )
1724
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1725
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1726
+ if config.use_weighted_layer_sum:
1727
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1728
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
1729
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
1730
+
1731
+ # Initialize weights and apply final processing
1732
+ self.post_init()
1733
+
1734
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1735
+ def freeze_feature_encoder(self):
1736
+ """
1737
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1738
+ not be updated during training.
1739
+ """
1740
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1741
+
1742
+ def freeze_base_model(self):
1743
+ """
1744
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1745
+ be updated during training. Only the classification head will be updated.
1746
+ """
1747
+ for param in self.wav2vec2_conformer.parameters():
1748
+ param.requires_grad = False
1749
+
1750
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1751
+ @add_code_sample_docstrings(
1752
+ checkpoint=_CHECKPOINT_FOR_DOC,
1753
+ output_type=SequenceClassifierOutput,
1754
+ config_class=_CONFIG_FOR_DOC,
1755
+ modality="audio",
1756
+ )
1757
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1758
+ def forward(
1759
+ self,
1760
+ input_values: Optional[torch.Tensor],
1761
+ attention_mask: Optional[torch.Tensor] = None,
1762
+ output_attentions: Optional[bool] = None,
1763
+ output_hidden_states: Optional[bool] = None,
1764
+ return_dict: Optional[bool] = None,
1765
+ labels: Optional[torch.Tensor] = None,
1766
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1767
+ r"""
1768
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1769
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1770
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1771
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1772
+ """
1773
+
1774
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1775
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1776
+
1777
+ outputs = self.wav2vec2_conformer(
1778
+ input_values,
1779
+ attention_mask=attention_mask,
1780
+ output_attentions=output_attentions,
1781
+ output_hidden_states=output_hidden_states,
1782
+ return_dict=return_dict,
1783
+ )
1784
+
1785
+ if self.config.use_weighted_layer_sum:
1786
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1787
+ hidden_states = torch.stack(hidden_states, dim=1)
1788
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1789
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1790
+ else:
1791
+ hidden_states = outputs[0]
1792
+
1793
+ hidden_states = self.projector(hidden_states)
1794
+ if attention_mask is None:
1795
+ pooled_output = hidden_states.mean(dim=1)
1796
+ else:
1797
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1798
+ hidden_states[~padding_mask] = 0.0
1799
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
1800
+
1801
+ logits = self.classifier(pooled_output)
1802
+
1803
+ loss = None
1804
+ if labels is not None:
1805
+ loss_fct = CrossEntropyLoss()
1806
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1807
+
1808
+ if not return_dict:
1809
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1810
+ return ((loss,) + output) if loss is not None else output
1811
+
1812
+ return SequenceClassifierOutput(
1813
+ loss=loss,
1814
+ logits=logits,
1815
+ hidden_states=outputs.hidden_states,
1816
+ attentions=outputs.attentions,
1817
+ )
1818
+
1819
+
1820
+ @add_start_docstrings(
1821
+ """
1822
+ Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
1823
+ """,
1824
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1825
+ )
1826
+ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
1827
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1828
+ def __init__(self, config):
1829
+ super().__init__(config)
1830
+
1831
+ if hasattr(config, "add_adapter") and config.add_adapter:
1832
+ raise ValueError(
1833
+ "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1834
+ )
1835
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1836
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1837
+ if config.use_weighted_layer_sum:
1838
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1839
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1840
+ self.num_labels = config.num_labels
1841
+
1842
+ self.init_weights()
1843
+
1844
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1845
+ def freeze_feature_encoder(self):
1846
+ """
1847
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1848
+ not be updated during training.
1849
+ """
1850
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1851
+
1852
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
1853
+ def freeze_base_model(self):
1854
+ """
1855
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1856
+ be updated during training. Only the classification head will be updated.
1857
+ """
1858
+ for param in self.wav2vec2_conformer.parameters():
1859
+ param.requires_grad = False
1860
+
1861
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1862
+ @add_code_sample_docstrings(
1863
+ checkpoint=_CHECKPOINT_FOR_DOC,
1864
+ output_type=TokenClassifierOutput,
1865
+ config_class=_CONFIG_FOR_DOC,
1866
+ modality="audio",
1867
+ )
1868
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
1869
+ def forward(
1870
+ self,
1871
+ input_values: Optional[torch.Tensor],
1872
+ attention_mask: Optional[torch.Tensor] = None,
1873
+ labels: Optional[torch.Tensor] = None,
1874
+ output_attentions: Optional[bool] = None,
1875
+ output_hidden_states: Optional[bool] = None,
1876
+ return_dict: Optional[bool] = None,
1877
+ ) -> Union[Tuple, TokenClassifierOutput]:
1878
+ r"""
1879
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1880
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1881
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1882
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1883
+ """
1884
+
1885
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1886
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1887
+
1888
+ outputs = self.wav2vec2_conformer(
1889
+ input_values,
1890
+ attention_mask=attention_mask,
1891
+ output_attentions=output_attentions,
1892
+ output_hidden_states=output_hidden_states,
1893
+ return_dict=return_dict,
1894
+ )
1895
+
1896
+ if self.config.use_weighted_layer_sum:
1897
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1898
+ hidden_states = torch.stack(hidden_states, dim=1)
1899
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1900
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1901
+ else:
1902
+ hidden_states = outputs[0]
1903
+
1904
+ logits = self.classifier(hidden_states)
1905
+
1906
+ loss = None
1907
+ if labels is not None:
1908
+ loss_fct = CrossEntropyLoss()
1909
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
1910
+
1911
+ if not return_dict:
1912
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1913
+ return output
1914
+
1915
+ return TokenClassifierOutput(
1916
+ loss=loss,
1917
+ logits=logits,
1918
+ hidden_states=outputs.hidden_states,
1919
+ attentions=outputs.attentions,
1920
+ )
1921
+
1922
+
1923
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
1924
+ class AMSoftmaxLoss(nn.Module):
1925
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
1926
+ super(AMSoftmaxLoss, self).__init__()
1927
+ self.scale = scale
1928
+ self.margin = margin
1929
+ self.num_labels = num_labels
1930
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
1931
+ self.loss = nn.CrossEntropyLoss()
1932
+
1933
+ def forward(self, hidden_states, labels):
1934
+ labels = labels.flatten()
1935
+ weight = nn.functional.normalize(self.weight, dim=0)
1936
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
1937
+ cos_theta = torch.mm(hidden_states, weight)
1938
+ psi = cos_theta - self.margin
1939
+
1940
+ onehot = nn.functional.one_hot(labels, self.num_labels)
1941
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
1942
+ loss = self.loss(logits, labels)
1943
+
1944
+ return loss
1945
+
1946
+
1947
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
1948
+ class TDNNLayer(nn.Module):
1949
+ def __init__(self, config, layer_id=0):
1950
+ super().__init__()
1951
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
1952
+ self.out_conv_dim = config.tdnn_dim[layer_id]
1953
+ self.kernel_size = config.tdnn_kernel[layer_id]
1954
+ self.dilation = config.tdnn_dilation[layer_id]
1955
+
1956
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
1957
+ self.activation = nn.ReLU()
1958
+
1959
+ def forward(self, hidden_states):
1960
+ hidden_states = hidden_states.unsqueeze(1)
1961
+ hidden_states = nn.functional.unfold(
1962
+ hidden_states,
1963
+ (self.kernel_size, self.in_conv_dim),
1964
+ stride=(1, self.in_conv_dim),
1965
+ dilation=(self.dilation, 1),
1966
+ )
1967
+ hidden_states = hidden_states.transpose(1, 2)
1968
+ hidden_states = self.kernel(hidden_states)
1969
+
1970
+ hidden_states = self.activation(hidden_states)
1971
+ return hidden_states
1972
+
1973
+
1974
+ @add_start_docstrings(
1975
+ """
1976
+ Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
1977
+ """,
1978
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1979
+ )
1980
+ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
1981
+ def __init__(self, config):
1982
+ super().__init__(config)
1983
+
1984
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1985
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1986
+ if config.use_weighted_layer_sum:
1987
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1988
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
1989
+
1990
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
1991
+ self.tdnn = nn.ModuleList(tdnn_layers)
1992
+
1993
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
1994
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
1995
+
1996
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
1997
+
1998
+ self.init_weights()
1999
+
2000
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
2001
+ def freeze_feature_encoder(self):
2002
+ """
2003
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
2004
+ not be updated during training.
2005
+ """
2006
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
2007
+
2008
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
2009
+ def freeze_base_model(self):
2010
+ """
2011
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
2012
+ be updated during training. Only the classification head will be updated.
2013
+ """
2014
+ for param in self.wav2vec2_conformer.parameters():
2015
+ param.requires_grad = False
2016
+
2017
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
2018
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
2019
+ """
2020
+ Computes the output length of the TDNN layers
2021
+ """
2022
+
2023
+ def _conv_out_length(input_length, kernel_size, stride):
2024
+ # 1D convolutional layer output length formula taken
2025
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
2026
+ return (input_length - kernel_size) // stride + 1
2027
+
2028
+ for kernel_size in self.config.tdnn_kernel:
2029
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
2030
+
2031
+ return input_lengths
2032
+
2033
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
2034
+ @add_code_sample_docstrings(
2035
+ checkpoint=_CHECKPOINT_FOR_DOC,
2036
+ output_type=XVectorOutput,
2037
+ config_class=_CONFIG_FOR_DOC,
2038
+ modality="audio",
2039
+ )
2040
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
2041
+ def forward(
2042
+ self,
2043
+ input_values: Optional[torch.Tensor],
2044
+ attention_mask: Optional[torch.Tensor] = None,
2045
+ output_attentions: Optional[bool] = None,
2046
+ output_hidden_states: Optional[bool] = None,
2047
+ return_dict: Optional[bool] = None,
2048
+ labels: Optional[torch.Tensor] = None,
2049
+ ) -> Union[Tuple, XVectorOutput]:
2050
+ r"""
2051
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2052
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
2053
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
2054
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
2055
+ """
2056
+
2057
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2058
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
2059
+
2060
+ outputs = self.wav2vec2_conformer(
2061
+ input_values,
2062
+ attention_mask=attention_mask,
2063
+ output_attentions=output_attentions,
2064
+ output_hidden_states=output_hidden_states,
2065
+ return_dict=return_dict,
2066
+ )
2067
+
2068
+ if self.config.use_weighted_layer_sum:
2069
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
2070
+ hidden_states = torch.stack(hidden_states, dim=1)
2071
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
2072
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
2073
+ else:
2074
+ hidden_states = outputs[0]
2075
+
2076
+ hidden_states = self.projector(hidden_states)
2077
+
2078
+ for tdnn_layer in self.tdnn:
2079
+ hidden_states = tdnn_layer(hidden_states)
2080
+
2081
+ # Statistic Pooling
2082
+ if attention_mask is None:
2083
+ mean_features = hidden_states.mean(dim=1)
2084
+ std_features = hidden_states.std(dim=1)
2085
+ else:
2086
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
2087
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
2088
+ mean_features = []
2089
+ std_features = []
2090
+ for i, length in enumerate(tdnn_output_lengths):
2091
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
2092
+ std_features.append(hidden_states[i, :length].std(dim=0))
2093
+ mean_features = torch.stack(mean_features)
2094
+ std_features = torch.stack(std_features)
2095
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
2096
+
2097
+ output_embeddings = self.feature_extractor(statistic_pooling)
2098
+ logits = self.classifier(output_embeddings)
2099
+
2100
+ loss = None
2101
+ if labels is not None:
2102
+ loss = self.objective(logits, labels)
2103
+
2104
+ if not return_dict:
2105
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
2106
+ return ((loss,) + output) if loss is not None else output
2107
+
2108
+ return XVectorOutput(
2109
+ loss=loss,
2110
+ logits=logits,
2111
+ embeddings=output_embeddings,
2112
+ hidden_states=outputs.hidden_states,
2113
+ attentions=outputs.attentions,
2114
+ )
slam_llm/models/musicfm/modules/random_quantizer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright 2023 ByteDance Inc.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
6
+ # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10
+ #
11
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
13
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
14
+ # IN THE SOFTWARE.
15
+
16
+ import torch
17
+ from torch import nn, einsum
18
+ from einops import rearrange
19
+
20
+
21
+ class RandomProjectionQuantizer(nn.Module):
22
+ """
23
+ Random projection and codebook lookup module
24
+
25
+ Some code is borrowed from:
26
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
27
+ But I did normalization using pre-computed global mean & variance instead of using layer norm.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ input_dim,
33
+ codebook_dim,
34
+ codebook_size,
35
+ seed=142,
36
+ ):
37
+ super().__init__()
38
+
39
+ # random seed
40
+ torch.manual_seed(seed)
41
+
42
+ # randomly initialized projection
43
+ random_projection = torch.empty(input_dim, codebook_dim)
44
+ nn.init.xavier_normal_(random_projection)
45
+ self.register_buffer("random_projection", random_projection)
46
+
47
+ # randomly initialized codebook
48
+ codebook = torch.empty(codebook_size, codebook_dim)
49
+ nn.init.normal_(codebook)
50
+ self.register_buffer("codebook", codebook)
51
+
52
+ def codebook_lookup(self, x):
53
+ # reshape
54
+ b = x.shape[0]
55
+ x = rearrange(x, "b n e -> (b n) e")
56
+
57
+ # L2 normalization
58
+ normalized_x = nn.functional.normalize(x, dim=1, p=2)
59
+ normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
60
+
61
+ # compute distances
62
+ distances = torch.cdist(normalized_codebook, normalized_x)
63
+
64
+ # get nearest
65
+ nearest_indices = torch.argmin(distances, dim=0)
66
+
67
+ # reshape
68
+ xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
69
+
70
+ return xq
71
+
72
+ @torch.no_grad()
73
+ def forward(self, x):
74
+ # always eval
75
+ self.eval()
76
+
77
+ # random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
78
+ x = einsum("b n d, d e -> b n e", x, self.random_projection)
79
+
80
+ # codebook lookup
81
+ xq = self.codebook_lookup(x)
82
+
83
+ return xq
slam_llm/models/projector.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class EncoderProjectorConcat(nn.Module):
6
+ def __init__(self, config):
7
+ super().__init__()
8
+ self.k = config.encoder_projector_ds_rate
9
+ self.encoder_dim = config.encoder_dim
10
+ self.llm_dim = config.llm_dim
11
+ self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048)
12
+ self.relu = nn.ReLU()
13
+ self.linear2 = nn.Linear(2048, config.llm_dim)
14
+
15
+ def forward(self, x):
16
+ batch_size, seq_len, dim = x.size()
17
+ num_frames_to_discard = seq_len % self.k
18
+ if num_frames_to_discard > 0:
19
+ x = x[:, :-num_frames_to_discard, :]
20
+ seq_len = x.size(1)
21
+
22
+ x = x.contiguous()
23
+ x = x.view(batch_size, seq_len // self.k, dim * self.k)
24
+ x = self.linear1(x)
25
+ x = self.relu(x)
26
+ x = self.linear2(x)
27
+ return x
28
+
29
+ class EncoderProjectorCov1d(nn.Module):
30
+ def __init__(self, config):
31
+ super().__init__()
32
+ self.k = config.encoder_projector_ds_rate
33
+ self.encoder_dim = config.encoder_dim
34
+ self.llm_dim = config.llm_dim
35
+ self.conv1d = nn.Conv1d(in_channels=self.encoder_dim, out_channels=self.encoder_dim, kernel_size=self.k, stride=self.k, padding=0)
36
+ self.linear1 = nn.Linear(self.encoder_dim, 2048)
37
+ self.relu1 = nn.ReLU()
38
+ self.linear2 = nn.Linear(2048, self.llm_dim)
39
+ self.relu2 = nn.ReLU()
40
+
41
+ def forward(self, x):
42
+ x = x.transpose(1, 2)
43
+ x = self.conv1d(x)
44
+ x = x.transpose(1, 2)
45
+ x = self.relu1(x)
46
+ x = self.linear1(x)
47
+ x = self.relu2(x)
48
+ x = self.linear2(x)
49
+ return x
50
+
51
+ class EncoderProjectorQFormer(nn.Module):
52
+ def __init__(self, config):
53
+ super().__init__()
54
+ self.encoder_dim = config.encoder_dim
55
+ self.llm_dim = config.llm_dim
56
+ from transformers import Blip2QFormerConfig, Blip2QFormerModel
57
+ configuration = Blip2QFormerConfig()
58
+ configuration.encoder_hidden_size = self.encoder_dim
59
+ configuration.num_hidden_layers = 8
60
+
61
+ self.query_len = 64
62
+ self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
63
+ self.query.data.normal_(mean=0.0, std=1.0)
64
+ self.qformer = Blip2QFormerModel(configuration)
65
+
66
+ self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
67
+ self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
68
+
69
+ def forward(self, x, atts):
70
+ query = self.query.expand(x.shape[0], -1, -1)
71
+
72
+ query_output = self.qformer(
73
+ query_embeds=query,
74
+ encoder_hidden_states=x,
75
+ encoder_attention_mask=atts,
76
+ return_dict=True,
77
+ )
78
+
79
+ query_proj = self.norm(self.linear(query_output.last_hidden_state))
80
+
81
+ return query_proj
slam_llm/models/slam_model.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import types
3
+ import torch
4
+ import soundfile as sf
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.distributed as dist
8
+ from typing import List, Optional, Tuple, Union
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
10
+ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
11
+
12
+ from slam_llm.utils.config_utils import generate_peft_config
13
+ from slam_llm.utils.train_utils import print_module_size, print_model_size
14
+ from peft import PeftModel, PeftConfig
15
+ from torch.nn import CrossEntropyLoss
16
+ from slam_llm.utils.metric import compute_accuracy
17
+
18
+ import logging
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def model_factory(train_config, model_config, **kwargs):
22
+ # return necessary components for training
23
+ tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
24
+
25
+ encoder = setup_encoder(train_config, model_config, **kwargs)
26
+
27
+ # llm
28
+ llm = setup_llm(train_config, model_config, **kwargs)
29
+
30
+ # projector
31
+ encoder_projector = setup_encoder_projector(
32
+ train_config, model_config, **kwargs
33
+ )
34
+ model = slam_model(
35
+ encoder,
36
+ llm,
37
+ encoder_projector,
38
+ tokenizer,
39
+ train_config,
40
+ model_config,
41
+ **kwargs,
42
+ )
43
+
44
+ ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft)
45
+ if ckpt_path is not None:
46
+ logger.info("loading other parts from: {}".format(ckpt_path))
47
+ ckpt_dict = torch.load(ckpt_path, map_location="cpu")
48
+ model.load_state_dict(ckpt_dict, strict=False)
49
+
50
+ print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
51
+ return model, tokenizer
52
+
53
+
54
+ def setup_tokenizer(train_config, model_config, **kwargs):
55
+ # Load the tokenizer and add special tokens
56
+ if "vallex" in model_config.llm_name.lower():
57
+ return None
58
+ elif "mupt" in model_config.llm_name.lower():
59
+ tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path,
60
+ trust_remote_code=True,
61
+ use_fast=False)
62
+ else:
63
+ tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path)
64
+ tokenizer.pad_token_id = tokenizer.eos_token_id
65
+ return tokenizer
66
+
67
+
68
+ def setup_encoder(train_config, model_config, **kwargs):
69
+ encoder_list = model_config.encoder_name.split(",") if model_config.encoder_name else []
70
+ if len(encoder_list) == 0:
71
+ return None
72
+ if len(encoder_list) == 1:
73
+ encoder_name = encoder_list[0]
74
+ if encoder_name == "whisper" or encoder_name == "qwen-audio":
75
+ from slam_llm.models.encoder import WhisperWrappedEncoder
76
+ encoder = WhisperWrappedEncoder.load(model_config)
77
+ if encoder_name == "beats":
78
+ from slam_llm.models.encoder import BEATsEncoder
79
+ encoder = BEATsEncoder.load(model_config)
80
+ if encoder_name == "eat":
81
+ from slam_llm.models.encoder import EATEncoder
82
+ encoder = EATEncoder.load(model_config)
83
+ if encoder_name == "SpatialAST":
84
+ from slam_llm.models.encoder import SpatialASTEncoder
85
+ encoder = SpatialASTEncoder.load(model_config)
86
+ if encoder_name == "wavlm":
87
+ from slam_llm.models.encoder import WavLMEncoder
88
+ encoder = WavLMEncoder.load(model_config)
89
+ if encoder_name == "av_hubert":
90
+ from slam_llm.models.encoder import AVHubertEncoder
91
+ encoder = AVHubertEncoder.load(model_config)
92
+ if encoder_name == "hubert":
93
+ from slam_llm.models.encoder import HubertEncoder
94
+ encoder = HubertEncoder.load(model_config)
95
+ if encoder_name == "musicfm":
96
+ from slam_llm.models.encoder import MusicFMEncoder
97
+ encoder = MusicFMEncoder.load(model_config)
98
+
99
+ if "llama" in encoder_name.lower():
100
+ from slam_llm.models.encoder import HfTextEncoder
101
+ encoder = HfTextEncoder.load(model_config)
102
+ print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
103
+
104
+ if train_config.freeze_encoder:
105
+ for name, param in encoder.named_parameters():
106
+ param.requires_grad = False
107
+ encoder.eval()
108
+ print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
109
+
110
+ return encoder
111
+
112
+ def setup_llm(train_config, model_config, **kwargs):
113
+ from pkg_resources import packaging
114
+ use_cache = False if train_config.enable_fsdp or train_config.enable_ddp else None
115
+ if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.low_cpu_fsdp:
116
+ """
117
+ for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
118
+ this avoids cpu oom when loading large models like llama 70B, in which case
119
+ model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
120
+ overhead and currently requires latest nightly.
121
+ """
122
+ # v = packaging.version.parse(torch.__version__)
123
+ # verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
124
+ # if not verify_latest_nightly:
125
+ # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
126
+ # "please install latest nightly.")
127
+ rank = int(os.environ["RANK"])
128
+ if rank == 0:
129
+ if "vallex" in model_config.llm_name.lower():
130
+ from src.slam_llm.models.vallex.vallex_config import VallexConfig
131
+ from src.slam_llm.models.vallex.vallex_model import VALLE
132
+ vallex_config = VallexConfig(
133
+ **model_config
134
+ )
135
+ model = VALLE(vallex_config)
136
+ elif "aya" in model_config.llm_name.lower():
137
+ model = AutoModelForSeq2SeqLM.from_pretrained(
138
+ model_config.llm_path,
139
+ load_in_8bit=True if train_config.quantization else None,
140
+ device_map="auto" if train_config.quantization else None,
141
+ use_cache=use_cache,
142
+ )
143
+ else:
144
+ model = AutoModelForCausalLM.from_pretrained(
145
+ model_config.llm_path,
146
+ load_in_8bit=True if train_config.quantization else None,
147
+ device_map="auto" if train_config.quantization else None,
148
+ use_cache=use_cache,
149
+ )
150
+ else:
151
+ llama_config = AutoConfig.from_pretrained(model_config.llm_path)
152
+ llama_config.use_cache = use_cache
153
+ # with torch.device("meta"):
154
+ if "aya" in model_config.llm_name.lower():
155
+ model = AutoModelForSeq2SeqLM(llama_config)
156
+ else:
157
+ model = AutoModelForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta`
158
+
159
+ else:
160
+ if "vallex" in model_config.llm_name.lower():
161
+ from src.slam_llm.models.vallex.vallex_config import VallexConfig
162
+ from src.slam_llm.models.vallex.vallex_model import VALLE
163
+ vallex_config = VallexConfig(
164
+ **model_config
165
+ )
166
+ model = VALLE(vallex_config)
167
+ elif "aya" in model_config.llm_name.lower():
168
+ model = AutoModelForSeq2SeqLM.from_pretrained(
169
+ model_config.llm_path,
170
+ load_in_8bit=True if train_config.quantization else None,
171
+ device_map="auto" if train_config.quantization else None,
172
+ use_cache=use_cache,
173
+ )
174
+ else:
175
+ model = AutoModelForCausalLM.from_pretrained(
176
+ model_config.llm_path,
177
+ load_in_8bit=True if train_config.quantization else None,
178
+ device_map="auto" if train_config.quantization else None,
179
+ use_cache=use_cache,
180
+ )
181
+ if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.use_fast_kernels:
182
+ """
183
+ For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
184
+ using of Flash Attention or Xformer memory-efficient kernels
185
+ based on the hardware being used. This would speed up fine-tuning.
186
+ """
187
+ try:
188
+ from optimum.bettertransformer import BetterTransformer
189
+ model = BetterTransformer.transform(model)
190
+ except ImportError:
191
+ logger.warning("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
192
+
193
+ print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
194
+
195
+ # Prepare the model for int8 training if quantization is enabled
196
+ if train_config.quantization:
197
+ model = prepare_model_for_kbit_training(model)
198
+
199
+ if train_config.freeze_llm: # TODO:to test offical `freeze_layers` and `num_freeze_layers`
200
+ for name, param in model.named_parameters():
201
+ param.requires_grad = False
202
+ model.eval()
203
+
204
+ if kwargs.get("peft_ckpt", None): # (FIX:MZY):reload will get wrong results when decoding
205
+ logger.info("loading peft_ckpt from: {}".format(kwargs.get("peft_ckpt")))
206
+ model = PeftModel.from_pretrained(model=model, model_id=kwargs.get("peft_ckpt"), is_trainable=True)
207
+ model.print_trainable_parameters()
208
+ elif train_config.use_peft:
209
+ logger.info("setup peft...")
210
+ peft_config = generate_peft_config(train_config)
211
+ model = get_peft_model(model, peft_config)
212
+ model.print_trainable_parameters()
213
+
214
+ print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
215
+ return model
216
+
217
+ def setup_encoder_projector(train_config, model_config, **kwargs):
218
+ if model_config.encoder_projector == "linear":
219
+ from slam_llm.models.projector import EncoderProjectorConcat
220
+ encoder_projector = EncoderProjectorConcat(model_config)
221
+ elif model_config.encoder_projector == "cov1d-linear":
222
+ from slam_llm.models.projector import EncoderProjectorCov1d
223
+ encoder_projector = EncoderProjectorCov1d(model_config)
224
+ elif model_config.encoder_projector == "q-former":
225
+ from slam_llm.models.projector import EncoderProjectorQFormer
226
+ encoder_projector = EncoderProjectorQFormer(model_config)
227
+ else:
228
+ return None
229
+ print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
230
+ return encoder_projector
231
+
232
+
233
+ class slam_model(nn.Module):
234
+ def __init__(
235
+ self,
236
+ encoder: nn.Module,
237
+ llm: nn.Module,
238
+ encoder_projector: nn.Module,
239
+ tokenizer,
240
+ train_config,
241
+ model_config,
242
+ **kwargs
243
+ ):
244
+ super().__init__()
245
+ # modality encoder
246
+ self.encoder = encoder
247
+
248
+ # llm
249
+ self.llm = llm
250
+
251
+ # projector
252
+ self.encoder_projector = encoder_projector
253
+
254
+ # tokenizer
255
+ self.tokenizer = tokenizer
256
+ self.metric = kwargs.get("metric", "acc")
257
+
258
+ self.train_config = train_config
259
+ self.model_config = model_config
260
+
261
+ if train_config.get("enable_deepspeed", False):
262
+ def new_forward(self, input):
263
+ output = F.layer_norm(
264
+ input.float(),
265
+ self.normalized_shape,
266
+ self.weight.float() if self.weight is not None else None,
267
+ self.bias.float() if self.bias is not None else None,
268
+ self.eps,
269
+ )
270
+ return output.type_as(input)
271
+ for item in self.modules():
272
+ if isinstance(item, nn.LayerNorm):
273
+ item.forward = types.MethodType(new_forward, item)
274
+
275
+
276
+
277
+ def forward(self,
278
+ input_ids: torch.LongTensor = None,
279
+ attention_mask: Optional[torch.Tensor] = None,
280
+ position_ids: Optional[torch.LongTensor] = None,
281
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
282
+ inputs_embeds: Optional[torch.FloatTensor] = None,
283
+ labels: Optional[torch.LongTensor] = None,
284
+ use_cache: Optional[bool] = None,
285
+ output_attentions: Optional[bool] = None,
286
+ output_hidden_states: Optional[bool] = None,
287
+ return_dict: Optional[bool] = None,
288
+ **kwargs,
289
+ ):
290
+ audio_mel = kwargs.get("audio_mel", None)
291
+ audio_mel_mask = kwargs.get("audio_mel_mask", None)
292
+ audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
293
+
294
+ audio = kwargs.get("audio", None)
295
+ audio_mask = kwargs.get("audio_mask", None)
296
+ visual = kwargs.get("visual", None)
297
+ visual_mask = kwargs.get("visual_mask", None)
298
+
299
+
300
+ # for text encoder
301
+ instruct_ids = kwargs.get("instruct_ids", None)
302
+ instruct_mask = kwargs.get("instruct_mask", None)
303
+
304
+ modality_mask = kwargs.get("modality_mask", None)
305
+
306
+ zh_data = kwargs.get("zh", None)
307
+ en_data = kwargs.get("en", None)
308
+
309
+ encoder_outs = None
310
+ if audio_mel is not None or audio is not None or visual is not None:
311
+ if self.train_config.freeze_encoder: # freeze encoder
312
+ self.encoder.eval()
313
+
314
+ if self.model_config.encoder_name == "whisper":
315
+ encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim
316
+ if self.model_config.encoder_name == "beats":
317
+ encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim
318
+ if self.model_config.encoder_name == "eat":
319
+ encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x']
320
+ if self.model_config.encoder_name == "SpatialAST":
321
+ encoder_outs = self.encoder(audio) # output: [bs, seq_len=3+512, dim=768]
322
+ if self.model_config.encoder_name == "wavlm":
323
+ encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask
324
+ if self.model_config.encoder_name == "hubert":
325
+ results = self.encoder(source = audio, padding_mask = 1-audio_mask)
326
+ if self.model_config.encoder_type == "pretrain":
327
+ encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"]
328
+ if self.model_config.encoder_type == "finetune":
329
+ encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
330
+ encoder_outs = encoder_outs.transpose(0, 1)
331
+ if self.model_config.encoder_name == "av_hubert":
332
+ results = self.encoder(source={'video':visual, 'audio':audio}, padding_mask=visual_mask) # bs*seq*dim
333
+ encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
334
+ encoder_outs = encoder_outs.transpose(0, 1)
335
+ audio_mel_post_mask = (~audio_mel_post_mask).float()
336
+ if self.model_config.encoder_name == 'musicfm':
337
+ encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask
338
+ if self.encoder is None:
339
+ encoder_outs = audio_mel if audio_mel is not None else audio
340
+
341
+ if self.model_config.encoder_projector == "q-former":
342
+ encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
343
+ if self.model_config.encoder_projector == "linear":
344
+ encoder_outs = self.encoder_projector(encoder_outs)
345
+ if self.model_config.encoder_projector == "cov1d-linear":
346
+ encoder_outs = self.encoder_projector(encoder_outs)
347
+
348
+ if instruct_ids is not None:
349
+ if self.encoder is not None:
350
+ encoder_outs = self.encoder(input_ids=instruct_ids, attention_mask=instruct_mask).last_hidden_state
351
+
352
+ if self.model_config.encoder_projector == "q-former":
353
+ encoder_outs = self.encoder_projector(encoder_outs, instruct_mask)
354
+ if self.model_config.encoder_projector == "linear":
355
+ encoder_outs = self.encoder_projector(encoder_outs)
356
+
357
+ if input_ids is not None:
358
+ input_ids[input_ids == -1] = 0
359
+ if isinstance(self.llm, T5ForConditionalGeneration):
360
+ inputs_embeds = self.llm.shared(input_ids)
361
+ else:
362
+ if hasattr(self.llm.model, "embed_tokens"):
363
+ inputs_embeds = self.llm.model.embed_tokens(input_ids)
364
+ elif hasattr(self.llm.model.model, "embed_tokens"):
365
+ inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
366
+ else:
367
+ inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
368
+
369
+ if modality_mask is not None:
370
+ modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1)
371
+ modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist()
372
+
373
+ encoder_outs_pad = torch.zeros_like(inputs_embeds)
374
+ for i in range(encoder_outs.shape[0]):
375
+ encoder_outs_pad[
376
+ i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i]
377
+ ] = encoder_outs[i][:modality_lengths[i]]
378
+
379
+ inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])
380
+
381
+ if kwargs.get("inference_mode", False):
382
+ return inputs_embeds, attention_mask
383
+
384
+ if zh_data is not None and en_data is not None:
385
+ model_outputs, acc = self.llm(zh=zh_data, en=en_data)
386
+ else:
387
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
388
+ acc = -1
389
+ if self.metric:
390
+ with torch.no_grad():
391
+ preds = torch.argmax(model_outputs.logits, -1)
392
+ acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)
393
+
394
+ return model_outputs, acc
395
+
396
+ @torch.no_grad()
397
+ def generate(self,
398
+ input_ids: torch.LongTensor = None,
399
+ attention_mask: Optional[torch.Tensor] = None,
400
+ position_ids: Optional[torch.LongTensor] = None,
401
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
402
+ inputs_embeds: Optional[torch.FloatTensor] = None,
403
+ labels: Optional[torch.LongTensor] = None,
404
+ use_cache: Optional[bool] = None,
405
+ output_attentions: Optional[bool] = None,
406
+ output_hidden_states: Optional[bool] = None,
407
+ return_dict: Optional[bool] = None,
408
+ **kwargs,
409
+ ):
410
+ kwargs["inference_mode"] = True
411
+
412
+ inputs_embeds, attention_mask = self.forward(
413
+ input_ids=input_ids,
414
+ attention_mask=attention_mask,
415
+ position_ids=position_ids,
416
+ past_key_values=past_key_values,
417
+ inputs_embeds=inputs_embeds,
418
+ labels=labels,
419
+ use_cache=use_cache,
420
+ output_attentions=output_attentions,
421
+ output_hidden_states=output_hidden_states,
422
+ return_dict=return_dict,
423
+ **kwargs,
424
+ )
425
+
426
+ model_outputs = self.llm.generate(
427
+ inputs_embeds=inputs_embeds,
428
+ # max_length=kwargs.get("max_length", 200),
429
+ max_new_tokens=kwargs.get("max_new_tokens", 200),
430
+ num_beams=kwargs.get("num_beams", 4),
431
+ do_sample=kwargs.get("do_sample", False),
432
+ min_length=kwargs.get("min_length", 1),
433
+ top_p=kwargs.get("top_p", 1.0),
434
+ repetition_penalty=kwargs.get("repetition_penalty", 1.0),
435
+ length_penalty=kwargs.get("length_penalty", 1.0),
436
+ temperature=kwargs.get("temperature", 1.0),
437
+ attention_mask=attention_mask,
438
+ bos_token_id=self.tokenizer.bos_token_id,
439
+ eos_token_id=self.tokenizer.eos_token_id,
440
+ pad_token_id=self.tokenizer.pad_token_id
441
+ )
442
+
443
+ return model_outputs
slam_llm/models/vallex/__init__.py ADDED
File without changes
slam_llm/models/vallex/activation.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List
2
+ import math
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
9
+
10
+
11
+ class MultiheadAttention(Module):
12
+ __constants__ = ["batch_first"]
13
+ bias_k: Optional[torch.Tensor]
14
+ bias_v: Optional[torch.Tensor]
15
+
16
+ def __init__(
17
+ self,
18
+ embed_dim,
19
+ num_heads,
20
+ dropout=0.0,
21
+ bias=True,
22
+ add_bias_kv=False,
23
+ add_zero_attn=False,
24
+ kdim=None,
25
+ vdim=None,
26
+ batch_first=False,
27
+ linear1_cls=Linear,
28
+ linear2_cls=Linear,
29
+ device=None,
30
+ dtype=None,
31
+ ) -> None:
32
+ factory_kwargs = {"device": device, "dtype": dtype}
33
+ super(MultiheadAttention, self).__init__()
34
+ self.embed_dim = embed_dim
35
+ self.kdim = kdim if kdim is not None else embed_dim
36
+ self.vdim = vdim if vdim is not None else embed_dim
37
+ self._qkv_same_embed_dim = False
38
+
39
+ self.num_heads = num_heads
40
+ self.dropout = dropout
41
+ self.batch_first = batch_first
42
+ self.head_dim = embed_dim // num_heads
43
+ self.num_heads = num_heads
44
+ assert (
45
+ self.head_dim * num_heads == self.embed_dim
46
+ ), "embed_dim must be divisible by num_heads"
47
+
48
+ self.k_proj = Linear(self.kdim, embed_dim)
49
+ self.v_proj = Linear(self.kdim, embed_dim)
50
+ self.q_proj = Linear(self.kdim, embed_dim)
51
+
52
+ self.out_proj = NonDynamicallyQuantizableLinear(
53
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
54
+ )
55
+
56
+ self.add_zero_attn = add_zero_attn
57
+ self.scaling = self.head_dim**-0.5
58
+
59
+ def __setstate__(self, state):
60
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
61
+ if "_qkv_same_embed_dim" not in state:
62
+ state["_qkv_same_embed_dim"] = True
63
+
64
+ super(MultiheadAttention, self).__setstate__(state)
65
+
66
+ def forward(
67
+ self,
68
+ query: Tensor,
69
+ key: Tensor,
70
+ value: Tensor,
71
+ key_padding_mask: Optional[Tensor] = None,
72
+ need_weights: bool = True,
73
+ attn_mask: Optional[Tensor] = None,
74
+ average_attn_weights: bool = True,
75
+ ) -> Tuple[Tensor, Optional[Tensor]]:
76
+
77
+ # T,B,C
78
+ B, T, C = query.size()
79
+
80
+ q = self.q_proj(query)
81
+ k = self.k_proj(key)
82
+ v = self.v_proj(value)
83
+ q *= self.scaling
84
+
85
+ k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
86
+ q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
87
+ v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
88
+
89
+ attn_weights = q @ k.transpose(-2, -1) # B, nh, T, T
90
+
91
+ if attn_mask is not None:
92
+ # attn_mask is inf
93
+ # attn_mask = attn_mask.unsqueeze(0)
94
+ # attn_weights += attn_mask
95
+ if torch.is_floating_point(attn_mask):
96
+ # print(attn_weights.size(), attn_mask.size())
97
+ attn_weights += attn_mask.unsqueeze(0).unsqueeze(1)
98
+ else:
99
+ attn_weights = attn_weights.masked_fill(attn_mask, float('-inf'))
100
+
101
+ if key_padding_mask is not None:
102
+ # don't attend to padding symbols
103
+ attn_weights = attn_weights.view(B, self.num_heads, T, T)
104
+ attn_weights = attn_weights.masked_fill(
105
+ key_padding_mask.unsqueeze(1)
106
+ .unsqueeze(2)
107
+ .to(torch.bool),
108
+ float("-inf"),
109
+ )
110
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
111
+ attn = attn_weights_float @ v
112
+
113
+ y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
114
+ y = self.out_proj(y)
115
+ return y, attn_weights
116
+
117
+ def infer(self,
118
+ x: Tensor,
119
+ key_padding_mask: Optional[Tensor] = None,
120
+ need_weights: bool = True,
121
+ attn_mask: Optional[Tensor] = None,
122
+ average_attn_weights: bool = True,
123
+ past_kv = None,
124
+ use_cache = False):
125
+
126
+ # print("debug:"+str(x.size()))
127
+
128
+ B, T, C = x.size()
129
+
130
+ q = self.q_proj(x)
131
+ k = self.k_proj(x)
132
+ v = self.v_proj(x)
133
+ q *= self.scaling
134
+
135
+ # k = k.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs)
136
+ # q = q.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs)
137
+ # v = v.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs)
138
+
139
+ k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
140
+ q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
141
+ v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
142
+
143
+ if past_kv is not None:
144
+ past_key = past_kv[0]
145
+ past_value = past_kv[1]
146
+ k = torch.cat((past_key, k), dim=-2)
147
+ v = torch.cat((past_value, v), dim=-2)
148
+
149
+ FULL_T = k.shape[-2]
150
+
151
+ if use_cache is True:
152
+ present = (k, v)
153
+ else:
154
+ present = None
155
+
156
+ # print(q.size(), k.size())
157
+ attn_weights = q @ k.transpose(-2, -1)
158
+ # print(attn_mask.size())
159
+ attn_weights = attn_weights.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
160
+
161
+ # if key_padding_mask is not None:
162
+ # # don't attend to padding symbols
163
+ # attn_weights = attn_weights.view(B, self.num_heads, T, T)
164
+ # attn_weights = attn_weights.view(B, -1, self.num_heads, T, T)
165
+ # attn_weights = attn_weights.masked_fill(
166
+ # key_padding_mask.unsqueeze(1)
167
+ # .unsqueeze(2)
168
+ # .unsqueeze(3)
169
+ # .to(torch.bool),
170
+ # float("-inf"),
171
+ # )
172
+ attn_weights_float = F.softmax(attn_weights, dim=-1, )
173
+ # attn_weights = attn_weights_float.type_as(attn_weights)
174
+ # attn = torch.bmm(attn_weights, v)
175
+ attn = attn_weights_float @ v
176
+
177
+ y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
178
+ y = self.out_proj(y)
179
+ return (y, present)
slam_llm/models/vallex/scaling.py ADDED
@@ -0,0 +1,1404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import collections
19
+ import logging
20
+ import random
21
+ import math
22
+ from functools import reduce
23
+ from itertools import repeat
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torch import Tensor
30
+ from torch.nn import Embedding as ScaledEmbedding
31
+
32
+ class Transpose(nn.Identity):
33
+ """(N, T, D) -> (N, D, T)"""
34
+
35
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
36
+ return input.transpose(1, 2)
37
+
38
+ class ActivationBalancerFunction(torch.autograd.Function):
39
+ @staticmethod
40
+ def forward(
41
+ ctx,
42
+ x: Tensor,
43
+ scale_factor: Tensor,
44
+ sign_factor: Optional[Tensor],
45
+ channel_dim: int,
46
+ ) -> Tensor:
47
+ if channel_dim < 0:
48
+ channel_dim += x.ndim
49
+ ctx.channel_dim = channel_dim
50
+ xgt0 = x > 0
51
+ if sign_factor is None:
52
+ ctx.save_for_backward(xgt0, scale_factor)
53
+ else:
54
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
55
+ return x
56
+
57
+ @staticmethod
58
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
59
+ if len(ctx.saved_tensors) == 3:
60
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
61
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
62
+ scale_factor = scale_factor.unsqueeze(-1)
63
+ sign_factor = sign_factor.unsqueeze(-1)
64
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
65
+ else:
66
+ xgt0, scale_factor = ctx.saved_tensors
67
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
68
+ scale_factor = scale_factor.unsqueeze(-1)
69
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
70
+ neg_delta_grad = x_grad.abs() * factor
71
+ return (
72
+ x_grad - neg_delta_grad,
73
+ None,
74
+ None,
75
+ None,
76
+ )
77
+
78
+
79
+ def _compute_scale_factor(
80
+ x: Tensor,
81
+ channel_dim: int,
82
+ min_abs: float,
83
+ max_abs: float,
84
+ gain_factor: float,
85
+ max_factor: float,
86
+ ) -> Tensor:
87
+ if channel_dim < 0:
88
+ channel_dim += x.ndim
89
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
90
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
91
+
92
+ if min_abs == 0.0:
93
+ below_threshold = 0.0
94
+ else:
95
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
96
+ # x_abs)_mean , min_abs.
97
+ below_threshold = (
98
+ (min_abs - x_abs_mean) * (gain_factor / min_abs)
99
+ ).clamp(min=0, max=max_factor)
100
+
101
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
102
+ min=0, max=max_factor
103
+ )
104
+
105
+ return below_threshold - above_threshold
106
+
107
+
108
+ def _compute_sign_factor(
109
+ x: Tensor,
110
+ channel_dim: int,
111
+ min_positive: float,
112
+ max_positive: float,
113
+ gain_factor: float,
114
+ max_factor: float,
115
+ ) -> Tensor:
116
+ if channel_dim < 0:
117
+ channel_dim += x.ndim
118
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
119
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
120
+ if min_positive == 0.0:
121
+ factor1 = 0.0
122
+ else:
123
+ # 0 if proportion_positive >= min_positive, else can be
124
+ # as large as max_factor.
125
+ factor1 = (
126
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
127
+ ).clamp_(min=0, max=max_factor)
128
+
129
+ if max_positive == 1.0:
130
+ factor2 = 0.0
131
+ else:
132
+ # 0 if self.proportion_positive <= max_positive, else can be
133
+ # as large as -max_factor.
134
+ factor2 = (
135
+ (proportion_positive - max_positive)
136
+ * (gain_factor / (1.0 - max_positive))
137
+ ).clamp_(min=0, max=max_factor)
138
+ sign_factor = factor1 - factor2
139
+ # require min_positive != 0 or max_positive != 1:
140
+ assert not isinstance(sign_factor, float)
141
+ return sign_factor
142
+
143
+
144
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
145
+ """
146
+ This object is used in class ActivationBalancer when the user specified
147
+ min_positive=0, max_positive=1, so there are no constraints on the signs
148
+ of the activations and only the absolute value has a constraint.
149
+ """
150
+
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ x: Tensor,
155
+ sign_factor: Tensor,
156
+ scale_factor: Tensor,
157
+ channel_dim: int,
158
+ ) -> Tensor:
159
+ if channel_dim < 0:
160
+ channel_dim += x.ndim
161
+ ctx.channel_dim = channel_dim
162
+ xgt0 = x > 0
163
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
164
+ return x
165
+
166
+ @staticmethod
167
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
168
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
169
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
170
+ sign_factor = sign_factor.unsqueeze(-1)
171
+ scale_factor = scale_factor.unsqueeze(-1)
172
+
173
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
174
+ neg_delta_grad = x_grad.abs() * factor
175
+ return (
176
+ x_grad - neg_delta_grad,
177
+ None,
178
+ None,
179
+ None,
180
+ )
181
+
182
+
183
+ class RandomClampFunction(torch.autograd.Function):
184
+ @staticmethod
185
+ def forward(
186
+ ctx,
187
+ x: Tensor,
188
+ min: Optional[float],
189
+ max: Optional[float],
190
+ prob: float,
191
+ reflect: float,
192
+ ) -> Tensor:
193
+ x_clamped = torch.clamp(x, min=min, max=max)
194
+ mask = torch.rand_like(x) < prob
195
+ ans = torch.where(mask, x_clamped, x)
196
+ if x.requires_grad:
197
+ ctx.save_for_backward(ans == x)
198
+ ctx.reflect = reflect
199
+ if reflect != 0.0:
200
+ ans = ans * (1.0 + reflect) - (x * reflect)
201
+ return ans
202
+
203
+ @staticmethod
204
+ def backward(
205
+ ctx, ans_grad: Tensor
206
+ ) -> Tuple[Tensor, None, None, None, None]:
207
+ (is_same,) = ctx.saved_tensors
208
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
209
+ reflect = ctx.reflect
210
+ if reflect != 0.0:
211
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
212
+ return x_grad, None, None, None, None
213
+
214
+
215
+ def random_clamp(
216
+ x: Tensor,
217
+ min: Optional[float] = None,
218
+ max: Optional[float] = None,
219
+ prob: float = 0.5,
220
+ reflect: float = 0.0,
221
+ ):
222
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
223
+
224
+
225
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
226
+ """
227
+ A randomized way of casting a floating point value to half precision.
228
+ """
229
+ if x.dtype == torch.float16:
230
+ return x
231
+ x_abs = x.abs()
232
+ is_too_small = x_abs < min_abs
233
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
234
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
235
+ # for those elements].
236
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
237
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
238
+
239
+
240
+ class RandomGradFunction(torch.autograd.Function):
241
+ """
242
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
243
+ randomized approach that preserves expectations (intended to reduce roundoff).
244
+ """
245
+
246
+ @staticmethod
247
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
248
+ ctx.min_abs = min_abs
249
+ return x
250
+
251
+ @staticmethod
252
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
253
+ if ans_grad.dtype == torch.float16:
254
+ return (
255
+ random_cast_to_half(
256
+ ans_grad.to(torch.float32), min_abs=ctx.min_abs
257
+ ),
258
+ None,
259
+ )
260
+ else:
261
+ return ans_grad, None
262
+
263
+
264
+ class RandomGrad(torch.nn.Module):
265
+ """
266
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
267
+ accuracy of training when using amp (automatic mixed precision)
268
+ """
269
+
270
+ def __init__(self, min_abs: float = 5.0e-06):
271
+ super(RandomGrad, self).__init__()
272
+ self.min_abs = min_abs
273
+
274
+ def forward(self, x: Tensor):
275
+ if (
276
+ torch.jit.is_scripting()
277
+ or not self.training
278
+ or torch.jit.is_tracing()
279
+ ):
280
+ return x
281
+ else:
282
+ return RandomGradFunction.apply(x, self.min_abs)
283
+
284
+
285
+ class SoftmaxFunction(torch.autograd.Function):
286
+ """
287
+ Tries to handle half-precision derivatives in a randomized way that should
288
+ be more accurate for training than the default behavior.
289
+ """
290
+
291
+ @staticmethod
292
+ def forward(ctx, x: Tensor, dim: int):
293
+ ans = x.softmax(dim=dim)
294
+ # if x dtype is float16, x.softmax() returns a float32 because
295
+ # (presumably) that op does not support float16, and autocast
296
+ # is enabled.
297
+ if torch.is_autocast_enabled():
298
+ ans = ans.to(torch.float16)
299
+ ctx.save_for_backward(ans)
300
+ ctx.x_dtype = x.dtype
301
+ ctx.dim = dim
302
+ return ans
303
+
304
+ @staticmethod
305
+ def backward(ctx, ans_grad: Tensor):
306
+ (ans,) = ctx.saved_tensors
307
+ with torch.cuda.amp.autocast(enabled=False):
308
+ ans_grad = ans_grad.to(torch.float32)
309
+ ans = ans.to(torch.float32)
310
+ x_grad = ans_grad * ans
311
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
312
+ return x_grad, None
313
+
314
+
315
+ def softmax(x: Tensor, dim: int):
316
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
317
+ return x.softmax(dim)
318
+
319
+ return SoftmaxFunction.apply(x, dim)
320
+
321
+
322
+ class MaxEigLimiterFunction(torch.autograd.Function):
323
+ @staticmethod
324
+ def forward(
325
+ ctx,
326
+ x: Tensor,
327
+ coeffs: Tensor,
328
+ direction: Tensor,
329
+ channel_dim: int,
330
+ grad_scale: float,
331
+ ) -> Tensor:
332
+ ctx.channel_dim = channel_dim
333
+ ctx.grad_scale = grad_scale
334
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
335
+ return x
336
+
337
+ @staticmethod
338
+ def backward(ctx, x_grad, *args):
339
+ with torch.enable_grad():
340
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
341
+ x_orig.requires_grad = True
342
+ num_channels = x_orig.shape[ctx.channel_dim]
343
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
344
+ new_direction.requires_grad = False
345
+ x = x - x.mean(dim=0)
346
+ x_var = (x ** 2).mean()
347
+ x_residual = x - coeffs * new_direction
348
+ x_residual_var = (x_residual ** 2).mean()
349
+ # `variance_proportion` is the proportion of the variance accounted for
350
+ # by the top eigen-direction. This is to be minimized.
351
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
352
+ variance_proportion.backward()
353
+ x_orig_grad = x_orig.grad
354
+ x_extra_grad = (
355
+ x_orig.grad
356
+ * ctx.grad_scale
357
+ * x_grad.norm()
358
+ / (x_orig_grad.norm() + 1.0e-20)
359
+ )
360
+ return x_grad + x_extra_grad.detach(), None, None, None, None
361
+
362
+
363
+ class BasicNorm(torch.nn.Module):
364
+ """
365
+ This is intended to be a simpler, and hopefully cheaper, replacement for
366
+ LayerNorm. The observation this is based on, is that Transformer-type
367
+ networks, especially with pre-norm, sometimes seem to set one of the
368
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
369
+ the LayerNorm because the output magnitude is then not strongly dependent
370
+ on the other (useful) features. Presumably the weight and bias of the
371
+ LayerNorm are required to allow it to do this.
372
+
373
+ So the idea is to introduce this large constant value as an explicit
374
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
375
+ doesn't have to do this trick. We make the "eps" learnable.
376
+
377
+ Args:
378
+ num_channels: the number of channels, e.g. 512.
379
+ channel_dim: the axis/dimension corresponding to the channel,
380
+ interprted as an offset from the input's ndim if negative.
381
+ shis is NOT the num_channels; it should typically be one of
382
+ {-2, -1, 0, 1, 2, 3}.
383
+ eps: the initial "epsilon" that we add as ballast in:
384
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
385
+ Note: our epsilon is actually large, but we keep the name
386
+ to indicate the connection with conventional LayerNorm.
387
+ learn_eps: if true, we learn epsilon; if false, we keep it
388
+ at the initial value.
389
+ eps_min: float
390
+ eps_max: float
391
+ """
392
+
393
+ def __init__(
394
+ self,
395
+ num_channels: int,
396
+ channel_dim: int = -1, # CAUTION: see documentation.
397
+ eps: float = 0.25,
398
+ learn_eps: bool = True,
399
+ eps_min: float = -3.0,
400
+ eps_max: float = 3.0,
401
+ ) -> None:
402
+ super(BasicNorm, self).__init__()
403
+ self.num_channels = num_channels
404
+ self.channel_dim = channel_dim
405
+ if learn_eps:
406
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
407
+ else:
408
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
409
+ self.eps_min = eps_min
410
+ self.eps_max = eps_max
411
+
412
+ def forward(self, x: Tensor) -> Tensor:
413
+ assert x.shape[self.channel_dim] == self.num_channels
414
+ eps = self.eps
415
+ if self.training and random.random() < 0.25:
416
+ # with probability 0.25, in training mode, clamp eps between the min
417
+ # and max; this will encourage it to learn parameters within the
418
+ # allowed range by making parameters that are outside the allowed
419
+ # range noisy.
420
+
421
+ # gradients to allow the parameter to get back into the allowed
422
+ # region if it happens to exit it.
423
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
424
+ scales = (
425
+ torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
426
+ ) ** -0.5
427
+ return x * scales
428
+
429
+
430
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
431
+ """
432
+ Behaves like a constructor of a modified version of nn.Linear
433
+ that gives an easy way to set the default initial parameter scale.
434
+
435
+ Args:
436
+ Accepts the standard args and kwargs that nn.Linear accepts
437
+ e.g. in_features, out_features, bias=False.
438
+
439
+ initial_scale: you can override this if you want to increase
440
+ or decrease the initial magnitude of the module's output
441
+ (affects the initialization of weight_scale and bias_scale).
442
+ Another option, if you want to do something like this, is
443
+ to re-initialize the parameters.
444
+ """
445
+ ans = nn.Linear(*args, **kwargs)
446
+ with torch.no_grad():
447
+ ans.weight[:] *= initial_scale
448
+ if ans.bias is not None:
449
+ torch.nn.init.uniform_(
450
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
451
+ )
452
+ return ans
453
+
454
+
455
+ def ScaledConv1d(
456
+ *args,
457
+ initial_scale: float = 1.0,
458
+ kernel_size: int = 3,
459
+ padding: str = "same",
460
+ **kwargs,
461
+ ) -> nn.Conv1d:
462
+ """
463
+ Behaves like a constructor of a modified version of nn.Conv1d
464
+ that gives an easy way to set the default initial parameter scale.
465
+
466
+ Args:
467
+ Accepts the standard args and kwargs that nn.Linear accepts
468
+ e.g. in_features, out_features, bias=False.
469
+
470
+ initial_scale: you can override this if you want to increase
471
+ or decrease the initial magnitude of the module's output
472
+ (affects the initialization of weight_scale and bias_scale).
473
+ Another option, if you want to do something like this, is
474
+ to re-initialize the parameters.
475
+ """
476
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
477
+ with torch.no_grad():
478
+ ans.weight[:] *= initial_scale
479
+ if ans.bias is not None:
480
+ torch.nn.init.uniform_(
481
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
482
+ )
483
+ return ans
484
+
485
+
486
+ def TransposeScaledConv1d(
487
+ *args,
488
+ initial_scale: float = 1.0,
489
+ kernel_size: int = 3,
490
+ padding: str = "same",
491
+ **kwargs,
492
+ ) -> nn.Sequential:
493
+ """
494
+ Transpose -> ScaledConv1d
495
+ """
496
+ return nn.Sequential(
497
+ Transpose(),
498
+ ScaledConv1d(
499
+ *args,
500
+ initial_scale=initial_scale,
501
+ kernel_size=kernel_size,
502
+ padding=padding,
503
+ **kwargs,
504
+ ),
505
+ )
506
+
507
+
508
+ def ScaledConv1dTranspose(
509
+ *args,
510
+ initial_scale: float = 1.0,
511
+ kernel_size: int = 3,
512
+ padding: str = "same",
513
+ **kwargs,
514
+ ) -> nn.Sequential:
515
+ """
516
+ Transpose -> ScaledConv1d
517
+ """
518
+ return nn.Sequential(
519
+ ScaledConv1d(
520
+ *args,
521
+ initial_scale=initial_scale,
522
+ kernel_size=kernel_size,
523
+ padding=padding,
524
+ **kwargs,
525
+ ),
526
+ Transpose(),
527
+ )
528
+
529
+
530
+ def TransposeConv1d(
531
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
532
+ ) -> nn.Sequential:
533
+ """
534
+ Transpose -> Conv1d
535
+ """
536
+ return nn.Sequential(
537
+ Transpose(),
538
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
539
+ )
540
+
541
+
542
+ def Conv1dTranspose(
543
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
544
+ ) -> nn.Sequential:
545
+ """
546
+ ScaledConv1d -> Transpose
547
+ """
548
+ return nn.Sequential(
549
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
550
+ Transpose(),
551
+ )
552
+
553
+
554
+ class SRLinear(nn.Linear):
555
+ """https://arxiv.org/abs/2303.06296
556
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
557
+ """
558
+
559
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
560
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
561
+ self.register_buffer(
562
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
563
+ )
564
+ with torch.no_grad():
565
+ sigma = self.get_sigma()
566
+ self.register_buffer("spectral_norm", sigma)
567
+ self.sigma = nn.Parameter(torch.ones(1))
568
+
569
+ def get_sigma(self):
570
+ with torch.no_grad():
571
+ u = self.u
572
+ v = self.weight.mv(u)
573
+ v = nn.functional.normalize(v, dim=0)
574
+ u = self.weight.T.mv(v)
575
+ u = nn.functional.normalize(u, dim=0)
576
+ self.u.data.copy_(u)
577
+ return torch.einsum("c,cd,d->", v, self.weight, u)
578
+
579
+ def get_weight(self):
580
+ sigma = self.get_sigma()
581
+ if self.training:
582
+ self.spectral_norm.data.copy_(sigma)
583
+ weight = (self.sigma / sigma) * self.weight
584
+ return weight
585
+
586
+ def forward(self, x):
587
+ return nn.functional.linear(x, self.get_weight(), self.bias)
588
+
589
+
590
+ class SRConv1d(SRLinear):
591
+ def __init__(
592
+ self,
593
+ in_features,
594
+ out_features,
595
+ kernel_size,
596
+ stride: int = 1,
597
+ padding: str = "same",
598
+ bias: bool = True,
599
+ **kwargs,
600
+ ):
601
+ in_features = in_features * kernel_size
602
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
603
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
604
+ self.kernel_size = kernel_size
605
+ self.stride = stride
606
+ self.padding = padding
607
+
608
+ def forward(self, x):
609
+ in_features = self.in_features // self.kernel_size
610
+ weight = self.get_weight().view(
611
+ self.out_features, in_features, self.kernel_size
612
+ )
613
+ return nn.functional.conv1d(
614
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
615
+ )
616
+
617
+
618
+ def TransposeSRConv1d(
619
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
620
+ ) -> nn.Sequential:
621
+ """
622
+ Transpose -> SRConv1d
623
+ """
624
+ return nn.Sequential(
625
+ Transpose(),
626
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
627
+ )
628
+
629
+
630
+ def SRConv1dTranspose(
631
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
632
+ ) -> nn.Sequential:
633
+ """
634
+ SRConv1d -> Transpose
635
+ """
636
+ return nn.Sequential(
637
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
638
+ Transpose(),
639
+ )
640
+
641
+
642
+ class ActivationBalancer(torch.nn.Module):
643
+ """
644
+ Modifies the backpropped derivatives of a function to try to encourage, for
645
+ each channel, that it is positive at least a proportion `threshold` of the
646
+ time. It does this by multiplying negative derivative values by up to
647
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
648
+ interpolated from 1 at the threshold to those extremal values when none
649
+ of the inputs are positive.
650
+
651
+ Args:
652
+ num_channels: the number of channels
653
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
654
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
655
+ min_positive: the minimum, per channel, of the proportion of the time
656
+ that (x > 0), below which we start to modify the derivatives.
657
+ max_positive: the maximum, per channel, of the proportion of the time
658
+ that (x > 0), above which we start to modify the derivatives.
659
+ max_factor: the maximum factor by which we modify the derivatives for
660
+ either the sign constraint or the magnitude constraint;
661
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
662
+ values in the range [0.98..1.02].
663
+ sign_gain_factor: determines the 'gain' with which we increase the
664
+ change in gradient once the constraints on min_positive and max_positive
665
+ are violated.
666
+ scale_gain_factor: determines the 'gain' with which we increase the
667
+ change in gradient once the constraints on min_abs and max_abs
668
+ are violated.
669
+ min_abs: the minimum average-absolute-value difference from the mean
670
+ value per channel, which we allow, before we start to modify
671
+ the derivatives to prevent this.
672
+ max_abs: the maximum average-absolute-value difference from the mean
673
+ value per channel, which we allow, before we start to modify
674
+ the derivatives to prevent this.
675
+ min_prob: determines the minimum probability with which we modify the
676
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
677
+ on each forward(). This is done randomly to prevent all layers
678
+ from doing it at the same time. Early in training we may use
679
+ higher probabilities than this; it will decay to this value.
680
+ """
681
+
682
+ def __init__(
683
+ self,
684
+ num_channels: int,
685
+ channel_dim: int,
686
+ min_positive: float = 0.05,
687
+ max_positive: float = 0.95,
688
+ max_factor: float = 0.04,
689
+ sign_gain_factor: float = 0.01,
690
+ scale_gain_factor: float = 0.02,
691
+ min_abs: float = 0.2,
692
+ max_abs: float = 100.0,
693
+ min_prob: float = 0.1,
694
+ ):
695
+ super(ActivationBalancer, self).__init__()
696
+ self.num_channels = num_channels
697
+ self.channel_dim = channel_dim
698
+ self.min_positive = min_positive
699
+ self.max_positive = max_positive
700
+ self.max_factor = max_factor
701
+ self.min_abs = min_abs
702
+ self.max_abs = max_abs
703
+ self.min_prob = min_prob
704
+ self.sign_gain_factor = sign_gain_factor
705
+ self.scale_gain_factor = scale_gain_factor
706
+
707
+ # count measures how many times the forward() function has been called.
708
+ # We occasionally sync this to a tensor called `count`, that exists to
709
+ # make sure it is synced to disk when we load and save the model.
710
+ self.cpu_count = 0
711
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
712
+
713
+ def forward(self, x: Tensor) -> Tensor:
714
+ if (
715
+ torch.jit.is_scripting()
716
+ or not x.requires_grad
717
+ or torch.jit.is_tracing()
718
+ ):
719
+ return _no_op(x)
720
+
721
+ count = self.cpu_count
722
+ self.cpu_count += 1
723
+
724
+ if random.random() < 0.01:
725
+ # Occasionally sync self.cpu_count with self.count.
726
+ # count affects the decay of 'prob'. don't do this on every iter,
727
+ # because syncing with the GPU is slow.
728
+ self.cpu_count = max(self.cpu_count, self.count.item())
729
+ self.count.fill_(self.cpu_count)
730
+
731
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
732
+ # a floor at min_prob (==0.1, by default)
733
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
734
+
735
+ if random.random() < prob:
736
+ sign_gain_factor = 0.5
737
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
738
+ sign_factor = _compute_sign_factor(
739
+ x,
740
+ self.channel_dim,
741
+ self.min_positive,
742
+ self.max_positive,
743
+ gain_factor=self.sign_gain_factor / prob,
744
+ max_factor=self.max_factor,
745
+ )
746
+ else:
747
+ sign_factor = None
748
+
749
+ scale_factor = _compute_scale_factor(
750
+ x.detach(),
751
+ self.channel_dim,
752
+ min_abs=self.min_abs,
753
+ max_abs=self.max_abs,
754
+ gain_factor=self.scale_gain_factor / prob,
755
+ max_factor=self.max_factor,
756
+ )
757
+ return ActivationBalancerFunction.apply(
758
+ x,
759
+ scale_factor,
760
+ sign_factor,
761
+ self.channel_dim,
762
+ )
763
+ else:
764
+ return _no_op(x)
765
+
766
+
767
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
768
+ """
769
+ Returns x unmodified, but in backprop will put a penalty for the excess of
770
+ the absolute values of elements of x over the limit "limit". E.g. if
771
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
772
+
773
+ Caution: the value of this penalty will be affected by grad scaling used
774
+ in automatic mixed precision training. For this reasons we use this,
775
+ it shouldn't really matter, or may even be helpful; we just use this
776
+ to disallow really implausible values of scores to be given to softmax.
777
+ """
778
+ x_sign = x.sign()
779
+ over_limit = (x.abs() - limit) > 0
780
+ # The following is a memory efficient way to penalize the absolute values of
781
+ # x that's over the limit. (The memory efficiency comes when you think
782
+ # about which items torch needs to cache for the autograd, and which ones it
783
+ # can throw away). The numerical value of aux_loss as computed here will
784
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
785
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
786
+ # limit).relu().
787
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
788
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
789
+ # sum() due to how with_loss() works.
790
+ x = with_loss(x, aux_loss)
791
+ # you must use x for something, or this will be ineffective.
792
+ return x
793
+
794
+
795
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
796
+ if x.ndim == 2:
797
+ return x.diag()
798
+ else:
799
+ (batch, dim, dim) = x.shape
800
+ x = x.reshape(batch, dim * dim)
801
+ x = x[:, :: dim + 1]
802
+ assert x.shape == (batch, dim)
803
+ return x
804
+
805
+
806
+ def _whitening_metric(x: Tensor, num_groups: int):
807
+ """
808
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
809
+ of the centered feature covariance are the same within each group's covariance matrix
810
+ and also between groups.
811
+ Args:
812
+ x: a Tensor of shape (*, num_channels)
813
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
814
+ Returns:
815
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
816
+ greater than 1.0 otherwise.
817
+ """
818
+ assert x.dtype != torch.float16
819
+ x = x.reshape(-1, x.shape[-1])
820
+ (num_frames, num_channels) = x.shape
821
+ assert num_channels % num_groups == 0
822
+ channels_per_group = num_channels // num_groups
823
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
824
+ # x now has shape (num_groups, num_frames, channels_per_group)
825
+ # subtract the mean so we use the centered, not uncentered, covariance.
826
+ # My experience has been that when we "mess with the gradients" like this,
827
+ # it's better not do anything that tries to move the mean around, because
828
+ # that can easily cause instability.
829
+ x = x - x.mean(dim=1, keepdim=True)
830
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
831
+ x_covar = torch.matmul(x.transpose(1, 2), x)
832
+ x_covar_mean_diag = _diag(x_covar).mean()
833
+ # the following expression is what we'd get if we took the matrix product
834
+ # of each covariance and measured the mean of its trace, i.e.
835
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
836
+ x_covarsq_mean_diag = (x_covar ** 2).sum() / (
837
+ num_groups * channels_per_group
838
+ )
839
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
840
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
841
+ return metric
842
+
843
+
844
+ class WhiteningPenaltyFunction(torch.autograd.Function):
845
+ @staticmethod
846
+ def forward(
847
+ ctx,
848
+ x: Tensor,
849
+ num_groups: int,
850
+ whitening_limit: float,
851
+ grad_scale: float,
852
+ ) -> Tensor:
853
+ ctx.save_for_backward(x)
854
+ ctx.num_groups = num_groups
855
+ ctx.whitening_limit = whitening_limit
856
+ ctx.grad_scale = grad_scale
857
+ return x
858
+
859
+ @staticmethod
860
+ def backward(ctx, x_grad: Tensor):
861
+ (x_orig,) = ctx.saved_tensors
862
+ with torch.enable_grad():
863
+ with torch.cuda.amp.autocast(enabled=False):
864
+ x_detached = x_orig.to(torch.float32).detach()
865
+ x_detached.requires_grad = True
866
+
867
+ metric = _whitening_metric(x_detached, ctx.num_groups)
868
+
869
+ if random.random() < 0.005 or __name__ == "__main__":
870
+ logging.info(
871
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
872
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
873
+ )
874
+
875
+ (metric - ctx.whitening_limit).relu().backward()
876
+ penalty_grad = x_detached.grad
877
+ scale = ctx.grad_scale * (
878
+ x_grad.to(torch.float32).norm()
879
+ / (penalty_grad.norm() + 1.0e-20)
880
+ )
881
+ penalty_grad = penalty_grad * scale
882
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
883
+
884
+
885
+ class Whiten(nn.Module):
886
+ def __init__(
887
+ self,
888
+ num_groups: int,
889
+ whitening_limit: float,
890
+ prob: Union[float, Tuple[float, float]],
891
+ grad_scale: float,
892
+ ):
893
+ """
894
+ Args:
895
+ num_groups: the number of groups to divide the channel dim into before
896
+ whitening. We will attempt to make the feature covariance
897
+ within each group, after mean subtraction, as "white" as possible,
898
+ while having the same trace across all groups.
899
+ whitening_limit: a value greater than 1.0, that dictates how much
900
+ freedom we have to violate the constraints. 1.0 would mean perfectly
901
+ white, with exactly the same trace across groups; larger values
902
+ give more freedom. E.g. 2.0.
903
+ prob: the probability with which we apply the gradient modification
904
+ (also affects the grad scale). May be supplied as a float,
905
+ or as a pair (min_prob, max_prob)
906
+
907
+ grad_scale: determines the scale on the gradient term from this object,
908
+ relative to the rest of the gradient on the attention weights.
909
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
910
+ """
911
+ super(Whiten, self).__init__()
912
+ assert num_groups >= 1
913
+ assert whitening_limit >= 1
914
+ assert grad_scale >= 0
915
+ self.num_groups = num_groups
916
+ self.whitening_limit = whitening_limit
917
+ if isinstance(prob, float):
918
+ assert 0 < prob <= 1
919
+ self.prob = prob
920
+ else:
921
+ (self.min_prob, self.max_prob) = prob
922
+ assert 0 < self.min_prob < self.max_prob <= 1
923
+ self.prob = self.max_prob
924
+
925
+ self.grad_scale = grad_scale
926
+
927
+ def forward(self, x: Tensor) -> Tensor:
928
+ """
929
+ In the forward pass, this function just returns the input unmodified.
930
+ In the backward pass, it will modify the gradients to ensure that the
931
+ distribution in each group has close to (lambda times I) as the covariance
932
+ after mean subtraction, with the same lambda across groups.
933
+ For whitening_limit > 1, there will be more freedom to violate this
934
+ constraint.
935
+
936
+ Args:
937
+ x: the input of shape (*, num_channels)
938
+
939
+ Returns:
940
+ x, unmodified. You should make sure
941
+ you use the returned value, or the graph will be freed
942
+ and nothing will happen in backprop.
943
+ """
944
+ if (
945
+ not x.requires_grad
946
+ or random.random() > self.prob
947
+ or self.grad_scale == 0
948
+ ):
949
+ return _no_op(x)
950
+ else:
951
+ if hasattr(self, "min_prob") and random.random() < 0.25:
952
+ # occasionally switch between min_prob and max_prob, based on whether
953
+ # we are above or below the threshold.
954
+ if (
955
+ _whitening_metric(x.to(torch.float32), self.num_groups)
956
+ > self.whitening_limit
957
+ ):
958
+ # there would be a change to the grad.
959
+ self.prob = self.max_prob
960
+ else:
961
+ self.prob = self.min_prob
962
+
963
+ return WhiteningPenaltyFunction.apply(
964
+ x, self.num_groups, self.whitening_limit, self.grad_scale
965
+ )
966
+
967
+
968
+ class WithLoss(torch.autograd.Function):
969
+ @staticmethod
970
+ def forward(ctx, x: Tensor, y: Tensor):
971
+ ctx.y_shape = y.shape
972
+ return x
973
+
974
+ @staticmethod
975
+ def backward(ctx, ans_grad: Tensor):
976
+ return ans_grad, torch.ones(
977
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
978
+ )
979
+
980
+
981
+ def with_loss(x, y):
982
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
983
+ return x
984
+ # returns x but adds y.sum() to the loss function.
985
+ return WithLoss.apply(x, y)
986
+
987
+
988
+ def _no_op(x: Tensor) -> Tensor:
989
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
990
+ return x
991
+ else:
992
+ # a no-op function that will have a node in the autograd graph,
993
+ # to avoid certain bugs relating to backward hooks
994
+ return x.chunk(1, dim=-1)[0]
995
+
996
+
997
+ class Identity(torch.nn.Module):
998
+ def __init__(self):
999
+ super(Identity, self).__init__()
1000
+
1001
+ def forward(self, x):
1002
+ return _no_op(x)
1003
+
1004
+
1005
+ class MaxEig(torch.nn.Module):
1006
+ """
1007
+ Modifies the backpropped derivatives of a function to try to discourage
1008
+ that any given direction in activation space accounts for more than
1009
+ a specified proportion of the covariance (e.g. 0.2).
1010
+
1011
+
1012
+ Args:
1013
+ num_channels: the number of channels
1014
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
1015
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
1016
+ max_var_per_eig: the maximum proportion of the variance of the
1017
+ features/channels, after mean subtraction, that can come from
1018
+ any given eigenvalue.
1019
+ min_prob: the minimum probability with which we apply this during any invocation
1020
+ of forward(), assuming last time we applied the constraint it was
1021
+ not active; supplied for speed.
1022
+ scale: determines the scale with which we modify the gradients, relative
1023
+ to the existing / unmodified gradients
1024
+ """
1025
+
1026
+ def __init__(
1027
+ self,
1028
+ num_channels: int,
1029
+ channel_dim: int,
1030
+ max_var_per_eig: float = 0.2,
1031
+ min_prob: float = 0.01,
1032
+ scale: float = 0.01,
1033
+ ):
1034
+ super(MaxEig, self).__init__()
1035
+ self.num_channels = num_channels
1036
+ self.channel_dim = channel_dim
1037
+ self.scale = scale
1038
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
1039
+ self.max_var_per_eig = max_var_per_eig
1040
+
1041
+ # we figure out the dominant direction using the power method: starting with
1042
+ # a random vector, keep multiplying by the covariance and renormalizing.
1043
+ with torch.no_grad():
1044
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1045
+ # random parameters unchanged for comparison
1046
+ direction = torch.arange(num_channels).to(torch.float)
1047
+ direction = direction / direction.norm()
1048
+ self.register_buffer("max_eig_direction", direction)
1049
+
1050
+ self.min_prob = min_prob
1051
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1052
+ # We'll regress this towards prob, each time we try to apply it and it is not
1053
+ # active.
1054
+ self.cur_prob = 1.0
1055
+
1056
+ def forward(self, x: Tensor) -> Tensor:
1057
+ if (
1058
+ torch.jit.is_scripting()
1059
+ or self.max_var_per_eig <= 0
1060
+ or random.random() > self.cur_prob
1061
+ or torch.jit.is_tracing()
1062
+ ):
1063
+ return _no_op(x)
1064
+
1065
+ with torch.cuda.amp.autocast(enabled=False):
1066
+ eps = 1.0e-20
1067
+ orig_x = x
1068
+ x = x.to(torch.float32)
1069
+ with torch.no_grad():
1070
+ x = x.transpose(self.channel_dim, -1).reshape(
1071
+ -1, self.num_channels
1072
+ )
1073
+ x = x - x.mean(dim=0)
1074
+ new_direction, coeffs = self._find_direction_coeffs(
1075
+ x, self.max_eig_direction
1076
+ )
1077
+ x_var = (x ** 2).mean()
1078
+ x_residual = x - coeffs * new_direction
1079
+ x_residual_var = (x_residual ** 2).mean()
1080
+
1081
+ # `variance_proportion` is the proportion of the variance accounted for
1082
+ # by the top eigen-direction.
1083
+ variance_proportion = (x_var - x_residual_var) / (
1084
+ x_var + 1.0e-20
1085
+ )
1086
+
1087
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1088
+ self._set_direction(
1089
+ 0.1 * self.max_eig_direction + new_direction
1090
+ )
1091
+
1092
+ if random.random() < 0.01 or __name__ == "__main__":
1093
+ logging.info(
1094
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1095
+ )
1096
+
1097
+ if variance_proportion >= self.max_var_per_eig:
1098
+ # The constraint is active. Note, we should quite rarely
1099
+ # reach here, only near the beginning of training if we are
1100
+ # starting to diverge, should this constraint be active.
1101
+ cur_prob = self.cur_prob
1102
+ self.cur_prob = (
1103
+ 1.0 # next time, do the update with probability 1.0.
1104
+ )
1105
+ return MaxEigLimiterFunction.apply(
1106
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1107
+ )
1108
+ else:
1109
+ # let self.cur_prob exponentially approach self.min_prob, as
1110
+ # long as the constraint is inactive.
1111
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1112
+ return orig_x
1113
+
1114
+ def _set_direction(self, direction: Tensor):
1115
+ """
1116
+ Sets self.max_eig_direction to a normalized version of `direction`
1117
+ """
1118
+ direction = direction.detach()
1119
+ direction = direction / direction.norm()
1120
+ direction_sum = direction.sum().item()
1121
+ if direction_sum - direction_sum == 0: # no inf/nan
1122
+ self.max_eig_direction[:] = direction
1123
+ else:
1124
+ logging.info(
1125
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1126
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1127
+ )
1128
+
1129
+ def _find_direction_coeffs(
1130
+ self, x: Tensor, prev_direction: Tensor
1131
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1132
+ """
1133
+ Figure out (an approximation to) the proportion of the variance of a set of
1134
+ feature vectors that can be attributed to the top eigen-direction.
1135
+ Args:
1136
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1137
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1138
+ of the top eigen-direction, or a random direction if this is the first
1139
+ iteration. Does not have to be normalized, but should be nonzero.
1140
+
1141
+ Returns: (cur_direction, coeffs), where:
1142
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1143
+ estimate of the top eigen-direction.
1144
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1145
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1146
+ """
1147
+ (num_frames, num_channels) = x.shape
1148
+ assert num_channels > 1 and num_frames > 1
1149
+ assert prev_direction.shape == (num_channels,)
1150
+ # `coeffs` are the coefficients of `prev_direction` in x.
1151
+ # actually represent the coeffs up to a constant positive factor.
1152
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1153
+ cur_direction = (x * coeffs).sum(dim=0) / (
1154
+ (coeffs ** 2).sum() + 1.0e-20
1155
+ )
1156
+ return cur_direction, coeffs
1157
+
1158
+
1159
+ class DoubleSwishFunction(torch.autograd.Function):
1160
+ """
1161
+ double_swish(x) = x * torch.sigmoid(x-1)
1162
+ This is a definition, originally motivated by its close numerical
1163
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1164
+
1165
+ Memory-efficient derivative computation:
1166
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1167
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1168
+ Now, s'(x) = s(x) * (1-s(x)).
1169
+ double_swish'(x) = x * s'(x) + s(x).
1170
+ = x * s(x) * (1-s(x)) + s(x).
1171
+ = double_swish(x) * (1-s(x)) + s(x)
1172
+ ... so we just need to remember s(x) but not x itself.
1173
+ """
1174
+
1175
+ @staticmethod
1176
+ def forward(ctx, x: Tensor) -> Tensor:
1177
+ requires_grad = x.requires_grad
1178
+ x_dtype = x.dtype
1179
+ if x.dtype == torch.float16:
1180
+ x = x.to(torch.float32)
1181
+
1182
+ s = torch.sigmoid(x - 1.0)
1183
+ y = x * s
1184
+
1185
+ if requires_grad:
1186
+ deriv = y * (1 - s) + s
1187
+ # notes on derivative of x * sigmoid(x - 1):
1188
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1189
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1190
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1191
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1192
+ # floors), should be expectation-preserving.
1193
+ floor = -0.043637
1194
+ ceil = 1.2
1195
+ d_scaled = (deriv - floor) * (
1196
+ 255.0 / (ceil - floor)
1197
+ ) + torch.rand_like(deriv)
1198
+ if __name__ == "__main__":
1199
+ # for self-testing only.
1200
+ assert d_scaled.min() >= 0.0
1201
+ assert d_scaled.max() < 256.0
1202
+ d_int = d_scaled.to(torch.uint8)
1203
+ ctx.save_for_backward(d_int)
1204
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1205
+ y = y.to(torch.float16)
1206
+ return y
1207
+
1208
+ @staticmethod
1209
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1210
+ (d,) = ctx.saved_tensors
1211
+ # the same constants as used in forward pass.
1212
+ floor = -0.043637
1213
+ ceil = 1.2
1214
+ d = d * ((ceil - floor) / 255.0) + floor
1215
+ return y_grad * d
1216
+
1217
+
1218
+ class DoubleSwish(torch.nn.Module):
1219
+ def forward(self, x: Tensor) -> Tensor:
1220
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1221
+ that we approximate closely with x * sigmoid(x-1).
1222
+ """
1223
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1224
+ return x * torch.sigmoid(x - 1.0)
1225
+ return DoubleSwishFunction.apply(x)
1226
+
1227
+
1228
+ def BalancedDoubleSwish(
1229
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1230
+ ) -> nn.Sequential:
1231
+ """
1232
+ ActivationBalancer -> DoubleSwish
1233
+ """
1234
+ balancer = ActivationBalancer(
1235
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1236
+ )
1237
+ return nn.Sequential(
1238
+ balancer,
1239
+ DoubleSwish(),
1240
+ )
1241
+
1242
+
1243
+ def _test_max_eig():
1244
+ for proportion in [0.1, 0.5, 10.0]:
1245
+ logging.info(f"proportion = {proportion}")
1246
+ x = torch.randn(100, 128)
1247
+ direction = torch.randn(128)
1248
+ coeffs = torch.randn(100, 1)
1249
+ x += proportion * direction * coeffs
1250
+
1251
+ x.requires_grad = True
1252
+
1253
+ num_channels = 128
1254
+ m = MaxEig(
1255
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1256
+ ) # grad_scale
1257
+
1258
+ for _ in range(4):
1259
+ y = m(x)
1260
+
1261
+ y_grad = torch.randn_like(x)
1262
+ y.backward(gradient=y_grad)
1263
+
1264
+ if proportion < 0.2:
1265
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1266
+ elif proportion > 1.0:
1267
+ assert not torch.allclose(x.grad, y_grad)
1268
+
1269
+
1270
+ def _test_whiten():
1271
+ for proportion in [0.1, 0.5, 10.0]:
1272
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1273
+ x = torch.randn(100, 128)
1274
+ direction = torch.randn(128)
1275
+ coeffs = torch.randn(100, 1)
1276
+ x += proportion * direction * coeffs
1277
+
1278
+ x.requires_grad = True
1279
+
1280
+ num_channels = 128
1281
+ m = Whiten(
1282
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1283
+ ) # grad_scale
1284
+
1285
+ for _ in range(4):
1286
+ y = m(x)
1287
+
1288
+ y_grad = torch.randn_like(x)
1289
+ y.backward(gradient=y_grad)
1290
+
1291
+ if proportion < 0.2:
1292
+ assert torch.allclose(x.grad, y_grad)
1293
+ elif proportion > 1.0:
1294
+ assert not torch.allclose(x.grad, y_grad)
1295
+
1296
+
1297
+ def _test_activation_balancer_sign():
1298
+ probs = torch.arange(0, 1, 0.01)
1299
+ N = 1000
1300
+ x = 1.0 * (
1301
+ (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
1302
+ )
1303
+ x = x.detach()
1304
+ x.requires_grad = True
1305
+ m = ActivationBalancer(
1306
+ probs.numel(),
1307
+ channel_dim=0,
1308
+ min_positive=0.05,
1309
+ max_positive=0.95,
1310
+ max_factor=0.2,
1311
+ min_abs=0.0,
1312
+ )
1313
+
1314
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1315
+
1316
+ y = m(x)
1317
+ y.backward(gradient=y_grad)
1318
+ print("_test_activation_balancer_sign: x = ", x)
1319
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1320
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1321
+
1322
+
1323
+ def _test_activation_balancer_magnitude():
1324
+ magnitudes = torch.arange(0, 1, 0.01)
1325
+ N = 1000
1326
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
1327
+ -1
1328
+ )
1329
+ x = x.detach()
1330
+ x.requires_grad = True
1331
+ m = ActivationBalancer(
1332
+ magnitudes.numel(),
1333
+ channel_dim=0,
1334
+ min_positive=0.0,
1335
+ max_positive=1.0,
1336
+ max_factor=0.2,
1337
+ min_abs=0.2,
1338
+ max_abs=0.8,
1339
+ min_prob=1.0,
1340
+ )
1341
+
1342
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1343
+
1344
+ y = m(x)
1345
+ y.backward(gradient=y_grad)
1346
+ print("_test_activation_balancer_magnitude: x = ", x)
1347
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1348
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1349
+
1350
+
1351
+ def _test_basic_norm():
1352
+ num_channels = 128
1353
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1354
+
1355
+ x = torch.randn(500, num_channels)
1356
+
1357
+ y = m(x)
1358
+
1359
+ assert y.shape == x.shape
1360
+ x_rms = (x ** 2).mean().sqrt()
1361
+ y_rms = (y ** 2).mean().sqrt()
1362
+ print("x rms = ", x_rms)
1363
+ print("y rms = ", y_rms)
1364
+ assert y_rms < x_rms
1365
+ assert y_rms > 0.5 * x_rms
1366
+
1367
+
1368
+ def _test_double_swish_deriv():
1369
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1370
+ x.requires_grad = True
1371
+ m = DoubleSwish()
1372
+
1373
+ tol = (1.2 - (-0.043637)) / 255.0
1374
+ torch.autograd.gradcheck(m, x, atol=tol)
1375
+
1376
+ # for self-test.
1377
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1378
+ x.requires_grad = True
1379
+ y = m(x)
1380
+
1381
+
1382
+ def _test_softmax():
1383
+ a = torch.randn(2, 10, dtype=torch.float64)
1384
+ b = a.clone()
1385
+ a.requires_grad = True
1386
+ b.requires_grad = True
1387
+ a.softmax(dim=1)[:, 0].sum().backward()
1388
+ print("a grad = ", a.grad)
1389
+ softmax(b, dim=1)[:, 0].sum().backward()
1390
+ print("b grad = ", b.grad)
1391
+ assert torch.allclose(a.grad, b.grad)
1392
+
1393
+
1394
+ if __name__ == "__main__":
1395
+ logging.getLogger().setLevel(logging.INFO)
1396
+ torch.set_num_threads(1)
1397
+ torch.set_num_interop_threads(1)
1398
+ _test_softmax()
1399
+ _test_whiten()
1400
+ _test_max_eig()
1401
+ _test_activation_balancer_sign()
1402
+ _test_activation_balancer_magnitude()
1403
+ _test_basic_norm()
1404
+ _test_double_swish_deriv()
slam_llm/models/vallex/transformers.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numbers
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+ from torch.nn import functional as F
9
+
10
+ from .activation import MultiheadAttention
11
+ from .scaling import BasicNorm as _BasicNorm
12
+
13
+ _shape_t = Union[int, List[int], torch.Size]
14
+
15
+
16
+ class LayerNorm(nn.Module):
17
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
18
+ normalized_shape: Tuple[int, ...]
19
+ eps: float
20
+ elementwise_affine: bool
21
+
22
+ def __init__(
23
+ self,
24
+ normalized_shape: _shape_t,
25
+ eps: float = 1e-5,
26
+ elementwise_affine: bool = True,
27
+ device=None,
28
+ dtype=None,
29
+ ) -> None:
30
+ factory_kwargs = {"device": device, "dtype": dtype}
31
+ super(LayerNorm, self).__init__()
32
+ if isinstance(normalized_shape, numbers.Integral):
33
+ # mypy error: incompatible types in assignment
34
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
35
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
36
+ self.eps = eps
37
+ self.elementwise_affine = elementwise_affine
38
+ if self.elementwise_affine:
39
+ self.weight = nn.Parameter(
40
+ torch.empty(self.normalized_shape, **factory_kwargs)
41
+ )
42
+ self.bias = nn.Parameter(
43
+ torch.empty(self.normalized_shape, **factory_kwargs)
44
+ )
45
+ else:
46
+ self.register_parameter("weight", None)
47
+ self.register_parameter("bias", None)
48
+
49
+ self.reset_parameters()
50
+
51
+ def reset_parameters(self) -> None:
52
+ if self.elementwise_affine:
53
+ nn.init.ones_(self.weight)
54
+ nn.init.zeros_(self.bias)
55
+
56
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
57
+ if isinstance(input, tuple):
58
+ input, embedding = input
59
+ return (
60
+ F.layer_norm(
61
+ input,
62
+ self.normalized_shape,
63
+ self.weight,
64
+ self.bias,
65
+ self.eps,
66
+ ),
67
+ embedding,
68
+ )
69
+
70
+ assert embedding is None
71
+ return F.layer_norm(
72
+ input, self.normalized_shape, self.weight, self.bias, self.eps
73
+ )
74
+
75
+ def extra_repr(self) -> str:
76
+ return (
77
+ "{normalized_shape}, eps={eps}, "
78
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
79
+ )
80
+
81
+
82
+ class AdaptiveLayerNorm(nn.Module):
83
+ r"""Adaptive Layer Normalization"""
84
+
85
+ def __init__(self, d_model, norm) -> None:
86
+ super(AdaptiveLayerNorm, self).__init__()
87
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
88
+ self.norm = norm
89
+ self.d_model = d_model
90
+ self.eps = self.norm.eps
91
+
92
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
93
+ if isinstance(input, tuple):
94
+ input, embedding = input
95
+ weight, bias = torch.split(
96
+ self.project_layer(embedding),
97
+ split_size_or_sections=self.d_model,
98
+ dim=-1,
99
+ )
100
+ return (weight * self.norm(input) + bias, embedding)
101
+
102
+ weight, bias = torch.split(
103
+ self.project_layer(embedding),
104
+ split_size_or_sections=self.d_model,
105
+ dim=-1,
106
+ )
107
+ return weight * self.norm(input) + bias
108
+
109
+
110
+ class BasicNorm(_BasicNorm):
111
+ def __init__(
112
+ self,
113
+ d_model: int,
114
+ eps: float = 1e-5,
115
+ device=None,
116
+ dtype=None,
117
+ ):
118
+ super(BasicNorm, self).__init__(d_model, eps=eps)
119
+
120
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
121
+ if isinstance(input, tuple):
122
+ input, embedding = input
123
+ return (
124
+ super(BasicNorm, self).forward(input),
125
+ embedding,
126
+ )
127
+
128
+ assert embedding is None
129
+ return super(BasicNorm, self).forward(input)
130
+
131
+
132
+ class BalancedBasicNorm(nn.Module):
133
+ def __init__(
134
+ self,
135
+ d_model: int,
136
+ eps: float = 1e-5,
137
+ device=None,
138
+ dtype=None,
139
+ ):
140
+ super(BalancedBasicNorm, self).__init__()
141
+ self.balancer = ActivationBalancer(
142
+ d_model,
143
+ channel_dim=-1,
144
+ min_positive=0.45,
145
+ max_positive=0.55,
146
+ max_abs=6.0,
147
+ )
148
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
149
+
150
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
151
+ if isinstance(input, tuple):
152
+ input, embedding = input
153
+ return self.norm((self.balancer(input), embedding))
154
+
155
+ assert embedding is None
156
+ return self.norm(self.balancer(input))
157
+
158
+
159
+ class IdentityNorm(nn.Module):
160
+ def __init__(
161
+ self,
162
+ d_model: int,
163
+ eps: float = 1e-5,
164
+ device=None,
165
+ dtype=None,
166
+ ) -> None:
167
+ super(IdentityNorm, self).__init__()
168
+
169
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
170
+ if isinstance(input, tuple):
171
+ return input
172
+
173
+ assert embedding is None
174
+ return input
175
+
176
+
177
+ class TransformerEncoderLayer(nn.Module):
178
+ __constants__ = ["batch_first", "norm_first"]
179
+
180
+ def __init__(
181
+ self,
182
+ d_model: int,
183
+ nhead: int,
184
+ dim_feedforward: int = 2048,
185
+ dropout: float = 0.1,
186
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
187
+ batch_first: bool = False,
188
+ norm_first: bool = False,
189
+ device=None,
190
+ dtype=None,
191
+ linear1_self_attention_cls: nn.Module = nn.Linear,
192
+ linear2_self_attention_cls: nn.Module = nn.Linear,
193
+ linear1_feedforward_cls: nn.Module = nn.Linear,
194
+ linear2_feedforward_cls: nn.Module = nn.Linear,
195
+ layer_norm_cls: nn.Module = LayerNorm,
196
+ layer_norm_eps: float = 1e-5,
197
+ adaptive_layer_norm=False,
198
+ ) -> None:
199
+ factory_kwargs = {"device": device, "dtype": dtype}
200
+ super(TransformerEncoderLayer, self).__init__()
201
+ self.self_attn = MultiheadAttention(
202
+ d_model,
203
+ nhead,
204
+ dropout=dropout,
205
+ batch_first=batch_first,
206
+ linear1_cls=linear1_self_attention_cls,
207
+ linear2_cls=linear2_self_attention_cls,
208
+ **factory_kwargs,
209
+ )
210
+
211
+ # Implementation of Feedforward model
212
+
213
+ self.dropout = nn.Dropout(dropout)
214
+
215
+ self.norm_first = norm_first
216
+ self.dropout1 = nn.Dropout(dropout)
217
+ self.dropout2 = nn.Dropout(dropout)
218
+
219
+ # Legacy string support for activation function.
220
+ if isinstance(activation, str):
221
+ activation = _get_activation_fn(activation)
222
+ elif isinstance(activation, partial):
223
+ activation = activation(d_model)
224
+
225
+ self.activation = activation
226
+
227
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
228
+ if layer_norm_cls == IdentityNorm:
229
+ norm2 = BalancedBasicNorm(
230
+ d_model, eps=layer_norm_eps, **factory_kwargs
231
+ )
232
+ else:
233
+ norm2 = layer_norm_cls(
234
+ d_model, eps=layer_norm_eps, **factory_kwargs
235
+ )
236
+
237
+ self.norm1 = norm1
238
+ self.linear1 = linear1_feedforward_cls(
239
+ d_model, dim_feedforward, **factory_kwargs
240
+ )
241
+ self.linear2 = linear2_feedforward_cls(
242
+ dim_feedforward, d_model, **factory_kwargs
243
+ )
244
+ self.norm2 = norm2
245
+
246
+
247
+ def __setstate__(self, state):
248
+ super(TransformerEncoderLayer, self).__setstate__(state)
249
+ if not hasattr(self, "activation"):
250
+ self.activation = F.relu
251
+
252
+ def forward(
253
+ self,
254
+ src: Tensor,
255
+ src_mask: Optional[Tensor] = None,
256
+ src_key_padding_mask: Optional[Tensor] = None,
257
+ ) -> Tensor:
258
+ r"""Pass the input through the encoder layer.
259
+
260
+ Args:
261
+ src: the sequence to the encoder layer (required).
262
+ src_mask: the mask for the src sequence (optional).
263
+ src_key_padding_mask: the mask for the src keys per batch (optional).
264
+
265
+ Shape:
266
+ see the docs in Transformer class.
267
+ """
268
+ x, stage_embedding = src, None
269
+ is_src_tuple = False
270
+ if isinstance(src, tuple):
271
+ x, stage_embedding = src
272
+ is_src_tuple = True
273
+
274
+ if src_key_padding_mask is not None:
275
+ _skpm_dtype = src_key_padding_mask.dtype
276
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
277
+ src_key_padding_mask
278
+ ):
279
+ raise AssertionError(
280
+ "only bool and floating types of key_padding_mask are supported"
281
+ )
282
+
283
+ if self.norm_first:
284
+ x = x + self._sa_block(
285
+ self.norm1(x, stage_embedding),
286
+ src_mask,
287
+ src_key_padding_mask,
288
+ )
289
+
290
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
291
+ else:
292
+ x = self.norm1(
293
+ x + self._sa_block(x, src_mask, src_key_padding_mask),
294
+ stage_embedding,
295
+ )
296
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
297
+
298
+ if is_src_tuple:
299
+ return (x, stage_embedding)
300
+ return x
301
+
302
+ def infer(
303
+ self,
304
+ src: Tensor,
305
+ src_mask: Optional[Tensor] = None,
306
+ src_key_padding_mask: Optional[Tensor] = None,
307
+ past_kv: Optional[Tensor] = None,
308
+ use_cache: bool = False,
309
+ ):
310
+ x, stage_embedding = src, None
311
+ is_src_tuple = False
312
+ if isinstance(src, tuple):
313
+ x, stage_embedding = src
314
+ is_src_tuple = True
315
+
316
+ if src_key_padding_mask is not None:
317
+ _skpm_dtype = src_key_padding_mask.dtype
318
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
319
+ src_key_padding_mask
320
+ ):
321
+ raise AssertionError(
322
+ "only bool and floating types of key_padding_mask are supported"
323
+ )
324
+
325
+ if self.norm_first:
326
+ x_attn_out, kv = self.self_attn.infer(
327
+ self.norm1(x, stage_embedding),
328
+ attn_mask=src_mask,
329
+ key_padding_mask=src_key_padding_mask,
330
+ need_weights=False,
331
+ past_kv=past_kv,
332
+ use_cache=use_cache,
333
+ )
334
+ x = x + x_attn_out
335
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
336
+
337
+ if is_src_tuple:
338
+ return (x, stage_embedding)
339
+ return (x, kv)
340
+
341
+ # self-attention block
342
+ def _sa_block(
343
+ self,
344
+ x: Tensor,
345
+ attn_mask: Optional[Tensor],
346
+ key_padding_mask: Optional[Tensor],
347
+ ) -> Tensor:
348
+ x = self.self_attn(
349
+ x,
350
+ x,
351
+ x,
352
+ attn_mask=attn_mask,
353
+ key_padding_mask=key_padding_mask,
354
+ need_weights=False,
355
+ )[0]
356
+ return self.dropout1(x)
357
+
358
+ # feed forward block
359
+ def _ff_block(self, x: Tensor) -> Tensor:
360
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
361
+ return self.dropout2(x)
362
+
363
+
364
+ class TransformerEncoder(nn.Module):
365
+ __constants__ = ["norm"]
366
+
367
+ def __init__(self, encoder_layer, num_layers, norm=None):
368
+ super(TransformerEncoder, self).__init__()
369
+ self.layers = _get_clones(encoder_layer, num_layers)
370
+ self.num_layers = num_layers
371
+ self.norm = norm
372
+
373
+ def forward(
374
+ self,
375
+ src: Tensor,
376
+ mask: Optional[Tensor] = None,
377
+ src_key_padding_mask: Optional[Tensor] = None,
378
+ return_layer_states: bool = False,
379
+ ) -> Tensor:
380
+ output = src
381
+ for i, mod in enumerate(self.layers):
382
+ output = mod(
383
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
384
+ )
385
+ # print(i, output.mean())
386
+ if self.norm is not None:
387
+ output = self.norm(output)
388
+
389
+ return output
390
+
391
+ def infer(
392
+ self,
393
+ src: Tensor,
394
+ mask: Optional[Tensor] = None,
395
+ src_key_padding_mask: Optional[Tensor] = None,
396
+ return_layer_states: bool = False,
397
+ past_kv: Optional[Tensor] = None,
398
+ use_cache: bool = False,
399
+ ):
400
+ if past_kv is None:
401
+ past_length = 0
402
+ past_kv = tuple([None] * self.num_layers)
403
+ else:
404
+ past_length = past_kv[0][0].size(-2)
405
+ new_kv = () if use_cache else None
406
+ output = src
407
+ for i, (mod, past_layer_kv) in enumerate(zip(self.layers, past_kv)):
408
+ output, kv = mod.infer(
409
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
410
+ )
411
+ # print(i, output.mean())
412
+ if use_cache:
413
+ new_kv = new_kv + (kv,)
414
+
415
+ if self.norm is not None:
416
+ output = self.norm(output)
417
+
418
+ return output, new_kv
419
+
420
+
421
+ class TransformerDecoderLayer(nn.Module):
422
+ __constants__ = ["batch_first", "norm_first"]
423
+
424
+ def __init__(
425
+ self,
426
+ d_model: int,
427
+ nhead: int,
428
+ dim_feedforward: int = 2048,
429
+ dropout: float = 0.1,
430
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
431
+ linear1_self_attention_cls: nn.Module = nn.Linear,
432
+ linear2_self_attention_cls: nn.Module = nn.Linear,
433
+ linear1_feedforward_cls: nn.Module = nn.Linear,
434
+ linear2_feedforward_cls: nn.Module = nn.Linear,
435
+ batch_first: bool = False,
436
+ norm_first: bool = False,
437
+ device=None,
438
+ dtype=None,
439
+ layer_norm_cls: nn.Module = LayerNorm,
440
+ layer_norm_eps: float = 1e-5,
441
+ adaptive_layer_norm=False,
442
+ ) -> None:
443
+ factory_kwargs = {"device": device, "dtype": dtype}
444
+ super(TransformerDecoderLayer, self).__init__()
445
+ self.self_attn = MultiheadAttention(
446
+ d_model,
447
+ nhead,
448
+ dropout=dropout,
449
+ batch_first=batch_first,
450
+ linear1_cls=linear1_self_attention_cls,
451
+ linear2_cls=linear2_self_attention_cls,
452
+ **factory_kwargs,
453
+ )
454
+ self.multihead_attn = MultiheadAttention(
455
+ d_model,
456
+ nhead,
457
+ dropout=dropout,
458
+ batch_first=batch_first,
459
+ linear1_cls=linear1_self_attention_cls,
460
+ linear2_cls=linear2_self_attention_cls,
461
+ **factory_kwargs,
462
+ )
463
+ # Implementation of Feedforward model
464
+ self.linear1 = linear1_feedforward_cls(
465
+ d_model, dim_feedforward, **factory_kwargs
466
+ )
467
+ self.dropout = nn.Dropout(dropout)
468
+ self.linear2 = linear2_feedforward_cls(
469
+ dim_feedforward, d_model, **factory_kwargs
470
+ )
471
+
472
+ self.norm_first = norm_first
473
+ self.dropout1 = nn.Dropout(dropout)
474
+ self.dropout2 = nn.Dropout(dropout)
475
+ self.dropout3 = nn.Dropout(dropout)
476
+
477
+ # Legacy string support for activation function.
478
+ if isinstance(activation, str):
479
+ self.activation = _get_activation_fn(activation)
480
+ elif isinstance(activation, partial):
481
+ self.activation = activation(d_model)
482
+ else:
483
+ self.activation = activation
484
+
485
+ if adaptive_layer_norm:
486
+ norm1 = layer_norm_cls(
487
+ d_model, eps=layer_norm_eps, **factory_kwargs
488
+ )
489
+ norm2 = layer_norm_cls(
490
+ d_model, eps=layer_norm_eps, **factory_kwargs
491
+ )
492
+ norm3 = layer_norm_cls(
493
+ d_model, eps=layer_norm_eps, **factory_kwargs
494
+ )
495
+
496
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
497
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
498
+ self.norm3 = AdaptiveLayerNorm(d_model, norm3)
499
+ else:
500
+ self.norm1 = layer_norm_cls(
501
+ d_model, eps=layer_norm_eps, **factory_kwargs
502
+ )
503
+ self.norm2 = layer_norm_cls(
504
+ d_model, eps=layer_norm_eps, **factory_kwargs
505
+ )
506
+ if layer_norm_cls == IdentityNorm:
507
+ self.norm3 = BalancedBasicNorm(
508
+ d_model, eps=layer_norm_eps, **factory_kwargs
509
+ )
510
+ else:
511
+ self.norm3 = layer_norm_cls(
512
+ d_model, eps=layer_norm_eps, **factory_kwargs
513
+ )
514
+
515
+ def forward(
516
+ self,
517
+ tgt: Tensor,
518
+ memory: Tensor,
519
+ tgt_mask: Optional[Tensor] = None,
520
+ memory_mask: Optional[Tensor] = None,
521
+ tgt_key_padding_mask: Optional[Tensor] = None,
522
+ memory_key_padding_mask: Optional[Tensor] = None,
523
+ ) -> Tensor:
524
+ tgt_is_tuple = False
525
+ if isinstance(tgt, tuple):
526
+ x, stage_embedding = tgt
527
+ tgt_is_tuple = True
528
+ else:
529
+ x, stage_embedding = tgt, None
530
+
531
+ if self.norm_first:
532
+ x = x + self._sa_block(
533
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
534
+ )
535
+ x = x + self._mha_block(
536
+ self.norm2(x, stage_embedding),
537
+ memory,
538
+ memory_mask,
539
+ memory_key_padding_mask,
540
+ )
541
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
542
+ else:
543
+ x = self.norm1(
544
+ x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
545
+ stage_embedding,
546
+ )
547
+ x = self.norm2(
548
+ x
549
+ + self._mha_block(
550
+ x, memory, memory_mask, memory_key_padding_mask
551
+ ),
552
+ stage_embedding,
553
+ )
554
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
555
+
556
+ if tgt_is_tuple:
557
+ return (x, stage_embedding)
558
+ return x
559
+
560
+ # self-attention block
561
+ def _sa_block(
562
+ self,
563
+ x: Tensor,
564
+ attn_mask: Optional[Tensor],
565
+ key_padding_mask: Optional[Tensor],
566
+ ) -> Tensor:
567
+ x = self.self_attn(
568
+ x,
569
+ x,
570
+ x,
571
+ attn_mask=attn_mask,
572
+ key_padding_mask=key_padding_mask,
573
+ need_weights=False,
574
+ )[0]
575
+ return self.dropout1(x)
576
+
577
+ # multihead attention block
578
+ def _mha_block(
579
+ self,
580
+ x: Tensor,
581
+ mem: Tensor,
582
+ attn_mask: Optional[Tensor],
583
+ key_padding_mask: Optional[Tensor],
584
+ ) -> Tensor:
585
+ x = self.multihead_attn(
586
+ x,
587
+ mem,
588
+ mem,
589
+ attn_mask=attn_mask,
590
+ key_padding_mask=key_padding_mask,
591
+ need_weights=False,
592
+ )[0]
593
+ return self.dropout2(x)
594
+
595
+ # feed forward block
596
+ def _ff_block(self, x: Tensor) -> Tensor:
597
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
598
+ return self.dropout3(x)
599
+
600
+
601
+ def _get_clones(module, N):
602
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
603
+
604
+
605
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
606
+ if activation == "relu":
607
+ return F.relu
608
+ elif activation == "gelu":
609
+ return F.gelu
610
+
611
+ raise RuntimeError(
612
+ "activation should be relu/gelu, not {}".format(activation)
613
+ )
slam_llm/models/vallex/vallex_config.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+ from fairseq.data import Dictionary
4
+ from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+
10
+ class VallexConfig(PretrainedConfig):
11
+
12
+ model_type = "vallex"
13
+
14
+ def __init__(self,
15
+ n_layer=24,
16
+ n_head=16,
17
+ n_dim=1024,
18
+ prefix_mode=1,
19
+ num_quantizers=8,
20
+ sample_rate=24000,
21
+ ar_at_dict="",
22
+ ar_st_dict="",
23
+ nar_at_dict="",
24
+ nar_st_dict="",
25
+ nar_scale_factor=1.0,
26
+ prepend_bos=True,
27
+ norm_first=True,
28
+ eps=0.0,
29
+ only_ar=False,
30
+ only_nar=False,
31
+ **kwargs
32
+ ):
33
+ self.n_layer = n_layer
34
+ self.n_head = n_head
35
+ self.n_dim = n_dim
36
+ self.prefix_mode = prefix_mode
37
+ self.num_quantizers = num_quantizers
38
+ self.sample_rate = sample_rate
39
+ self.nar_scale_factor = nar_scale_factor
40
+ self.prepend_bos = prepend_bos
41
+ self.norm_first = norm_first
42
+
43
+ self.ar_at_dict = ar_at_dict
44
+ self.ar_st_dict = ar_st_dict
45
+ self.nar_at_dict = nar_at_dict
46
+ self.nar_st_dict = nar_st_dict
47
+ self.eps = eps
48
+ self.only_ar = only_ar
49
+ self.only_nar = only_nar
50
+
51
+ super().__init__(
52
+ **kwargs
53
+ )
54
+
55
+
56
+ AutoConfig.register("vallex", VallexConfig)
slam_llm/models/vallex/vallex_model.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Dict, Iterator, List, Tuple, Union
3
+ from fairseq import utils
4
+ import numpy as np
5
+ import torch
6
+ import math
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from fairseq.data import Dictionary
10
+ from src.slam_llm.models.vallex.transformers import (
11
+ LayerNorm,
12
+ TransformerEncoder,
13
+ TransformerEncoderLayer,
14
+ )
15
+ from src.slam_llm.models.vallex.vallex_config import VallexConfig
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
18
+ from dataclasses import dataclass
19
+
20
+ @dataclass
21
+ class ModelOutput:
22
+ logits: torch.Tensor
23
+ loss: torch.Tensor
24
+ acc: torch.Tensor
25
+
26
+ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True, scale=1, prob_mask=None):
27
+ if target.dim() == lprobs.dim() - 1:
28
+ target = target.unsqueeze(-1)
29
+ if prob_mask is not None:
30
+ lprobs = lprobs.masked_fill(prob_mask, 0.0)
31
+ n_class = (1-prob_mask.float()).sum()
32
+ else:
33
+ n_class = lprobs.size(-1)
34
+ nll_loss = -lprobs.gather(dim=-1, index=target)
35
+ # nll_loss = nll_loss * scale
36
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True) * scale
37
+ if ignore_index is not None:
38
+ pad_mask = target.eq(ignore_index)
39
+ nll_loss.masked_fill_(pad_mask, 0.0)
40
+ smooth_loss.masked_fill_(pad_mask, 0.0)
41
+ pad_mask_float = (1 - pad_mask.to(torch.float)).sum()
42
+ else:
43
+ nll_loss = nll_loss.squeeze(-1)
44
+ smooth_loss = smooth_loss.squeeze(-1)
45
+ if reduce:
46
+ nll_loss = nll_loss.sum()
47
+ smooth_loss = smooth_loss.sum()
48
+ eps_i = epsilon / (n_class - 1)
49
+ loss = (1.0 - epsilon - eps_i) * nll_loss + \
50
+ eps_i * smooth_loss
51
+ return loss / pad_mask_float, nll_loss / pad_mask_float
52
+
53
+
54
+ class SinusoidalPositionalEmbedding(nn.Module):
55
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
56
+ super().__init__()
57
+ self.embedding_dim = embedding_dim
58
+ self.padding_idx = padding_idx if padding_idx is not None else 0
59
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
60
+ init_size, embedding_dim, padding_idx
61
+ )
62
+ self.onnx_trace = False
63
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
64
+ self.max_positions = int(1e5)
65
+
66
+ def prepare_for_onnx_export_(self):
67
+ self.onnx_trace = True
68
+
69
+ @staticmethod
70
+ def get_embedding(
71
+ num_embeddings: int, embedding_dim: int, padding_idx = None
72
+ ):
73
+ half_dim = embedding_dim // 2
74
+ emb = math.log(10000) / (half_dim - 1)
75
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
76
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
77
+ 1
78
+ ) * emb.unsqueeze(0)
79
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
80
+ num_embeddings, -1
81
+ )
82
+ if embedding_dim % 2 == 1:
83
+ # zero pad
84
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
85
+ if padding_idx is not None:
86
+ emb[padding_idx, :] = 0
87
+ return emb
88
+
89
+ def forward(
90
+ self,
91
+ input,
92
+ incremental_state = None,
93
+ timestep = None,
94
+ positions = None,
95
+ ):
96
+ bspair = torch.onnx.operators.shape_as_tensor(input)
97
+ bsz, seq_len = bspair[0], bspair[1]
98
+ max_pos = self.padding_idx + 1 + seq_len
99
+ if self.weights is None or max_pos > self.weights.size(0):
100
+ # recompute/expand embeddings if needed
101
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
102
+ max_pos, self.embedding_dim, self.padding_idx
103
+ )
104
+ self.weights = self.weights.to(self._float_tensor)
105
+
106
+ if incremental_state is not None:
107
+ # positions is the same for every token when decoding a single step
108
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
109
+ if self.onnx_trace:
110
+ return (
111
+ self.weights.index_select(index=self.padding_idx + pos, dim=0)
112
+ .unsqueeze(1)
113
+ .repeat(bsz, 1, 1)
114
+ )
115
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
116
+
117
+ positions = utils.make_positions(
118
+ input, self.padding_idx, onnx_trace=self.onnx_trace
119
+ )
120
+ if self.onnx_trace:
121
+ flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
122
+ embedding_shape = torch.cat(
123
+ (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
124
+ )
125
+ embeddings = torch.onnx.operators.reshape_from_tensor_shape(
126
+ flat_embeddings, embedding_shape
127
+ )
128
+ return embeddings
129
+ return (
130
+ self.weights.index_select(0, positions.view(-1))
131
+ .view(bsz, seq_len, -1)
132
+ .detach()
133
+ )
134
+
135
+
136
+ class Transpose(nn.Identity):
137
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
138
+ return input.transpose(1, 2)
139
+
140
+
141
+ class VALLF(PreTrainedModel):
142
+ config_class = VallexConfig
143
+
144
+ def __init__(
145
+ self,
146
+ config: VallexConfig
147
+ ):
148
+ super().__init__(config)
149
+
150
+ self.ar_at_dict = Dictionary.load(self.config.ar_at_dict)
151
+ self.ar_st_dict = Dictionary.load(self.config.ar_st_dict)
152
+ self.nar_at_dict = Dictionary.load(self.config.nar_at_dict)
153
+ self.nar_st_dict = Dictionary.load(self.config.nar_st_dict)
154
+
155
+ self.ar_at_dict.tts_flag = self.ar_at_dict.add_symbol("<TTS>")
156
+ self.ar_st_dict.asr_flag = self.ar_st_dict.add_symbol("<ASR>")
157
+ self.ar_st_dict.mt_flag = self.ar_st_dict.add_symbol("<MT>")
158
+
159
+ self.padding_idx = self.ar_at_dict.pad()
160
+ self.config = config
161
+ d_model = self.config.n_dim
162
+ nar_scale_factor = self.config.nar_scale_factor
163
+ prepend_bos = self.config.prepend_bos
164
+
165
+ norm_first = self.config.norm_first
166
+ num_layers = self.config.n_layer
167
+ self.NUM_AUDIO_TOKENS = self.ar_at_dict.eos()
168
+
169
+ nar_d_model = int(d_model * nar_scale_factor)
170
+
171
+ self.ar_text_embedding = nn.Embedding(len(self.ar_st_dict), d_model, self.ar_st_dict.pad()) # W_x
172
+ if config.only_ar:
173
+ pass
174
+ else:
175
+ self.nar_text_embedding = nn.Embedding(len(self.nar_st_dict), d_model, self.nar_st_dict.pad())
176
+
177
+ # ID self.NUM_AUDIO_TOKENS -> PAD
178
+ # ID self.NUM_AUDIO_TOKENS + 1 -> BOS
179
+ self.ar_audio_prepend_bos = prepend_bos
180
+ self.ar_audio_embedding = EncodecDecoderLstm(
181
+ dictionary=self.ar_at_dict, emb_dim=d_model
182
+ )
183
+
184
+ self.ar_text_prenet = nn.Identity()
185
+ self.ar_audio_prenet = nn.Identity()
186
+
187
+ self.ar_text_position = SinusoidalPositionalEmbedding(
188
+ d_model,
189
+ padding_idx=self.ar_at_dict.pad(),
190
+ init_size=1024+self.ar_at_dict.pad()+1
191
+ )
192
+ self.ar_audio_position = SinusoidalPositionalEmbedding(
193
+ d_model,
194
+ padding_idx=self.ar_at_dict.pad(),
195
+ init_size=1024+self.ar_at_dict.pad()+1
196
+ )
197
+
198
+ self.ar_decoder = TransformerEncoder(
199
+ TransformerEncoderLayer(
200
+ d_model,
201
+ self.config.n_head,
202
+ dim_feedforward=d_model * 4,
203
+ dropout=0.1,
204
+ batch_first=True,
205
+ norm_first=norm_first,
206
+ ),
207
+ num_layers=num_layers,
208
+ norm=LayerNorm(d_model) if norm_first else None,
209
+ )
210
+ self.ar_predict_layer = nn.Linear(
211
+ d_model, len(self.ar_at_dict), bias=False
212
+ )
213
+
214
+ self.rng = random.Random(0)
215
+ self.num_heads = self.config.n_head
216
+ self.prefix_mode = self.config.prefix_mode
217
+ self.num_quantizers = self.config.num_quantizers
218
+
219
+ assert self.num_quantizers >= 1
220
+ if config.only_ar:
221
+ pass
222
+ else:
223
+ if self.num_quantizers > 1:
224
+ self.nar_audio_embeddings = NATEncodecDecoderLstm(
225
+ codecs=[0, 1, 2, 3, 4, 5, 6, 7], dictionary=self.nar_at_dict, emb_dim=d_model
226
+ ) # W_a
227
+
228
+ self.nar_text_prenet = nn.Identity()
229
+ self.nar_audio_prenet = nn.Identity()
230
+
231
+ self.nar_text_position = SinusoidalPositionalEmbedding(
232
+ d_model,
233
+ padding_idx=self.nar_at_dict.pad(),
234
+ init_size=1024+self.nar_at_dict.pad()+1
235
+ )
236
+ self.nar_audio_position = SinusoidalPositionalEmbedding(
237
+ d_model,
238
+ padding_idx=self.nar_at_dict.pad(),
239
+ init_size=1024+self.nar_at_dict.pad()+1
240
+ )
241
+
242
+ self.nar_decoder = TransformerEncoder(
243
+ TransformerEncoderLayer(
244
+ nar_d_model,
245
+ int(self.num_heads * nar_scale_factor),
246
+ dim_feedforward=nar_d_model * 4,
247
+ dropout=0.1,
248
+ batch_first=True,
249
+ norm_first=norm_first,
250
+ adaptive_layer_norm=True,
251
+ ),
252
+ num_layers=int(num_layers * nar_scale_factor),
253
+ norm=nn.LayerNorm(nar_d_model)
254
+ if norm_first
255
+ else None,
256
+ )
257
+ self.nar_predict_layers = nn.ModuleList(
258
+ [
259
+ nn.Linear(nar_d_model, len(self.nar_at_dict), bias=False)
260
+ for i in range(self.num_quantizers)
261
+ ]
262
+ )
263
+ self.nar_stage_embeddings = None
264
+
265
+ def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
266
+ assert stage > 0
267
+ if stage == 1:
268
+ for name, param in self.named_parameters():
269
+ if name.startswith("ar_"):
270
+ print(f" AR parameter: {name}")
271
+ yield param
272
+
273
+ if stage == 2:
274
+ for name, param in self.named_parameters():
275
+ if name.startswith("nar_"):
276
+ print(f"NAR parameter: {name}")
277
+ yield param
278
+
279
+ def stage_named_parameters(
280
+ self, stage: int = 1
281
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
282
+ assert stage > 0
283
+ if stage == 1:
284
+ for pair in self.named_parameters():
285
+ if pair[0].startswith("ar_"):
286
+ yield pair
287
+
288
+ if stage == 2:
289
+ for pair in self.named_parameters():
290
+ if pair[0].startswith("nar_"):
291
+ yield pair
292
+
293
+ def pad_y_eos(self, y, y_mask_int, eos_id):
294
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
295
+ y_mask_int, (0, 1), value=1
296
+ )
297
+ # inputs, targets
298
+ if self.ar_audio_prepend_bos:
299
+ return (
300
+ F.pad(targets[:, :-1], (1, 0), value=self.NUM_AUDIO_TOKENS + 1),
301
+ targets,
302
+ )
303
+
304
+ return targets[:, :-1], targets[:, 1:]
305
+
306
+ class VALLE(VALLF):
307
+ config_class = VallexConfig
308
+
309
+ def __init__(
310
+ self,
311
+ config: VallexConfig,
312
+ **kwargs,
313
+ ):
314
+ super(VALLE, self).__init__(
315
+ config,
316
+ **kwargs,
317
+ )
318
+ print(config)
319
+ self.config = config
320
+ d_model = self.config.n_dim
321
+ self.eps = config.eps
322
+
323
+ self.language_ID = {
324
+ 'en': 0,
325
+ 'zh': 1,
326
+ }
327
+ self.ar_language_embedding = nn.Embedding(3, d_model, padding_idx=2)
328
+ self.nar_language_embedding = nn.Embedding(3, d_model, padding_idx=2)
329
+ self.embed_scale = 32.0
330
+
331
+ def forward(
332
+ self,
333
+ zh,
334
+ en
335
+ ):
336
+ """
337
+ "zh": {
338
+ "st_tokens": zh_st,
339
+ "at_tokens_wbos": zh_prev_at,
340
+ "at_tokens_tgt": zh_tgt_at,
341
+ "self_atten_mask": zh_self_atten_mask,
342
+ "padding_mask": zh_padding_mask,
343
+ "langid": zh_id.long()
344
+ },
345
+ "en": {
346
+ "st_tokens": en_st,
347
+ "at_tokens_wbos": en_prev_at,
348
+ "at_tokens_tgt": en_tgt_at,
349
+ "self_atten_mask": en_self_atten_mask,
350
+ "padding_mask": en_padding_mask,
351
+ "langid": en_id.long()
352
+ }
353
+ """
354
+ flag = (np.random.randint(low=0, high=1000) % 2 == 0) # zh or en
355
+ if flag:
356
+ data = zh
357
+ else:
358
+ data = en
359
+
360
+ st_tokens = data["st_tokens"]
361
+ at_tokens_wbos = data["at_tokens_wbos"]
362
+ at_tokens_tgt = data["at_tokens_tgt"]
363
+ self_atten_mask = data["self_atten_mask"]
364
+ padding_mask = data["padding_mask"]
365
+ langid = data["langid"]
366
+
367
+ st_len = st_tokens.size(1)
368
+ st_emb = self.embed_scale * self.ar_text_embedding(st_tokens)
369
+ src_lang_emb = self.embed_scale * self.ar_language_embedding(langid)
370
+ st_emb += src_lang_emb
371
+ st_pos = self.ar_text_position(st_tokens)
372
+ st_emb += st_pos
373
+
374
+ at_emb, _ = self.ar_audio_embedding(at_tokens_wbos, None)
375
+ at_emb = self.embed_scale * at_emb
376
+ tgt_lang_emb = self.embed_scale * self.ar_language_embedding(langid)
377
+ at_emb += tgt_lang_emb
378
+ at_pos = self.ar_audio_position(at_tokens_wbos)
379
+ at_emb += at_pos
380
+
381
+ x = torch.concat([st_emb, at_emb], dim=1)
382
+
383
+ x = self.ar_decoder(
384
+ x,
385
+ mask=self_atten_mask,
386
+ src_key_padding_mask=padding_mask
387
+ )
388
+ x = self.ar_predict_layer(x)
389
+ x = x[:, st_len:, :]
390
+ loss, nll_loss, lprob, right_rate = self.calculate_loss(
391
+ x, at_tokens_tgt
392
+ )
393
+ return ModelOutput(logits=lprob, loss=loss, acc=right_rate), right_rate
394
+
395
+ def calculate_loss(self, encoder_out, target, reduce=True, scale=1.0, prob_mask=None, acc=True):
396
+ lprob = self.get_normalized_probs(encoder_out, log_probs=True)
397
+ with torch.no_grad():
398
+ mask = target.ne(self.padding_idx)
399
+ n_correct = torch.sum(
400
+ lprob.argmax(-1).masked_select(mask).eq(target.masked_select(mask))
401
+ )
402
+ total = torch.sum(mask)
403
+ right_rate = n_correct * 100.0 / total
404
+
405
+ lprob, target = lprob.view(-1, lprob.size(-1)), target.view(-1)
406
+ loss, nll_loss = label_smoothed_nll_loss(
407
+ lprob,
408
+ target,
409
+ self.eps,
410
+ ignore_index=self.padding_idx,
411
+ reduce=reduce,
412
+ scale=scale,
413
+ prob_mask=prob_mask
414
+ )
415
+
416
+ return loss, nll_loss, lprob, right_rate
417
+
418
+ def get_normalized_probs(self, encoder_out, log_probs, sample=None):
419
+ if torch.is_tensor(encoder_out):
420
+ logits = encoder_out.float()
421
+ if log_probs:
422
+ return F.log_softmax(logits, dim=-1)
423
+ else:
424
+ return F.softmax(logits, dim=-1)
425
+
426
+
427
+ def inference_24L(
428
+ self,
429
+ x: torch.Tensor,
430
+ x_lens: torch.Tensor,
431
+ y: torch.Tensor,
432
+ enroll_x_lens: torch.Tensor,
433
+ top_k: int = -100,
434
+ temperature: float = 1.0,
435
+ prompt_language: str = None,
436
+ text_language: str = None,
437
+ best_of: int = 1,
438
+ length_penalty: float = 1.0,
439
+ return_worst: bool = False,
440
+ at_eos: int = -1
441
+ ) -> torch.Tensor:
442
+ assert x.ndim == 2, x.shape
443
+ assert x_lens.ndim == 1, x_lens.shape
444
+ assert y.ndim == 3, y.shape
445
+ assert y.shape[0] == 1, y.shape
446
+
447
+ assert torch.all(x_lens > 0)
448
+ self.NUM_AUDIO_TOKENS = at_eos
449
+ text = x
450
+ x = self.embed_scale * self.ar_text_embedding(text)
451
+ prompt_language_id = prompt_language.to(x.device)
452
+ text_language_id = text_language.to(x.device)
453
+ src_lang_emb = self.embed_scale * self.ar_language_embedding(prompt_language_id)
454
+ tgt_lang_emb = self.embed_scale * self.ar_language_embedding(text_language_id)
455
+ x[:, :enroll_x_lens, :] += src_lang_emb
456
+ x[:, enroll_x_lens:, :] += tgt_lang_emb
457
+ x = self.ar_text_prenet(x)
458
+ x_pos = self.ar_text_position(text)
459
+
460
+ text_len = x_lens.max()
461
+ prompts = y
462
+ prefix_len = y.shape[1]
463
+
464
+ # AR Decoder
465
+ # TODO: Managing decoder steps avoid repetitive computation
466
+ y = prompts[..., 0]
467
+ if self.ar_audio_prepend_bos:
468
+ y = F.pad(y, (1, 0), value=self.ar_at_dict.tts_flag)
469
+
470
+ x_len = x_lens.max()
471
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
472
+
473
+ kv_cache = None
474
+ use_kv_caching = True
475
+
476
+ sum_logprobs = torch.zeros(best_of, device=y.device) # implement batch decoding here
477
+ x = x.repeat(best_of, 1, 1)
478
+ y = y.repeat(best_of, 1)
479
+ lstm_h = None
480
+ first_ar = True
481
+ while True:
482
+ if first_ar:
483
+ y_emb, lstm_h = self.ar_audio_embedding(y, lstm_h)
484
+ y_emb = y_emb * self.embed_scale
485
+ y_emb = self.ar_audio_prenet(y_emb)
486
+ y_pos = self.ar_audio_position(y)
487
+ y_emb[:, :prefix_len] = y_emb[:, :prefix_len] + src_lang_emb
488
+ y_emb[:, prefix_len:] = y_emb[:, prefix_len:] + tgt_lang_emb
489
+ xy_pos_token = torch.concat([x_pos+x, y_pos+y_emb], dim=1)
490
+ first_ar = False
491
+ else:
492
+ y_emb_cur, lstm_h = self.ar_audio_embedding(y[:, -1:], lstm_h)
493
+ y_emb_cur = y_emb_cur * self.embed_scale
494
+ y_emb_cur = self.ar_audio_prenet(y_emb_cur)
495
+ y_pos_cur = self.ar_audio_position(y)[:, -1:]
496
+ y_emb_cur = y_emb_cur + src_lang_emb
497
+ y_emb_cur = y_emb_cur + tgt_lang_emb
498
+ xy_pos_token = torch.concat([xy_pos_token, y_pos_cur+y_emb_cur], dim=1)
499
+ # print(xy_pos_token.size())
500
+
501
+ y_len = y.shape[1]
502
+ x_attn_mask_pad = F.pad(
503
+ x_attn_mask,
504
+ (0, y_len),
505
+ value=True,
506
+ )
507
+ y_attn_mask = F.pad(
508
+ torch.triu(
509
+ torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
510
+ ),
511
+ (x_len, 0),
512
+ value=False,
513
+ )
514
+ xy_attn_mask = torch.concat(
515
+ [x_attn_mask_pad, y_attn_mask], dim=0
516
+ ).to(y.device)
517
+
518
+
519
+ if use_kv_caching and kv_cache is not None:
520
+ xy_pos = xy_pos_token[:, [-1]]
521
+ xy_attn_mask = xy_attn_mask[:, [-1]]
522
+ else:
523
+ xy_pos = xy_pos_token
524
+
525
+ xy_dec, kv_cache = self.ar_decoder.infer(
526
+ xy_pos,
527
+ mask=xy_attn_mask,
528
+ past_kv=kv_cache,
529
+ use_cache=use_kv_caching,
530
+ )
531
+
532
+ logits = self.ar_predict_layer(xy_dec[:, -1])
533
+ samples, current_logprobs = topk_sampling(
534
+ logits, top_k=top_k, top_p=1, temperature=temperature
535
+ )
536
+ sum_logprobs += current_logprobs * (y[:, -1] != self.NUM_AUDIO_TOKENS)
537
+ samples[y[:, -1] == self.NUM_AUDIO_TOKENS] = self.NUM_AUDIO_TOKENS
538
+ completed = (samples[:, -1] == self.NUM_AUDIO_TOKENS).all()
539
+ if (
540
+ completed
541
+ or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 32
542
+ ):
543
+ if prompts.shape[1] == y.shape[1]:
544
+ raise SyntaxError(
545
+ "well trained model shouldn't reach here."
546
+ )
547
+ lengths = torch.sum(y != self.NUM_AUDIO_TOKENS, dim=1)
548
+ avg_logprobs = sum_logprobs / lengths ** length_penalty
549
+ # choose the best beam according to sum_logprobs
550
+ best_beam = y[torch.argmax(avg_logprobs), :]
551
+ worst_beam = y[torch.argmin(avg_logprobs), :]
552
+ # strip all eos tokens
553
+ best_beam = best_beam[best_beam != self.NUM_AUDIO_TOKENS]
554
+ worst_beam = worst_beam[worst_beam != self.NUM_AUDIO_TOKENS]
555
+ if return_worst:
556
+ y = worst_beam.unsqueeze(0)
557
+ else:
558
+ y = best_beam.unsqueeze(0)
559
+ print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
560
+ break
561
+
562
+ y = torch.concat([y, samples], dim=1)
563
+
564
+ codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
565
+ if self.num_quantizers == 1:
566
+ return torch.stack(codes, dim=-1)
567
+
568
+ if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
569
+ enrolled_len = enroll_x_lens.max().item()
570
+ # SOS + Synthesis Text + EOS
571
+ text = torch.concat(
572
+ [
573
+ text[:, :1],
574
+ text[:, enrolled_len - 1 :],
575
+ ],
576
+ dim=1,
577
+ )
578
+ text_len = text_len - (enrolled_len - 2)
579
+ assert text.shape[0] == 1
580
+
581
+ x = self.embed_scale * self.nar_text_embedding(text)
582
+ # Add language embedding
583
+ prompt_language_id = prompt_language.to(x.device)
584
+ text_language_id = text_language.to(x.device)
585
+ src_lang_emb = self.embed_scale * self.nar_language_embedding(prompt_language_id)
586
+ tgt_lang_emb = self.embed_scale * self.nar_language_embedding(text_language_id)
587
+ x[:, :enroll_x_lens, :] += src_lang_emb
588
+ x[:, enroll_x_lens:, :] += tgt_lang_emb
589
+ x = self.nar_text_prenet(x)
590
+ x_pos = self.nar_text_position(text)
591
+
592
+ if self.prefix_mode == 0:
593
+ for i, predict_layer in enumerate(
594
+ self.nar_predict_layers
595
+ ):
596
+ y_pos = self.nar_audio_prenet(y_emb)
597
+ y_pos = self.nar_audio_position(y_pos)
598
+ xy_pos = torch.concat([x, y_pos], dim=1)
599
+
600
+ xy_dec, _ = self.nar_decoder(
601
+ (xy_pos, self.nar_stage_embeddings[i].weight)
602
+ )
603
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
604
+
605
+ samples = torch.argmax(logits, dim=-1)
606
+ codes.append(samples)
607
+
608
+ if i < self.num_quantizers - 2:
609
+ y_emb[:, :prefix_len] += self.embed_scale * self.nar_audio_embeddings(
610
+ prompts[..., i + 1]
611
+ )[0]
612
+ y_emb[:, prefix_len:] += self.embed_scale * self.nar_audio_embeddings(samples)[0]
613
+ else:
614
+ y_pos = self.nar_audio_position(y[:, int(self.ar_audio_prepend_bos):])
615
+
616
+ ref_at_emb = self.embed_scale * self.nar_audio_embeddings(prompts)[0] + src_lang_emb
617
+ est_at = y[:, prefix_len+int(self.ar_audio_prepend_bos):].unsqueeze(-1)
618
+ #
619
+ for i in range(1, 8):
620
+ y_emb, _ = self.nar_audio_embeddings(est_at)
621
+ y_emb = self.embed_scale * y_emb + tgt_lang_emb
622
+
623
+ y_emb = torch.concat([ref_at_emb, y_emb], dim=1)
624
+ xy_pos = torch.concat([x+x_pos, y_emb+y_pos], dim=1)
625
+
626
+ xy_dec = self.nar_decoder(
627
+ xy_pos
628
+ )
629
+ logits = self.nar_predict_layers[i-1](xy_dec[:, text_len + prefix_len :])
630
+ # print(logits.size(), xy_pos.size(), xy_dec.size())
631
+ samples = torch.argmax(logits, dim=-1)
632
+ est_at = torch.concat([est_at, samples.unsqueeze(-1)], dim=-1)
633
+ codes.append(samples)
634
+
635
+ assert len(codes) == self.num_quantizers
636
+ return torch.stack(codes, dim=-1)
637
+
638
+ def top_k_top_p_filtering(
639
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
640
+ ):
641
+ if top_k > 0:
642
+ top_k = min(
643
+ max(top_k, min_tokens_to_keep), logits.size(-1)
644
+ ) # Safety check
645
+ # Remove all tokens with a probability less than the last token of the top-k
646
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
647
+ logits[indices_to_remove] = filter_value
648
+
649
+ if top_p < 1.0:
650
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
651
+ cumulative_probs = torch.cumsum(
652
+ F.softmax(sorted_logits, dim=-1), dim=-1
653
+ )
654
+
655
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
656
+ sorted_indices_to_remove = cumulative_probs > top_p
657
+ if min_tokens_to_keep > 1:
658
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
659
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
660
+ # Shift the indices to the right to keep also the first token above the threshold
661
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
662
+ ..., :-1
663
+ ].clone()
664
+ sorted_indices_to_remove[..., 0] = 0
665
+
666
+ # scatter sorted tensors to original indexing
667
+ indices_to_remove = sorted_indices_to_remove.scatter(
668
+ 1, sorted_indices, sorted_indices_to_remove
669
+ )
670
+ logits[indices_to_remove] = filter_value
671
+ return logits
672
+
673
+
674
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
675
+ if temperature != 1.0:
676
+ logits = logits / temperature
677
+ # Top-p/top-k filtering
678
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
679
+ # Sample
680
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
681
+ logprobs = F.log_softmax(logits.float(), dim=-1)
682
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
683
+ return token, current_logprobs
684
+
685
+ class SLSTM(nn.Module):
686
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional=False):
687
+ super().__init__()
688
+ self.skip = skip
689
+ self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional)
690
+ if bidirectional:
691
+ self.out_fc = nn.Linear(dimension*2, dimension)
692
+ else:
693
+ self.out_fc = None
694
+
695
+ def forward(self, x, hidden=None):
696
+ x = x.permute(2, 0, 1)
697
+ y, hidden = self.lstm(x, hidden)
698
+ if self.out_fc is not None:
699
+ y = self.out_fc(y)
700
+ if self.skip:
701
+ y = y + x
702
+ y = y.permute(1, 2, 0)
703
+ return y, hidden
704
+
705
+ class EncodecDecoderLstm(nn.Module):
706
+ def __init__(self, dictionary, emb_dim,
707
+ out_dim=None,
708
+ num_layers=3, lstm_skip=True, lstm_bidire=False,
709
+ activation_param={'alpha': 1.0}, **kwargs):
710
+ super().__init__()
711
+
712
+ # Identity()
713
+ if out_dim is None:
714
+ out_dim = emb_dim
715
+ self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire)
716
+ self.elu = nn.ELU(**activation_param)
717
+ self.embedding_dim = emb_dim
718
+ self.padding_idx = dictionary.pad()
719
+ self.emb = nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index)
720
+
721
+ def forward(self, x, hidden=None):
722
+ """
723
+ Args:
724
+ x (_type_): B,T,D
725
+ """
726
+ # print(x.size())
727
+ quantized_out = self.emb(x)
728
+ out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden)
729
+ out = self.elu(out)
730
+ return out.permute(0,2,1), hidden
731
+
732
+ class NATEncodecDecoderLstm(nn.Module):
733
+ def __init__(self, codecs, dictionary, emb_dim,
734
+ out_dim=None,
735
+ num_layers=3, lstm_skip=True, lstm_bidire=False,
736
+ activation_param={'alpha': 1.0}, **kwargs):
737
+ super().__init__()
738
+
739
+ # Identity()
740
+ if out_dim is None:
741
+ out_dim = emb_dim
742
+ self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire)
743
+ self.elu = nn.ELU(**activation_param)
744
+ self.codecs = codecs
745
+ self.embedding_dim = emb_dim
746
+ self.padding_idx = dictionary.pad()
747
+ self.emb_list = nn.ModuleList(
748
+ [nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index) for i in range(len(self.codecs))]
749
+ )
750
+
751
+ def forward(self, x, hidden=None):
752
+ """
753
+ Args:
754
+ x (_type_): B,T,D
755
+ """
756
+ if len(x.size()) == 2:
757
+ x = x.unsqueeze(-1)
758
+
759
+ if x.size(2) != len(self.codecs) and x.size(1) == len(self.codecs):
760
+ x = x.permute(0, 2, 1)
761
+
762
+ quantized_out = 0
763
+ for i in range(x.size(2)):
764
+ quantized = self.emb_list[i](x[: , :, i])
765
+ quantized_out = quantized_out + quantized
766
+ # quantized_out = quantized_out / len(self.codecs)
767
+
768
+ out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden)
769
+ out = self.elu(out)
770
+ return out.permute(0,2,1), hidden
771
+
772
+ AutoModel.register(VallexConfig, VALLE)
slam_llm/models/wavlm/WavLM.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn import LayerNorm
20
+ from .modules import (
21
+ Fp32GroupNorm,
22
+ Fp32LayerNorm,
23
+ GradMultiply,
24
+ MultiheadAttention,
25
+ SamePad,
26
+ init_bert_params,
27
+ get_activation_fn,
28
+ TransposeLast,
29
+ GLU_Linear,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def compute_mask_indices(
36
+ shape: Tuple[int, int],
37
+ padding_mask: Optional[torch.Tensor],
38
+ mask_prob: float,
39
+ mask_length: int,
40
+ mask_type: str = "static",
41
+ mask_other: float = 0.0,
42
+ min_masks: int = 0,
43
+ no_overlap: bool = False,
44
+ min_space: int = 0,
45
+ ) -> np.ndarray:
46
+ """
47
+ Computes random mask spans for a given shape
48
+
49
+ Args:
50
+ shape: the the shape for which to compute masks.
51
+ should be of size 2 where first element is batch size and 2nd is timesteps
52
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
53
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
54
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
55
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
56
+ mask_type: how to compute mask lengths
57
+ static = fixed size
58
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
59
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
60
+ poisson = sample from possion distribution with lambda = mask length
61
+ min_masks: minimum number of masked spans
62
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
63
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
64
+ """
65
+
66
+ bsz, all_sz = shape
67
+ mask = np.full((bsz, all_sz), False)
68
+
69
+ all_num_mask = int(
70
+ # add a random number for probabilistic rounding
71
+ mask_prob * all_sz / float(mask_length)
72
+ + np.random.rand()
73
+ )
74
+
75
+ all_num_mask = max(min_masks, all_num_mask)
76
+
77
+ mask_idcs = []
78
+ for i in range(bsz):
79
+ if padding_mask is not None:
80
+ sz = all_sz - padding_mask[i].long().sum().item()
81
+ num_mask = int(
82
+ # add a random number for probabilistic rounding
83
+ mask_prob * sz / float(mask_length)
84
+ + np.random.rand()
85
+ )
86
+ num_mask = max(min_masks, num_mask)
87
+ else:
88
+ sz = all_sz
89
+ num_mask = all_num_mask
90
+
91
+ if mask_type == "static":
92
+ lengths = np.full(num_mask, mask_length)
93
+ elif mask_type == "uniform":
94
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
95
+ elif mask_type == "normal":
96
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
97
+ lengths = [max(1, int(round(x))) for x in lengths]
98
+ elif mask_type == "poisson":
99
+ lengths = np.random.poisson(mask_length, size=num_mask)
100
+ lengths = [int(round(x)) for x in lengths]
101
+ else:
102
+ raise Exception("unknown mask selection " + mask_type)
103
+
104
+ if sum(lengths) == 0:
105
+ lengths[0] = min(mask_length, sz - 1)
106
+
107
+ if no_overlap:
108
+ mask_idc = []
109
+
110
+ def arrange(s, e, length, keep_length):
111
+ span_start = np.random.randint(s, e - length)
112
+ mask_idc.extend(span_start + i for i in range(length))
113
+
114
+ new_parts = []
115
+ if span_start - s - min_space >= keep_length:
116
+ new_parts.append((s, span_start - min_space + 1))
117
+ if e - span_start - keep_length - min_space > keep_length:
118
+ new_parts.append((span_start + length + min_space, e))
119
+ return new_parts
120
+
121
+ parts = [(0, sz)]
122
+ min_length = min(lengths)
123
+ for length in sorted(lengths, reverse=True):
124
+ lens = np.fromiter(
125
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
126
+ np.int,
127
+ )
128
+ l_sum = np.sum(lens)
129
+ if l_sum == 0:
130
+ break
131
+ probs = lens / np.sum(lens)
132
+ c = np.random.choice(len(parts), p=probs)
133
+ s, e = parts.pop(c)
134
+ parts.extend(arrange(s, e, length, min_length))
135
+ mask_idc = np.asarray(mask_idc)
136
+ else:
137
+ min_len = min(lengths)
138
+ if sz - min_len <= num_mask:
139
+ min_len = sz - num_mask - 1
140
+
141
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
142
+
143
+ mask_idc = np.asarray(
144
+ [
145
+ mask_idc[j] + offset
146
+ for j in range(len(mask_idc))
147
+ for offset in range(lengths[j])
148
+ ]
149
+ )
150
+
151
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
152
+
153
+ min_len = min([len(m) for m in mask_idcs])
154
+ for i, mask_idc in enumerate(mask_idcs):
155
+ if len(mask_idc) > min_len:
156
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
157
+ mask[i, mask_idc] = True
158
+
159
+ return mask
160
+
161
+
162
+ class WavLMConfig:
163
+ def __init__(self, cfg=None):
164
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
165
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
166
+
167
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
168
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
169
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
170
+ self.activation_fn: str = "gelu" # activation function to use
171
+
172
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
173
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
174
+ self.conv_bias: bool = False # include bias in conv encoder
175
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
176
+
177
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
178
+
179
+ # dropouts
180
+ self.dropout: float = 0.1 # dropout probability for the transformer
181
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
182
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
183
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
184
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
185
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
186
+
187
+ # masking
188
+ self.mask_length: int = 10 # mask length
189
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
190
+ self.mask_selection: str = "static" # how to choose mask length
191
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
192
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
193
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
194
+
195
+ # channel masking
196
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
197
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
198
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
199
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
200
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
201
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
202
+
203
+ # positional embeddings
204
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
205
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
206
+
207
+ # relative position embedding
208
+ self.relative_position_embedding: bool = False # apply relative position embedding
209
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
210
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
211
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
212
+
213
+ if cfg is not None:
214
+ self.update(cfg)
215
+
216
+ def update(self, cfg: dict):
217
+ self.__dict__.update(cfg)
218
+
219
+
220
+ class WavLM(nn.Module):
221
+ def __init__(
222
+ self,
223
+ cfg: WavLMConfig,
224
+ ) -> None:
225
+ super().__init__()
226
+ logger.info(f"WavLM Config: {cfg.__dict__}")
227
+
228
+ self.cfg = cfg
229
+ feature_enc_layers = eval(cfg.conv_feature_layers)
230
+ self.embed = feature_enc_layers[-1][0]
231
+
232
+ self.feature_extractor = ConvFeatureExtractionModel(
233
+ conv_layers=feature_enc_layers,
234
+ dropout=0.0,
235
+ mode=cfg.extractor_mode,
236
+ conv_bias=cfg.conv_bias,
237
+ )
238
+
239
+ self.post_extract_proj = (
240
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
241
+ if self.embed != cfg.encoder_embed_dim
242
+ else None
243
+ )
244
+
245
+ self.mask_prob = cfg.mask_prob
246
+ self.mask_selection = cfg.mask_selection
247
+ self.mask_other = cfg.mask_other
248
+ self.mask_length = cfg.mask_length
249
+ self.no_mask_overlap = cfg.no_mask_overlap
250
+ self.mask_min_space = cfg.mask_min_space
251
+
252
+ self.mask_channel_prob = cfg.mask_channel_prob
253
+ self.mask_channel_selection = cfg.mask_channel_selection
254
+ self.mask_channel_other = cfg.mask_channel_other
255
+ self.mask_channel_length = cfg.mask_channel_length
256
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
257
+ self.mask_channel_min_space = cfg.mask_channel_min_space
258
+
259
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
260
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
261
+
262
+ self.feature_grad_mult = cfg.feature_grad_mult
263
+
264
+ self.mask_emb = nn.Parameter(
265
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
266
+ )
267
+
268
+ self.encoder = TransformerEncoder(cfg)
269
+ self.layer_norm = LayerNorm(self.embed)
270
+
271
+ def apply_mask(self, x, padding_mask):
272
+ B, T, C = x.shape
273
+ if self.mask_prob > 0:
274
+ mask_indices = compute_mask_indices(
275
+ (B, T),
276
+ padding_mask,
277
+ self.mask_prob,
278
+ self.mask_length,
279
+ self.mask_selection,
280
+ self.mask_other,
281
+ min_masks=2,
282
+ no_overlap=self.no_mask_overlap,
283
+ min_space=self.mask_min_space,
284
+ )
285
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
286
+ x[mask_indices] = self.mask_emb
287
+ else:
288
+ mask_indices = None
289
+
290
+ if self.mask_channel_prob > 0:
291
+ mask_channel_indices = compute_mask_indices(
292
+ (B, C),
293
+ None,
294
+ self.mask_channel_prob,
295
+ self.mask_channel_length,
296
+ self.mask_channel_selection,
297
+ self.mask_channel_other,
298
+ no_overlap=self.no_mask_channel_overlap,
299
+ min_space=self.mask_channel_min_space,
300
+ )
301
+ mask_channel_indices = (
302
+ torch.from_numpy(mask_channel_indices)
303
+ .to(x.device)
304
+ .unsqueeze(1)
305
+ .expand(-1, T, -1)
306
+ )
307
+ x[mask_channel_indices] = 0
308
+
309
+ return x, mask_indices
310
+
311
+ def forward_padding_mask(
312
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
313
+ ) -> torch.Tensor:
314
+ extra = padding_mask.size(1) % features.size(1)
315
+ if extra > 0:
316
+ padding_mask = padding_mask[:, :-extra]
317
+ padding_mask = padding_mask.view(
318
+ padding_mask.size(0), features.size(1), -1
319
+ )
320
+ padding_mask = padding_mask.all(-1)
321
+ return padding_mask
322
+
323
+ def extract_features(
324
+ self,
325
+ source: torch.Tensor,
326
+ padding_mask: Optional[torch.Tensor] = None,
327
+ mask: bool = False,
328
+ ret_conv: bool = False,
329
+ output_layer: Optional[int] = None,
330
+ ret_layer_results: bool = False,
331
+ ):
332
+
333
+ if self.feature_grad_mult > 0:
334
+ features = self.feature_extractor(source)
335
+ if self.feature_grad_mult != 1.0:
336
+ features = GradMultiply.apply(features, self.feature_grad_mult)
337
+ else:
338
+ with torch.no_grad():
339
+ features = self.feature_extractor(source)
340
+
341
+ features = features.transpose(1, 2)
342
+ features = self.layer_norm(features)
343
+
344
+ if padding_mask is not None:
345
+ padding_mask = self.forward_padding_mask(features, padding_mask)
346
+
347
+ if self.post_extract_proj is not None:
348
+ features = self.post_extract_proj(features)
349
+
350
+ features = self.dropout_input(features)
351
+
352
+ if mask:
353
+ x, mask_indices = self.apply_mask(
354
+ features, padding_mask
355
+ )
356
+ else:
357
+ x = features
358
+
359
+ # feature: (B, T, D), float
360
+ # target: (B, T), long
361
+ # x: (B, T, D), float
362
+ # padding_mask: (B, T), bool
363
+ # mask_indices: (B, T), bool
364
+ x, layer_results = self.encoder(
365
+ x,
366
+ padding_mask=padding_mask,
367
+ layer=None if output_layer is None else output_layer - 1
368
+ )
369
+
370
+ res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
371
+
372
+ feature = res["features"] if ret_conv else res["x"]
373
+ if ret_layer_results:
374
+ feature = (feature, res["layer_results"])
375
+ return feature, res["padding_mask"]
376
+
377
+
378
+ class ConvFeatureExtractionModel(nn.Module):
379
+ def __init__(
380
+ self,
381
+ conv_layers: List[Tuple[int, int, int]],
382
+ dropout: float = 0.0,
383
+ mode: str = "default",
384
+ conv_bias: bool = False,
385
+ conv_type: str = "default"
386
+ ):
387
+ super().__init__()
388
+
389
+ assert mode in {"default", "layer_norm"}
390
+
391
+ def block(
392
+ n_in,
393
+ n_out,
394
+ k,
395
+ stride,
396
+ is_layer_norm=False,
397
+ is_group_norm=False,
398
+ conv_bias=False,
399
+ ):
400
+ def make_conv():
401
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
402
+ nn.init.kaiming_normal_(conv.weight)
403
+ return conv
404
+
405
+ assert (
406
+ is_layer_norm and is_group_norm
407
+ ) == False, "layer norm and group norm are exclusive"
408
+
409
+ if is_layer_norm:
410
+ return nn.Sequential(
411
+ make_conv(),
412
+ nn.Dropout(p=dropout),
413
+ nn.Sequential(
414
+ TransposeLast(),
415
+ Fp32LayerNorm(dim, elementwise_affine=True),
416
+ TransposeLast(),
417
+ ),
418
+ nn.GELU(),
419
+ )
420
+ elif is_group_norm:
421
+ return nn.Sequential(
422
+ make_conv(),
423
+ nn.Dropout(p=dropout),
424
+ Fp32GroupNorm(dim, dim, affine=True),
425
+ nn.GELU(),
426
+ )
427
+ else:
428
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
429
+
430
+ self.conv_type = conv_type
431
+ if self.conv_type == "default":
432
+ in_d = 1
433
+ self.conv_layers = nn.ModuleList()
434
+ for i, cl in enumerate(conv_layers):
435
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
436
+ (dim, k, stride) = cl
437
+
438
+ self.conv_layers.append(
439
+ block(
440
+ in_d,
441
+ dim,
442
+ k,
443
+ stride,
444
+ is_layer_norm=mode == "layer_norm",
445
+ is_group_norm=mode == "default" and i == 0,
446
+ conv_bias=conv_bias,
447
+ )
448
+ )
449
+ in_d = dim
450
+ elif self.conv_type == "conv2d":
451
+ in_d = 1
452
+ self.conv_layers = nn.ModuleList()
453
+ for i, cl in enumerate(conv_layers):
454
+ assert len(cl) == 3
455
+ (dim, k, stride) = cl
456
+
457
+ self.conv_layers.append(
458
+ torch.nn.Conv2d(in_d, dim, k, stride)
459
+ )
460
+ self.conv_layers.append(torch.nn.ReLU())
461
+ in_d = dim
462
+ elif self.conv_type == "custom":
463
+ in_d = 1
464
+ idim = 80
465
+ self.conv_layers = nn.ModuleList()
466
+ for i, cl in enumerate(conv_layers):
467
+ assert len(cl) == 3
468
+ (dim, k, stride) = cl
469
+ self.conv_layers.append(
470
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
471
+ )
472
+ self.conv_layers.append(
473
+ torch.nn.LayerNorm([dim, idim])
474
+ )
475
+ self.conv_layers.append(torch.nn.ReLU())
476
+ in_d = dim
477
+ if (i + 1) % 2 == 0:
478
+ self.conv_layers.append(
479
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
480
+ )
481
+ idim = int(math.ceil(idim / 2))
482
+ else:
483
+ pass
484
+
485
+ def forward(self, x, mask=None):
486
+
487
+ # BxT -> BxCxT
488
+ x = x.unsqueeze(1)
489
+ if self.conv_type == "custom":
490
+ for conv in self.conv_layers:
491
+ if isinstance(conv, nn.LayerNorm):
492
+ x = x.transpose(1, 2)
493
+ x = conv(x).transpose(1, 2)
494
+ else:
495
+ x = conv(x)
496
+ x = x.transpose(2, 3).contiguous()
497
+ x = x.view(x.size(0), -1, x.size(-1))
498
+ else:
499
+ for conv in self.conv_layers:
500
+ x = conv(x)
501
+ if self.conv_type == "conv2d":
502
+ b, c, t, f = x.size()
503
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
504
+ return x
505
+
506
+
507
+ class TransformerEncoder(nn.Module):
508
+ def __init__(self, args):
509
+ super().__init__()
510
+
511
+ self.dropout = args.dropout
512
+ self.embedding_dim = args.encoder_embed_dim
513
+
514
+ self.pos_conv = nn.Conv1d(
515
+ self.embedding_dim,
516
+ self.embedding_dim,
517
+ kernel_size=args.conv_pos,
518
+ padding=args.conv_pos // 2,
519
+ groups=args.conv_pos_groups,
520
+ )
521
+ dropout = 0
522
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
523
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
524
+ nn.init.constant_(self.pos_conv.bias, 0)
525
+
526
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
527
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
528
+
529
+ if hasattr(args, "relative_position_embedding"):
530
+ self.relative_position_embedding = args.relative_position_embedding
531
+ self.num_buckets = args.num_buckets
532
+ self.max_distance = args.max_distance
533
+ else:
534
+ self.relative_position_embedding = False
535
+ self.num_buckets = 0
536
+ self.max_distance = 0
537
+
538
+ self.layers = nn.ModuleList(
539
+ [
540
+ TransformerSentenceEncoderLayer(
541
+ embedding_dim=self.embedding_dim,
542
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
543
+ num_attention_heads=args.encoder_attention_heads,
544
+ dropout=self.dropout,
545
+ attention_dropout=args.attention_dropout,
546
+ activation_dropout=args.activation_dropout,
547
+ activation_fn=args.activation_fn,
548
+ layer_norm_first=args.layer_norm_first,
549
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
550
+ num_buckets=self.num_buckets,
551
+ max_distance=self.max_distance,
552
+ gru_rel_pos=args.gru_rel_pos,
553
+ )
554
+ for i in range(args.encoder_layers)
555
+ ]
556
+ )
557
+
558
+ self.layer_norm_first = args.layer_norm_first
559
+ self.layer_norm = LayerNorm(self.embedding_dim)
560
+ self.layerdrop = args.encoder_layerdrop
561
+
562
+ self.apply(init_bert_params)
563
+
564
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
565
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
566
+
567
+ if self.layer_norm_first and layer is None:
568
+ x = self.layer_norm(x)
569
+
570
+ return x, layer_results
571
+
572
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
573
+
574
+ if padding_mask is not None:
575
+ x[padding_mask] = 0
576
+
577
+ x_conv = self.pos_conv(x.transpose(1, 2))
578
+ x_conv = x_conv.transpose(1, 2)
579
+ x = x + x_conv
580
+
581
+ if not self.layer_norm_first:
582
+ x = self.layer_norm(x)
583
+
584
+ x = F.dropout(x, p=self.dropout, training=self.training)
585
+
586
+ # B x T x C -> T x B x C
587
+ x = x.transpose(0, 1)
588
+
589
+ layer_results = []
590
+ z = None
591
+ if tgt_layer is not None:
592
+ layer_results.append((x, z))
593
+ r = None
594
+ pos_bias = None
595
+ for i, layer in enumerate(self.layers):
596
+ dropout_probability = np.random.random()
597
+ if not self.training or (dropout_probability > self.layerdrop):
598
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
599
+ self_attn_mask=streaming_mask, pos_bias=pos_bias)
600
+ if tgt_layer is not None:
601
+ layer_results.append((x, z))
602
+ if i == tgt_layer:
603
+ r = x
604
+ break
605
+
606
+ if r is not None:
607
+ x = r
608
+
609
+ # T x B x C -> B x T x C
610
+ x = x.transpose(0, 1)
611
+
612
+ return x, layer_results
613
+
614
+
615
+ class TransformerSentenceEncoderLayer(nn.Module):
616
+ """
617
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
618
+ models.
619
+ """
620
+
621
+ def __init__(
622
+ self,
623
+ embedding_dim: float = 768,
624
+ ffn_embedding_dim: float = 3072,
625
+ num_attention_heads: float = 8,
626
+ dropout: float = 0.1,
627
+ attention_dropout: float = 0.1,
628
+ activation_dropout: float = 0.1,
629
+ activation_fn: str = "relu",
630
+ layer_norm_first: bool = False,
631
+ has_relative_attention_bias: bool = False,
632
+ num_buckets: int = 0,
633
+ max_distance: int = 0,
634
+ rescale_init: bool = False,
635
+ gru_rel_pos: bool = False,
636
+ ) -> None:
637
+
638
+ super().__init__()
639
+ # Initialize parameters
640
+ self.embedding_dim = embedding_dim
641
+ self.dropout = dropout
642
+ self.activation_dropout = activation_dropout
643
+
644
+ # Initialize blocks
645
+ self.activation_name = activation_fn
646
+ self.activation_fn = get_activation_fn(activation_fn)
647
+ self.self_attn = MultiheadAttention(
648
+ self.embedding_dim,
649
+ num_attention_heads,
650
+ dropout=attention_dropout,
651
+ self_attention=True,
652
+ has_relative_attention_bias=has_relative_attention_bias,
653
+ num_buckets=num_buckets,
654
+ max_distance=max_distance,
655
+ rescale_init=rescale_init,
656
+ gru_rel_pos=gru_rel_pos,
657
+ )
658
+
659
+ self.dropout1 = nn.Dropout(dropout)
660
+ self.dropout2 = nn.Dropout(self.activation_dropout)
661
+ self.dropout3 = nn.Dropout(dropout)
662
+
663
+ self.layer_norm_first = layer_norm_first
664
+
665
+ # layer norm associated with the self attention layer
666
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
667
+
668
+ if self.activation_name == "glu":
669
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
670
+ else:
671
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
672
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
673
+
674
+ # layer norm associated with the position wise feed-forward NN
675
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
676
+
677
+ def forward(
678
+ self,
679
+ x: torch.Tensor,
680
+ self_attn_mask: torch.Tensor = None,
681
+ self_attn_padding_mask: torch.Tensor = None,
682
+ need_weights: bool = False,
683
+ pos_bias=None
684
+ ):
685
+ """
686
+ LayerNorm is applied either before or after the self-attention/ffn
687
+ modules similar to the original Transformer imlementation.
688
+ """
689
+ residual = x
690
+
691
+ if self.layer_norm_first:
692
+ x = self.self_attn_layer_norm(x)
693
+ x, attn, pos_bias = self.self_attn(
694
+ query=x,
695
+ key=x,
696
+ value=x,
697
+ key_padding_mask=self_attn_padding_mask,
698
+ need_weights=False,
699
+ attn_mask=self_attn_mask,
700
+ position_bias=pos_bias
701
+ )
702
+ x = self.dropout1(x)
703
+ x = residual + x
704
+
705
+ residual = x
706
+ x = self.final_layer_norm(x)
707
+ if self.activation_name == "glu":
708
+ x = self.fc1(x)
709
+ else:
710
+ x = self.activation_fn(self.fc1(x))
711
+ x = self.dropout2(x)
712
+ x = self.fc2(x)
713
+ x = self.dropout3(x)
714
+ x = residual + x
715
+ else:
716
+ x, attn, pos_bias = self.self_attn(
717
+ query=x,
718
+ key=x,
719
+ value=x,
720
+ key_padding_mask=self_attn_padding_mask,
721
+ need_weights=need_weights,
722
+ attn_mask=self_attn_mask,
723
+ position_bias=pos_bias
724
+ )
725
+
726
+ x = self.dropout1(x)
727
+ x = residual + x
728
+
729
+ x = self.self_attn_layer_norm(x)
730
+
731
+ residual = x
732
+ if self.activation_name == "glu":
733
+ x = self.fc1(x)
734
+ else:
735
+ x = self.activation_fn(self.fc1(x))
736
+ x = self.dropout2(x)
737
+ x = self.fc2(x)
738
+ x = self.dropout3(x)
739
+ x = residual + x
740
+ x = self.final_layer_norm(x)
741
+
742
+ return x, attn, pos_bias
743
+
slam_llm/models/wavlm/modules.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class TransposeLast(nn.Module):
20
+ def __init__(self, deconstruct_idx=None):
21
+ super().__init__()
22
+ self.deconstruct_idx = deconstruct_idx
23
+
24
+ def forward(self, x):
25
+ if self.deconstruct_idx is not None:
26
+ x = x[self.deconstruct_idx]
27
+ return x.transpose(-2, -1)
28
+
29
+
30
+ class Fp32LayerNorm(nn.LayerNorm):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ def forward(self, input):
35
+ output = F.layer_norm(
36
+ input.float(),
37
+ self.normalized_shape,
38
+ self.weight.float() if self.weight is not None else None,
39
+ self.bias.float() if self.bias is not None else None,
40
+ self.eps,
41
+ )
42
+ return output.type_as(input)
43
+
44
+
45
+ class Fp32GroupNorm(nn.GroupNorm):
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def forward(self, input):
50
+ output = F.group_norm(
51
+ input.float(),
52
+ self.num_groups,
53
+ self.weight.float() if self.weight is not None else None,
54
+ self.bias.float() if self.bias is not None else None,
55
+ self.eps,
56
+ )
57
+ return output.type_as(input)
58
+
59
+
60
+ class GradMultiply(torch.autograd.Function):
61
+ @staticmethod
62
+ def forward(ctx, x, scale):
63
+ ctx.scale = scale
64
+ res = x.new(x)
65
+ return res
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad):
69
+ return grad * ctx.scale, None
70
+
71
+
72
+ class SamePad(nn.Module):
73
+ def __init__(self, kernel_size, causal=False):
74
+ super().__init__()
75
+ if causal:
76
+ self.remove = kernel_size - 1
77
+ else:
78
+ self.remove = 1 if kernel_size % 2 == 0 else 0
79
+
80
+ def forward(self, x):
81
+ if self.remove > 0:
82
+ x = x[:, :, : -self.remove]
83
+ return x
84
+
85
+
86
+ class Swish(nn.Module):
87
+ """Swish function
88
+ """
89
+
90
+ def __init__(self):
91
+ """Construct an MultiHeadedAttention object."""
92
+ super(Swish, self).__init__()
93
+ self.act = torch.nn.Sigmoid()
94
+
95
+ def forward(self, x):
96
+ return x * self.act(x)
97
+
98
+
99
+ class GLU_Linear(nn.Module):
100
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
101
+ super(GLU_Linear, self).__init__()
102
+
103
+ self.glu_type = glu_type
104
+ self.output_dim = output_dim
105
+
106
+ if glu_type == "sigmoid":
107
+ self.glu_act = torch.nn.Sigmoid()
108
+ elif glu_type == "swish":
109
+ self.glu_act = Swish()
110
+ elif glu_type == "relu":
111
+ self.glu_act = torch.nn.ReLU()
112
+ elif glu_type == "gelu":
113
+ self.glu_act = torch.nn.GELU()
114
+
115
+ if bias_in_glu:
116
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
117
+ else:
118
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
119
+
120
+ def forward(self, x):
121
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
122
+ x = self.linear(x)
123
+
124
+ if self.glu_type == "bilinear":
125
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
126
+ else:
127
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
128
+
129
+ return x
130
+
131
+
132
+ def gelu_accurate(x):
133
+ if not hasattr(gelu_accurate, "_a"):
134
+ gelu_accurate._a = math.sqrt(2 / math.pi)
135
+ return (
136
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
137
+ )
138
+
139
+
140
+ def gelu(x: torch.Tensor) -> torch.Tensor:
141
+ return torch.nn.functional.gelu(x.float()).type_as(x)
142
+
143
+
144
+ def get_activation_fn(activation: str):
145
+ """Returns the activation function corresponding to `activation`"""
146
+
147
+ if activation == "relu":
148
+ return F.relu
149
+ elif activation == "gelu":
150
+ return gelu
151
+ elif activation == "gelu_fast":
152
+ warnings.warn(
153
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
154
+ )
155
+ return gelu_accurate
156
+ elif activation == "gelu_accurate":
157
+ return gelu_accurate
158
+ elif activation == "tanh":
159
+ return torch.tanh
160
+ elif activation == "linear":
161
+ return lambda x: x
162
+ elif activation == "glu":
163
+ return lambda x: x
164
+ else:
165
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
166
+
167
+
168
+ def init_bert_params(module):
169
+ """
170
+ Initialize the weights specific to the BERT Model.
171
+ This overrides the default initializations depending on the specified arguments.
172
+ 1. If normal_init_linear_weights is set then weights of linear
173
+ layer will be initialized using the normal distribution and
174
+ bais will be set to the specified value.
175
+ 2. If normal_init_embed_weights is set then weights of embedding
176
+ layer will be initialized using the normal distribution.
177
+ 3. If normal_init_proj_weights is set then weights of
178
+ in_project_weight for MultiHeadAttention initialized using
179
+ the normal distribution (to be validated).
180
+ """
181
+
182
+ def normal_(data):
183
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
184
+ # so that the RNG is consistent with and without FSDP
185
+ data.copy_(
186
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
187
+ )
188
+
189
+ if isinstance(module, nn.Linear):
190
+ normal_(module.weight.data)
191
+ if module.bias is not None:
192
+ module.bias.data.zero_()
193
+ if isinstance(module, nn.Embedding):
194
+ normal_(module.weight.data)
195
+ if module.padding_idx is not None:
196
+ module.weight.data[module.padding_idx].zero_()
197
+ if isinstance(module, MultiheadAttention):
198
+ normal_(module.q_proj.weight.data)
199
+ normal_(module.k_proj.weight.data)
200
+ normal_(module.v_proj.weight.data)
201
+
202
+
203
+ def quant_noise(module, p, block_size):
204
+ """
205
+ Wraps modules and applies quantization noise to the weights for
206
+ subsequent quantization with Iterative Product Quantization as
207
+ described in "Training with Quantization Noise for Extreme Model Compression"
208
+
209
+ Args:
210
+ - module: nn.Module
211
+ - p: amount of Quantization Noise
212
+ - block_size: size of the blocks for subsequent quantization with iPQ
213
+
214
+ Remarks:
215
+ - Module weights must have the right sizes wrt the block size
216
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
217
+ - For more detail on how to quantize by blocks with convolutional weights,
218
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
219
+ - We implement the simplest form of noise here as stated in the paper
220
+ which consists in randomly dropping blocks
221
+ """
222
+
223
+ # if no quantization noise, don't register hook
224
+ if p <= 0:
225
+ return module
226
+
227
+ # supported modules
228
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
229
+
230
+ # test whether module.weight has the right sizes wrt block_size
231
+ is_conv = module.weight.ndim == 4
232
+
233
+ # 2D matrix
234
+ if not is_conv:
235
+ assert (
236
+ module.weight.size(1) % block_size == 0
237
+ ), "Input features must be a multiple of block sizes"
238
+
239
+ # 4D matrix
240
+ else:
241
+ # 1x1 convolutions
242
+ if module.kernel_size == (1, 1):
243
+ assert (
244
+ module.in_channels % block_size == 0
245
+ ), "Input channels must be a multiple of block sizes"
246
+ # regular convolutions
247
+ else:
248
+ k = module.kernel_size[0] * module.kernel_size[1]
249
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
250
+
251
+ def _forward_pre_hook(mod, input):
252
+ # no noise for evaluation
253
+ if mod.training:
254
+ if not is_conv:
255
+ # gather weight and sizes
256
+ weight = mod.weight
257
+ in_features = weight.size(1)
258
+ out_features = weight.size(0)
259
+
260
+ # split weight matrix into blocks and randomly drop selected blocks
261
+ mask = torch.zeros(
262
+ in_features // block_size * out_features, device=weight.device
263
+ )
264
+ mask.bernoulli_(p)
265
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
266
+
267
+ else:
268
+ # gather weight and sizes
269
+ weight = mod.weight
270
+ in_channels = mod.in_channels
271
+ out_channels = mod.out_channels
272
+
273
+ # split weight matrix into blocks and randomly drop selected blocks
274
+ if mod.kernel_size == (1, 1):
275
+ mask = torch.zeros(
276
+ int(in_channels // block_size * out_channels),
277
+ device=weight.device,
278
+ )
279
+ mask.bernoulli_(p)
280
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
281
+ else:
282
+ mask = torch.zeros(
283
+ weight.size(0), weight.size(1), device=weight.device
284
+ )
285
+ mask.bernoulli_(p)
286
+ mask = (
287
+ mask.unsqueeze(2)
288
+ .unsqueeze(3)
289
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
290
+ )
291
+
292
+ # scale weights and apply mask
293
+ mask = mask.to(
294
+ torch.bool
295
+ ) # x.bool() is not currently supported in TorchScript
296
+ s = 1 / (1 - p)
297
+ mod.weight.data = s * weight.masked_fill(mask, 0)
298
+
299
+ module.register_forward_pre_hook(_forward_pre_hook)
300
+ return module
301
+
302
+
303
+ class MultiheadAttention(nn.Module):
304
+ """Multi-headed attention.
305
+
306
+ See "Attention Is All You Need" for more details.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ embed_dim,
312
+ num_heads,
313
+ kdim=None,
314
+ vdim=None,
315
+ dropout=0.0,
316
+ bias=True,
317
+ add_bias_kv=False,
318
+ add_zero_attn=False,
319
+ self_attention=False,
320
+ encoder_decoder_attention=False,
321
+ q_noise=0.0,
322
+ qn_block_size=8,
323
+ has_relative_attention_bias=False,
324
+ num_buckets=32,
325
+ max_distance=128,
326
+ gru_rel_pos=False,
327
+ rescale_init=False,
328
+ ):
329
+ super().__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout_module = nn.Dropout(dropout)
337
+
338
+ self.has_relative_attention_bias = has_relative_attention_bias
339
+ self.num_buckets = num_buckets
340
+ self.max_distance = max_distance
341
+ if self.has_relative_attention_bias:
342
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
343
+
344
+ self.head_dim = embed_dim // num_heads
345
+ self.q_head_dim = self.head_dim
346
+ self.k_head_dim = self.head_dim
347
+ assert (
348
+ self.head_dim * num_heads == self.embed_dim
349
+ ), "embed_dim must be divisible by num_heads"
350
+ self.scaling = self.head_dim ** -0.5
351
+
352
+ self.self_attention = self_attention
353
+ self.encoder_decoder_attention = encoder_decoder_attention
354
+
355
+ assert not self.self_attention or self.qkv_same_dim, (
356
+ "Self-attention requires query, key and " "value to be of the same size"
357
+ )
358
+
359
+ k_bias = True
360
+ if rescale_init:
361
+ k_bias = False
362
+
363
+ k_embed_dim = embed_dim
364
+ q_embed_dim = embed_dim
365
+
366
+ self.k_proj = quant_noise(
367
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
368
+ )
369
+ self.v_proj = quant_noise(
370
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
371
+ )
372
+ self.q_proj = quant_noise(
373
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
374
+ )
375
+
376
+ self.out_proj = quant_noise(
377
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
378
+ )
379
+
380
+ if add_bias_kv:
381
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
382
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
383
+ else:
384
+ self.bias_k = self.bias_v = None
385
+
386
+ self.add_zero_attn = add_zero_attn
387
+
388
+ self.gru_rel_pos = gru_rel_pos
389
+ if self.gru_rel_pos:
390
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
391
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
392
+
393
+ self.reset_parameters()
394
+
395
+ def reset_parameters(self):
396
+ if self.qkv_same_dim:
397
+ # Empirically observed the convergence to be much better with
398
+ # the scaled initialization
399
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
400
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
401
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
402
+ else:
403
+ nn.init.xavier_uniform_(self.k_proj.weight)
404
+ nn.init.xavier_uniform_(self.v_proj.weight)
405
+ nn.init.xavier_uniform_(self.q_proj.weight)
406
+
407
+ nn.init.xavier_uniform_(self.out_proj.weight)
408
+ if self.out_proj.bias is not None:
409
+ nn.init.constant_(self.out_proj.bias, 0.0)
410
+ if self.bias_k is not None:
411
+ nn.init.xavier_normal_(self.bias_k)
412
+ if self.bias_v is not None:
413
+ nn.init.xavier_normal_(self.bias_v)
414
+ if self.has_relative_attention_bias:
415
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
416
+
417
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
418
+ num_buckets = self.num_buckets
419
+ max_distance = self.max_distance
420
+ relative_buckets = 0
421
+
422
+ if bidirectional:
423
+ num_buckets = num_buckets // 2
424
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
425
+ relative_positions = torch.abs(relative_positions)
426
+ else:
427
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
428
+
429
+ max_exact = num_buckets // 2
430
+ is_small = relative_positions < max_exact
431
+
432
+ relative_postion_if_large = max_exact + (
433
+ torch.log(relative_positions.float() / max_exact)
434
+ / math.log(max_distance / max_exact)
435
+ * (num_buckets - max_exact)
436
+ ).to(torch.long)
437
+ relative_postion_if_large = torch.min(
438
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
439
+ )
440
+
441
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
442
+ return relative_buckets
443
+
444
+ def compute_bias(self, query_length, key_length):
445
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
446
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
447
+ relative_position = memory_position - context_position
448
+ relative_position_bucket = self._relative_positions_bucket(
449
+ relative_position,
450
+ bidirectional=True
451
+ )
452
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
453
+ values = self.relative_attention_bias(relative_position_bucket)
454
+ values = values.permute([2, 0, 1])
455
+ return values
456
+
457
+ def forward(
458
+ self,
459
+ query,
460
+ key: Optional[Tensor],
461
+ value: Optional[Tensor],
462
+ key_padding_mask: Optional[Tensor] = None,
463
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
464
+ need_weights: bool = True,
465
+ static_kv: bool = False,
466
+ attn_mask: Optional[Tensor] = None,
467
+ before_softmax: bool = False,
468
+ need_head_weights: bool = False,
469
+ position_bias: Optional[Tensor] = None
470
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
471
+ """Input shape: Time x Batch x Channel
472
+
473
+ Args:
474
+ key_padding_mask (ByteTensor, optional): mask to exclude
475
+ keys that are pads, of shape `(batch, src_len)`, where
476
+ padding elements are indicated by 1s.
477
+ need_weights (bool, optional): return the attention weights,
478
+ averaged over heads (default: False).
479
+ attn_mask (ByteTensor, optional): typically used to
480
+ implement causal attention, where the mask prevents the
481
+ attention from looking forward in time (default: None).
482
+ before_softmax (bool, optional): return the raw attention
483
+ weights and values before the attention softmax.
484
+ need_head_weights (bool, optional): return the attention
485
+ weights for each head. Implies *need_weights*. Default:
486
+ return the average attention weights over all heads.
487
+ """
488
+ if need_head_weights:
489
+ need_weights = True
490
+
491
+ is_tpu = query.device.type == "xla"
492
+
493
+ tgt_len, bsz, embed_dim = query.size()
494
+ src_len = tgt_len
495
+ assert embed_dim == self.embed_dim
496
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
497
+ if key is not None:
498
+ src_len, key_bsz, _ = key.size()
499
+ if not torch.jit.is_scripting():
500
+ assert key_bsz == bsz
501
+ assert value is not None
502
+ assert src_len, bsz == value.shape[:2]
503
+
504
+ if self.has_relative_attention_bias and position_bias is None:
505
+ position_bias = self.compute_bias(tgt_len, src_len)
506
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
507
+
508
+ if (
509
+ not is_tpu # don't use PyTorch version on TPUs
510
+ and incremental_state is None
511
+ and not static_kv
512
+ # A workaround for quantization to work. Otherwise JIT compilation
513
+ # treats bias in linear module as method.
514
+ and not torch.jit.is_scripting()
515
+ and self.q_head_dim == self.head_dim
516
+ ):
517
+ assert key is not None and value is not None
518
+ assert attn_mask is None
519
+
520
+ attn_mask_rel_pos = None
521
+ if position_bias is not None:
522
+ attn_mask_rel_pos = position_bias
523
+ if self.gru_rel_pos:
524
+ query_layer = query.transpose(0, 1)
525
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
526
+ query_layer = query_layer.view(*new_x_shape)
527
+ query_layer = query_layer.permute(0, 2, 1, 3)
528
+ _B, _H, _L, __ = query_layer.size()
529
+
530
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
531
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
532
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
533
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
534
+
535
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
536
+ k_proj_bias = self.k_proj.bias
537
+ if k_proj_bias is None:
538
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
539
+
540
+ x, attn = F.multi_head_attention_forward(
541
+ query,
542
+ key,
543
+ value,
544
+ self.embed_dim,
545
+ self.num_heads,
546
+ torch.empty([0]),
547
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
548
+ self.bias_k,
549
+ self.bias_v,
550
+ self.add_zero_attn,
551
+ self.dropout_module.p,
552
+ self.out_proj.weight,
553
+ self.out_proj.bias,
554
+ self.training,
555
+ # self.training or self.dropout_module.apply_during_inference,
556
+ key_padding_mask,
557
+ need_weights,
558
+ attn_mask_rel_pos,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj.weight,
561
+ k_proj_weight=self.k_proj.weight,
562
+ v_proj_weight=self.v_proj.weight,
563
+ )
564
+ return x, attn, position_bias
565
+
566
+ if incremental_state is not None:
567
+ saved_state = self._get_input_buffer(incremental_state)
568
+ if saved_state is not None and "prev_key" in saved_state:
569
+ # previous time steps are cached - no need to recompute
570
+ # key and value if they are static
571
+ if static_kv:
572
+ assert self.encoder_decoder_attention and not self.self_attention
573
+ key = value = None
574
+ else:
575
+ saved_state = None
576
+
577
+ if self.self_attention:
578
+ q = self.q_proj(query)
579
+ k = self.k_proj(query)
580
+ v = self.v_proj(query)
581
+ elif self.encoder_decoder_attention:
582
+ # encoder-decoder attention
583
+ q = self.q_proj(query)
584
+ if key is None:
585
+ assert value is None
586
+ k = v = None
587
+ else:
588
+ k = self.k_proj(key)
589
+ v = self.v_proj(key)
590
+
591
+ else:
592
+ assert key is not None and value is not None
593
+ q = self.q_proj(query)
594
+ k = self.k_proj(key)
595
+ v = self.v_proj(value)
596
+ q *= self.scaling
597
+
598
+ if self.bias_k is not None:
599
+ assert self.bias_v is not None
600
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
601
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
602
+ if attn_mask is not None:
603
+ attn_mask = torch.cat(
604
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
605
+ )
606
+ if key_padding_mask is not None:
607
+ key_padding_mask = torch.cat(
608
+ [
609
+ key_padding_mask,
610
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
611
+ ],
612
+ dim=1,
613
+ )
614
+
615
+ q = (
616
+ q.contiguous()
617
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
618
+ .transpose(0, 1)
619
+ )
620
+ if k is not None:
621
+ k = (
622
+ k.contiguous()
623
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
624
+ .transpose(0, 1)
625
+ )
626
+ if v is not None:
627
+ v = (
628
+ v.contiguous()
629
+ .view(-1, bsz * self.num_heads, self.head_dim)
630
+ .transpose(0, 1)
631
+ )
632
+
633
+ if saved_state is not None:
634
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
635
+ if "prev_key" in saved_state:
636
+ _prev_key = saved_state["prev_key"]
637
+ assert _prev_key is not None
638
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
639
+ if static_kv:
640
+ k = prev_key
641
+ else:
642
+ assert k is not None
643
+ k = torch.cat([prev_key, k], dim=1)
644
+ src_len = k.size(1)
645
+ if "prev_value" in saved_state:
646
+ _prev_value = saved_state["prev_value"]
647
+ assert _prev_value is not None
648
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
649
+ if static_kv:
650
+ v = prev_value
651
+ else:
652
+ assert v is not None
653
+ v = torch.cat([prev_value, v], dim=1)
654
+ prev_key_padding_mask: Optional[Tensor] = None
655
+ if "prev_key_padding_mask" in saved_state:
656
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
657
+ assert k is not None and v is not None
658
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
659
+ key_padding_mask=key_padding_mask,
660
+ prev_key_padding_mask=prev_key_padding_mask,
661
+ batch_size=bsz,
662
+ src_len=k.size(1),
663
+ static_kv=static_kv,
664
+ )
665
+
666
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
667
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
668
+ saved_state["prev_key_padding_mask"] = key_padding_mask
669
+ # In this branch incremental_state is never None
670
+ assert incremental_state is not None
671
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
672
+ assert k is not None
673
+ assert k.size(1) == src_len
674
+
675
+ # This is part of a workaround to get around fork/join parallelism
676
+ # not supporting Optional types.
677
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
678
+ key_padding_mask = None
679
+
680
+ if key_padding_mask is not None:
681
+ assert key_padding_mask.size(0) == bsz
682
+ assert key_padding_mask.size(1) == src_len
683
+
684
+ if self.add_zero_attn:
685
+ assert v is not None
686
+ src_len += 1
687
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
688
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
689
+ if attn_mask is not None:
690
+ attn_mask = torch.cat(
691
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
692
+ )
693
+ if key_padding_mask is not None:
694
+ key_padding_mask = torch.cat(
695
+ [
696
+ key_padding_mask,
697
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
698
+ key_padding_mask
699
+ ),
700
+ ],
701
+ dim=1,
702
+ )
703
+
704
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
705
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
706
+
707
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
708
+
709
+ if attn_mask is not None:
710
+ attn_mask = attn_mask.unsqueeze(0)
711
+ attn_weights += attn_mask
712
+
713
+ if key_padding_mask is not None:
714
+ # don't attend to padding symbols
715
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
716
+ if not is_tpu:
717
+ attn_weights = attn_weights.masked_fill(
718
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
719
+ float("-inf"),
720
+ )
721
+ else:
722
+ attn_weights = attn_weights.transpose(0, 2)
723
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
724
+ attn_weights = attn_weights.transpose(0, 2)
725
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
726
+
727
+ if before_softmax:
728
+ return attn_weights, v, position_bias
729
+
730
+ if position_bias is not None:
731
+ if self.gru_rel_pos == 1:
732
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
733
+ _B, _H, _L, __ = query_layer.size()
734
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
735
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
736
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
737
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
738
+
739
+ position_bias = position_bias.view(attn_weights.size())
740
+
741
+ attn_weights = attn_weights + position_bias
742
+
743
+ attn_weights_float = F.softmax(
744
+ attn_weights, dim=-1
745
+ )
746
+ attn_weights = attn_weights_float.type_as(attn_weights)
747
+ attn_probs = self.dropout_module(attn_weights)
748
+
749
+ assert v is not None
750
+ attn = torch.bmm(attn_probs, v)
751
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
752
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
753
+ attn = self.out_proj(attn)
754
+ attn_weights: Optional[Tensor] = None
755
+ if need_weights:
756
+ attn_weights = attn_weights_float.view(
757
+ bsz, self.num_heads, tgt_len, src_len
758
+ ).transpose(1, 0)
759
+ if not need_head_weights:
760
+ # average attention weights over heads
761
+ attn_weights = attn_weights.mean(dim=0)
762
+
763
+ return attn, attn_weights, position_bias
764
+
765
+ @staticmethod
766
+ def _append_prev_key_padding_mask(
767
+ key_padding_mask: Optional[Tensor],
768
+ prev_key_padding_mask: Optional[Tensor],
769
+ batch_size: int,
770
+ src_len: int,
771
+ static_kv: bool,
772
+ ) -> Optional[Tensor]:
773
+ # saved key padding masks have shape (bsz, seq_len)
774
+ if prev_key_padding_mask is not None and static_kv:
775
+ new_key_padding_mask = prev_key_padding_mask
776
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
777
+ new_key_padding_mask = torch.cat(
778
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
779
+ )
780
+ # During incremental decoding, as the padding token enters and
781
+ # leaves the frame, there will be a time when prev or current
782
+ # is None
783
+ elif prev_key_padding_mask is not None:
784
+ if src_len > prev_key_padding_mask.size(1):
785
+ filler = torch.zeros(
786
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
787
+ device=prev_key_padding_mask.device,
788
+ )
789
+ new_key_padding_mask = torch.cat(
790
+ [prev_key_padding_mask.float(), filler.float()], dim=1
791
+ )
792
+ else:
793
+ new_key_padding_mask = prev_key_padding_mask.float()
794
+ elif key_padding_mask is not None:
795
+ if src_len > key_padding_mask.size(1):
796
+ filler = torch.zeros(
797
+ (batch_size, src_len - key_padding_mask.size(1)),
798
+ device=key_padding_mask.device,
799
+ )
800
+ new_key_padding_mask = torch.cat(
801
+ [filler.float(), key_padding_mask.float()], dim=1
802
+ )
803
+ else:
804
+ new_key_padding_mask = key_padding_mask.float()
805
+ else:
806
+ new_key_padding_mask = prev_key_padding_mask
807
+ return new_key_padding_mask
808
+
809
+ def _get_input_buffer(
810
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
811
+ ) -> Dict[str, Optional[Tensor]]:
812
+ result = self.get_incremental_state(incremental_state, "attn_state")
813
+ if result is not None:
814
+ return result
815
+ else:
816
+ empty_result: Dict[str, Optional[Tensor]] = {}
817
+ return empty_result
818
+
819
+ def _set_input_buffer(
820
+ self,
821
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
822
+ buffer: Dict[str, Optional[Tensor]],
823
+ ):
824
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
825
+
826
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
827
+ return attn_weights
slam_llm/policies/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ from slam_llm.policies.mixed_precision import *
5
+ from slam_llm.policies.wrapping import *
6
+ from slam_llm.policies.activation_checkpointing_functions import apply_fsdp_checkpointing
7
+ from slam_llm.policies.anyprecision_optimizer import AnyPrecisionAdamW
slam_llm/policies/activation_checkpointing_functions.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ from functools import partial
5
+
6
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
7
+ checkpoint_wrapper,
8
+ CheckpointImpl,
9
+ apply_activation_checkpointing,
10
+ )
11
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
12
+
13
+ non_reentrant_wrapper = partial(
14
+ checkpoint_wrapper,
15
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
16
+ )
17
+
18
+ check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
19
+
20
+
21
+ def apply_fsdp_checkpointing(model):
22
+ """apply activation checkpointing to model
23
+ returns None as model is updated directly
24
+ """
25
+ print(f"--> applying fsdp activation checkpointing...")
26
+
27
+ apply_activation_checkpointing(
28
+ model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
29
+ )
slam_llm/policies/anyprecision_optimizer.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ # AnyPrecisionAdamW: a flexible precision AdamW optimizer
5
+ # with optional Kahan summation for high precision weight updates.
6
+ # Allows direct control over momentum, variance and auxiliary compensation
7
+ # buffer dtypes.
8
+ # Optional Kahan summation is used to offset precision reduction for
9
+ # the weight updates. This allows full training in BFloat16 (equal or
10
+ # better than FP32 results in many cases) due to high precision weight upates.
11
+
12
+ import torch
13
+ from torch.optim.optimizer import Optimizer
14
+
15
+
16
+ class AnyPrecisionAdamW(Optimizer):
17
+ def __init__(
18
+ self,
19
+ params,
20
+ lr=1e-3,
21
+ betas=(0.9, 0.999),
22
+ eps=1e-8,
23
+ weight_decay=0.0,
24
+ use_kahan_summation=False,
25
+ momentum_dtype=torch.bfloat16,
26
+ variance_dtype=torch.bfloat16,
27
+ compensation_buffer_dtype=torch.bfloat16,
28
+ ):
29
+ """
30
+ Args:
31
+ params (iterable): iterable of parameters to optimize or dicts defining
32
+ parameter groups
33
+ lr (float, optional): learning rate (default: 1e-3)
34
+ betas (Tuple[float, float], optional): coefficients used for computing
35
+ running averages of gradient and its square (default: (0.9, 0.999))
36
+ eps (float, optional): term added to the denominator to improve
37
+ numerical stability (default: 1e-8)
38
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
39
+
40
+ # Any Precision specific
41
+ use_kahan_summation = creates auxiliary buffer to ensure high precision
42
+ model param updates (default: False)
43
+ momentum_dtype = dtype for momentum (default: BFloat32)
44
+ variance_dtype = dtype for uncentered variance (default: BFloat16)
45
+ compensation_buffer_dtype = dtype for Kahan summation
46
+ buffer (default: BFloat16)
47
+
48
+ # Usage
49
+ This optimizer implements optimizer states, and Kahan summation
50
+ for high precision updates, all in user controlled dtypes.
51
+ Defaults are variance in BF16, Momentum in FP32.
52
+ This can be run in FSDP mixed precision, amp, or full precision,
53
+ depending on what training pipeline you wish to work with.
54
+
55
+ Setting to use_kahan_summation = False, and changing momentum and
56
+ variance dtypes to FP32, reverts this to a standard AdamW optimizer.
57
+
58
+ """
59
+ defaults = dict(
60
+ lr=lr,
61
+ betas=betas,
62
+ eps=eps,
63
+ weight_decay=weight_decay,
64
+ use_kahan_summation=use_kahan_summation,
65
+ momentum_dtype=momentum_dtype,
66
+ variance_dtype=variance_dtype,
67
+ compensation_buffer_dtype=compensation_buffer_dtype,
68
+ )
69
+
70
+ super().__init__(params, defaults)
71
+
72
+ @torch.no_grad()
73
+ def step(self, closure=None):
74
+ """Performs a single optimization step.
75
+ Args:
76
+ closure (callable, optional): A closure that reevaluates the model
77
+ and returns the loss.
78
+ """
79
+
80
+ if closure is not None:
81
+ with torch.enable_grad():
82
+ # to fix linter, we do not keep the returned loss for use atm.
83
+ closure()
84
+
85
+ for group in self.param_groups:
86
+
87
+ beta1, beta2 = group["betas"]
88
+ lr = group["lr"]
89
+ weight_decay = group["weight_decay"]
90
+ eps = group["eps"]
91
+ use_kahan_summation = group["use_kahan_summation"]
92
+
93
+ momentum_dtype = group["momentum_dtype"]
94
+ variance_dtype = group["variance_dtype"]
95
+ compensation_buffer_dtype = group["compensation_buffer_dtype"]
96
+
97
+ for p in group["params"]:
98
+ if p.grad is None:
99
+ continue
100
+
101
+ if p.grad.is_sparse:
102
+ raise RuntimeError(
103
+ "AnyPrecisionAdamW does not support sparse gradients"
104
+ )
105
+
106
+ state = self.state[p]
107
+
108
+ # State initialization
109
+ if len(state) == 0:
110
+
111
+ state["step"] = torch.tensor(0.0)
112
+
113
+ # momentum - EMA of gradient values
114
+ state["exp_avg"] = torch.zeros_like(
115
+ p,
116
+ dtype=momentum_dtype,
117
+ )
118
+
119
+ # variance uncentered - EMA of squared gradient values
120
+ state["exp_avg_sq"] = torch.zeros_like(
121
+ p,
122
+ dtype=variance_dtype,
123
+ )
124
+
125
+ # optional Kahan summation - accumulated error tracker
126
+ if use_kahan_summation:
127
+ state["compensation"] = torch.zeros_like(
128
+ p,
129
+ dtype=compensation_buffer_dtype,
130
+ )
131
+
132
+ # main processing -------------------------
133
+
134
+ # update the steps for each param group update
135
+ state["step"] += 1
136
+ step = state["step"]
137
+
138
+ exp_avg = state["exp_avg"]
139
+ exp_avg_sq = state["exp_avg_sq"]
140
+
141
+ grad = p.grad
142
+
143
+ # weight decay, AdamW style
144
+ if weight_decay:
145
+ p.data.mul_(1 - lr * weight_decay)
146
+
147
+ # update momentum
148
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
149
+
150
+ # update uncentered variance
151
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
152
+
153
+ # adjust using bias1
154
+ bias_correction1 = 1 - beta1**step
155
+
156
+ step_size = lr / bias_correction1
157
+
158
+ # adjust using bias2
159
+ denom_correction = (1 - beta2**step) ** 0.5 # avoids math import
160
+
161
+ centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(
162
+ eps, alpha=1
163
+ )
164
+
165
+ # lr update to compensation
166
+ if use_kahan_summation:
167
+ compensation = state["compensation"]
168
+
169
+ compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
170
+
171
+ # update weights with compensation (Kahan summation)
172
+ # save error back to compensation for next iteration
173
+ temp_buffer = p.detach().clone()
174
+ p.data.add_(compensation)
175
+ compensation.add_(temp_buffer.sub_(p.data))
176
+
177
+ else:
178
+ # usual AdamW updates
179
+ p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)
slam_llm/policies/mixed_precision.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import torch
5
+
6
+ from torch.distributed.fsdp import (
7
+ MixedPrecision,
8
+ )
9
+
10
+ # requires grad scaler in main loop
11
+ fpSixteen = MixedPrecision(
12
+ param_dtype=torch.float16,
13
+ # Gradient communication precision.
14
+ reduce_dtype=torch.float16,
15
+ # Buffer precision.
16
+ buffer_dtype=torch.float16,
17
+ )
18
+
19
+ bfSixteen = MixedPrecision(
20
+ param_dtype=torch.bfloat16,
21
+ # Gradient communication precision.
22
+ reduce_dtype=torch.bfloat16,
23
+ # Buffer precision.
24
+ buffer_dtype=torch.bfloat16,
25
+ cast_forward_inputs=True,
26
+ )
27
+
28
+ bfSixteen_mixed = MixedPrecision(
29
+ param_dtype=torch.float32,
30
+ reduce_dtype=torch.bfloat16,
31
+ buffer_dtype=torch.bfloat16,
32
+ )
33
+
34
+ fp32_policy = MixedPrecision(
35
+ param_dtype=torch.float32,
36
+ reduce_dtype=torch.float32,
37
+ buffer_dtype=torch.float32,
38
+ )
slam_llm/policies/wrapping.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import functools
5
+
6
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
7
+ from torch.distributed.fsdp.wrap import (
8
+ transformer_auto_wrap_policy,
9
+ size_based_auto_wrap_policy,
10
+ )
11
+
12
+
13
+ def get_size_policy(min_params=1e8):
14
+ num_wrap_policy = functools.partial(
15
+ size_based_auto_wrap_policy, min_num_params=min_params
16
+ )
17
+ return num_wrap_policy
18
+
19
+
20
+ def get_llama_wrapper():
21
+ """we register our main layer class and use the fsdp transformer wrapping policy
22
+ ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers
23
+ """
24
+ # ==== use new transformer wrapper
25
+
26
+ llama_auto_wrap_policy = functools.partial(
27
+ transformer_auto_wrap_policy,
28
+ transformer_layer_cls={
29
+ LlamaDecoderLayer,
30
+ },
31
+ )
32
+
33
+ return llama_auto_wrap_policy
slam_llm/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ from slam_llm.utils.memory_utils import MemoryTrace
5
+ from slam_llm.utils.dataset_utils import *
6
+ from slam_llm.utils.fsdp_utils import fsdp_auto_wrap_policy
7
+ from slam_llm.utils.train_utils import *