farzadab commited on
Commit
e3c2ab5
·
verified ·
1 Parent(s): 5d09428

Upload 4 files

Browse files
Files changed (3) hide show
  1. processor_config.json +1 -1
  2. ultravox_model.py +85 -48
  3. ultravox_processing.py +188 -75
processor_config.json CHANGED
@@ -5,7 +5,7 @@
5
  "auto_map": {
6
  "AutoProcessor": "ultravox_processing.UltravoxProcessor"
7
  },
8
- "encoder_ds_factor": 320,
9
  "processor_class": "UltravoxProcessor",
10
  "stack_factor": 8
11
  }
 
5
  "auto_map": {
6
  "AutoProcessor": "ultravox_processing.UltravoxProcessor"
7
  },
8
+ "encoder_ds_factor": 2,
9
  "processor_class": "UltravoxProcessor",
10
  "stack_factor": 8
11
  }
ultravox_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  import re
3
- from typing import Any, Dict, Optional, Set, Tuple, Union
4
 
5
  import peft
6
  import torch
@@ -10,6 +10,7 @@ import transformers
10
  import transformers.activations
11
  import transformers.modeling_outputs
12
  import transformers.models
 
13
  from transformers.models.whisper import modeling_whisper as whisper
14
 
15
  # We must use relative import in this directory to allow uploading to HF Hub
@@ -19,7 +20,7 @@ from .ultravox_config import LossFunction
19
  from .ultravox_config import UltravoxConfig
20
 
21
 
22
- class UltravoxModel(transformers.LlamaPreTrainedModel):
23
  """
24
  The Ultravox model which consists of an audio encoder and a language model.
25
 
@@ -37,6 +38,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
37
  config: UltravoxConfig # for type hinting
38
  # Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
39
  _keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"]
 
 
 
40
 
41
  def __init__(self, config: UltravoxConfig):
42
  super().__init__(config)
@@ -46,15 +50,16 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
46
  self.vocab_size = config.vocab_size
47
 
48
  self.audio_tower = self._create_audio_tower(config)
 
 
 
49
  self.multi_modal_projector = self._create_multi_modal_projector(config)
50
  self.language_model = self._create_language_model(config)
51
 
52
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
53
  # FSDP throws an error if some of the layer types are not found in the model.
54
- # This would be something like ["LlamaDecoderLayer", "WhisperEncoderLayer"]
55
- self._no_split_modules = (self.language_model._no_split_modules or []) + (
56
- self.audio_tower._no_split_modules or []
57
- )
58
 
59
  self.loss_config = LossConfig()
60
  self.post_init()
@@ -141,6 +146,24 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
141
  )
142
  return {"loss": kl_loss}
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def forward(
145
  self,
146
  input_ids: torch.Tensor,
@@ -149,8 +172,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
149
  labels: Optional[torch.Tensor] = None,
150
  attention_mask: Optional[torch.Tensor] = None,
151
  audio_token_start_idx: Optional[torch.Tensor] = None,
152
- audio_len: Optional[torch.Tensor] = None,
153
  audio_token_len: Optional[torch.Tensor] = None,
 
154
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
155
  # the alt_* fields are needed for KL divergence loss
156
  alt_input_ids: Optional[torch.Tensor] = None,
@@ -181,29 +205,37 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
181
  # B x T -> B x T x D
182
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
183
 
184
- if audio_values is not None:
185
  assert (
186
- audio_token_start_idx is not None and audio_token_len is not None
187
- ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
 
 
 
188
  assert (
189
- len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
190
- ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
191
-
192
- # B x A/3200 x D
 
 
 
 
 
 
193
  audio_tower_output = self.audio_tower.forward(
194
  audio_values.to(self.audio_tower.dtype),
195
- audio_len=audio_len,
196
  ).last_hidden_state
197
  audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
198
-
199
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
200
 
201
  # combine audio and text embeddings
202
- for i, (audio, start, length) in enumerate(
203
- zip(audio_embeds, audio_token_start_idx, audio_token_len)
204
- ):
205
- length = min(length, audio.shape[0])
206
- inputs_embeds[i, start : start + length] = audio[:length]
207
 
208
  lm_output = self.language_model.forward(
209
  inputs_embeds=inputs_embeds,
@@ -238,7 +270,8 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
238
  audio_values: Optional[torch.FloatTensor] = None,
239
  audio_token_start_idx: Optional[torch.Tensor] = None,
240
  audio_token_len: Optional[torch.Tensor] = None,
241
- audio_len: Optional[torch.Tensor] = None,
 
242
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
243
  attention_mask: Optional[torch.Tensor] = None,
244
  inputs_embeds: Optional[torch.Tensor] = None,
@@ -267,7 +300,8 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
267
  audio_token_start_idx - prefill_start_idx
268
  )
269
  model_input["audio_token_len"] = audio_token_len
270
- model_input["audio_len"] = audio_len
 
271
 
272
  return model_input
273
 
@@ -284,7 +318,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
284
  cls, config: UltravoxConfig
285
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
286
  if config.audio_model_id is not None:
287
- if "whisper" in config.audio_model_id is not None:
288
  audio_tower = ModifiedWhisperEncoder.from_pretrained(
289
  config.audio_model_id, torch_dtype=config.torch_dtype
290
  )
@@ -300,7 +334,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
300
  config.audio_model_id, torch_dtype=config.torch_dtype
301
  )
302
  else:
303
- if "whisper" in config.audio_config._name_or_path:
304
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
305
  audio_tower.init_latency_mask(
306
  config.audio_latency_block_size, dtype=config.torch_dtype
@@ -393,13 +427,17 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
393
  if state_dict is None:
394
  state_dict = super().state_dict()
395
 
396
- named_params = dict(self.named_parameters())
 
 
 
 
 
397
 
398
  state_dict = {
399
  k: v
400
  for k, v in state_dict.items()
401
- if k in self.keep_params
402
- or (k in named_params and named_params[k].requires_grad)
403
  }
404
 
405
  return state_dict
@@ -445,7 +483,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
445
 
446
  # TODO: refactor common parts to a shared module
447
  def is_cache_empty(
448
- past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]]
449
  ) -> bool:
450
  """
451
  Check if the cache is empty.
@@ -481,12 +519,8 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
481
 
482
  class StackAudioFrames(nn.Module):
483
  """
484
- Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
485
-
486
- The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
487
- NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
488
- we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
489
- In most cases this extra padding will get removed in the model's forward function so it has no effect.
490
  """
491
 
492
  def __init__(self, stack_factor: int = 8):
@@ -496,7 +530,7 @@ class StackAudioFrames(nn.Module):
496
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
497
  B, T, C = audio_embeds.shape
498
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
499
- audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
500
  B, T, C = audio_embeds.shape
501
  audio_embeds = audio_embeds.view(
502
  B, T // self.stack_factor, C * self.stack_factor
@@ -568,17 +602,25 @@ class ModifiedWhisperEncoder(
568
  base_model_prefix = "model.encoder"
569
  _no_split_modules = ["WhisperEncoderLayer"]
570
 
571
- def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype):
572
- if audio_latency_block_size is None:
573
- self.audio_streaming_mask = None
574
- return
575
 
576
- # maximum sequence length
577
- max_seqlen = (
 
578
  self.config.max_source_positions
579
  * self.conv1.stride[0]
580
  * self.conv2.stride[0]
581
  )
 
 
 
 
 
 
 
 
582
  assert (
583
  max_seqlen > 0
584
  ), f"maximum sequence length must be positive, got {max_seqlen}"
@@ -610,11 +652,7 @@ class ModifiedWhisperEncoder(
610
  output_hidden_states=None,
611
  return_dict=None,
612
  ):
613
- expected_seq_length = (
614
- self.config.max_source_positions
615
- * self.conv1.stride[0]
616
- * self.conv2.stride[0]
617
- )
618
  if input_features.shape[-1] > expected_seq_length:
619
  raise ValueError(
620
  f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
@@ -665,7 +703,6 @@ class ModifiedWhisperEncoder(
665
  attention_mask = self.get_extended_attention_mask(
666
  attention_mask,
667
  None,
668
- device=hidden_states.device,
669
  dtype=hidden_states.dtype,
670
  )
671
 
 
1
  import logging
2
  import re
3
+ from typing import Any, Dict, Generator, Optional, Set, Tuple, Union
4
 
5
  import peft
6
  import torch
 
10
  import transformers.activations
11
  import transformers.modeling_outputs
12
  import transformers.models
13
+ from transformers.generation.utils import GenerationMixin
14
  from transformers.models.whisper import modeling_whisper as whisper
15
 
16
  # We must use relative import in this directory to allow uploading to HF Hub
 
20
  from .ultravox_config import UltravoxConfig
21
 
22
 
23
+ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
24
  """
25
  The Ultravox model which consists of an audio encoder and a language model.
26
 
 
38
  config: UltravoxConfig # for type hinting
39
  # Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
40
  _keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"]
41
+ # Since we have kwargs in forward, we need to set this to False, otherwise grad_accum_steps will cause incorrect train loss to be reported
42
+ # see https://github.com/huggingface/transformers/issues/35856 and https://github.com/huggingface/trl/pull/2615/files
43
+ accepts_loss_kwargs = False
44
 
45
  def __init__(self, config: UltravoxConfig):
46
  super().__init__(config)
 
50
  self.vocab_size = config.vocab_size
51
 
52
  self.audio_tower = self._create_audio_tower(config)
53
+ self.audio_tower_context_length: Optional[int] = None
54
+ self.audio_tower_context_length = self.audio_tower.max_context_length
55
+
56
  self.multi_modal_projector = self._create_multi_modal_projector(config)
57
  self.language_model = self._create_language_model(config)
58
 
59
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
60
  # FSDP throws an error if some of the layer types are not found in the model.
61
+ # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
62
+ self._no_split_modules = self.language_model._no_split_modules
 
 
63
 
64
  self.loss_config = LossConfig()
65
  self.post_init()
 
146
  )
147
  return {"loss": kl_loss}
148
 
149
+ def _audio_iter(
150
+ self, audio_batch_size: torch.Tensor
151
+ ) -> Generator[Tuple[int, int], None, None]:
152
+ """
153
+ Iterate over the audio batch size and yield the batch index and audio index of each audio item.
154
+
155
+ Args:
156
+ audio_batch_size: A tensor of shape (B,) where B is the batch size.
157
+
158
+ Returns:
159
+ A generator that yields a tuple of (start index, length) for each audio item.
160
+ """
161
+ audio_index = 0
162
+ for i_b, batch_count in enumerate(audio_batch_size):
163
+ for _ in range(batch_count):
164
+ yield i_b, audio_index
165
+ audio_index += 1
166
+
167
  def forward(
168
  self,
169
  input_ids: torch.Tensor,
 
172
  labels: Optional[torch.Tensor] = None,
173
  attention_mask: Optional[torch.Tensor] = None,
174
  audio_token_start_idx: Optional[torch.Tensor] = None,
175
+ audio_lens: Optional[torch.Tensor] = None,
176
  audio_token_len: Optional[torch.Tensor] = None,
177
+ audio_batch_size: Optional[torch.Tensor] = None,
178
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
179
  # the alt_* fields are needed for KL divergence loss
180
  alt_input_ids: Optional[torch.Tensor] = None,
 
205
  # B x T -> B x T x D
206
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
207
 
208
+ if audio_values is not None and len(audio_values) > 0:
209
  assert (
210
+ audio_token_start_idx is not None
211
+ and audio_token_len is not None
212
+ and audio_lens is not None
213
+ and audio_batch_size is not None
214
+ ), "audio_token_start_idx/audio_token_len/audio_lens must be provided if audio_values are provided."
215
  assert (
216
+ len(audio_token_start_idx)
217
+ == len(audio_token_len)
218
+ == len(audio_lens)
219
+ == len(audio_values)
220
+ ), "audio_token_start_idx/audio_token_len/audio_lens/audio_values must have the same batch size."
221
+ assert len(audio_batch_size) == len(
222
+ inputs_embeds
223
+ ), "audio_batch_size and inputs_embeds must have the same batch size."
224
+
225
+ # B x A/3200 x (D=max-audio-length-in-batch)
226
  audio_tower_output = self.audio_tower.forward(
227
  audio_values.to(self.audio_tower.dtype),
228
+ audio_len=audio_lens,
229
  ).last_hidden_state
230
  audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
 
231
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
232
 
233
  # combine audio and text embeddings
234
+ for i_b, i_a in self._audio_iter(audio_batch_size):
235
+ start_idx = audio_token_start_idx[i_a]
236
+ token_len = audio_token_len[i_a]
237
+ item_embedding = audio_embeds[i_a][:token_len]
238
+ inputs_embeds[i_b][start_idx : start_idx + token_len] = item_embedding
239
 
240
  lm_output = self.language_model.forward(
241
  inputs_embeds=inputs_embeds,
 
270
  audio_values: Optional[torch.FloatTensor] = None,
271
  audio_token_start_idx: Optional[torch.Tensor] = None,
272
  audio_token_len: Optional[torch.Tensor] = None,
273
+ audio_lens: Optional[torch.Tensor] = None,
274
+ audio_batch_size: Optional[torch.Tensor] = None,
275
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
276
  attention_mask: Optional[torch.Tensor] = None,
277
  inputs_embeds: Optional[torch.Tensor] = None,
 
300
  audio_token_start_idx - prefill_start_idx
301
  )
302
  model_input["audio_token_len"] = audio_token_len
303
+ model_input["audio_batch_size"] = audio_batch_size
304
+ model_input["audio_lens"] = audio_lens
305
 
306
  return model_input
307
 
 
318
  cls, config: UltravoxConfig
319
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
320
  if config.audio_model_id is not None:
321
+ if "whisper" in config.audio_model_id.lower():
322
  audio_tower = ModifiedWhisperEncoder.from_pretrained(
323
  config.audio_model_id, torch_dtype=config.torch_dtype
324
  )
 
334
  config.audio_model_id, torch_dtype=config.torch_dtype
335
  )
336
  else:
337
+ if "whisper" in config.audio_config._name_or_path.lower():
338
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
339
  audio_tower.init_latency_mask(
340
  config.audio_latency_block_size, dtype=config.torch_dtype
 
427
  if state_dict is None:
428
  state_dict = super().state_dict()
429
 
430
+ trainable_params = {k for k, v in self.named_parameters() if v.requires_grad}
431
+ # normalize the keys to match the original model
432
+ # Example: audio_tower.base_model.model.layers.0._fsdp_wrapped_module.self_attn.k_proj.lora_B.default.weight
433
+ trainable_params = {
434
+ k.replace("_fsdp_wrapped_module.", "") for k in trainable_params
435
+ }
436
 
437
  state_dict = {
438
  k: v
439
  for k, v in state_dict.items()
440
+ if k in self.keep_params or k in trainable_params
 
441
  }
442
 
443
  return state_dict
 
483
 
484
  # TODO: refactor common parts to a shared module
485
  def is_cache_empty(
486
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],
487
  ) -> bool:
488
  """
489
  Check if the cache is empty.
 
519
 
520
  class StackAudioFrames(nn.Module):
521
  """
522
+ Stack the audio embedding frames to reduce the sequence length by a factor
523
+ of `stack_factor`.
 
 
 
 
524
  """
525
 
526
  def __init__(self, stack_factor: int = 8):
 
530
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
531
  B, T, C = audio_embeds.shape
532
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
533
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
534
  B, T, C = audio_embeds.shape
535
  audio_embeds = audio_embeds.view(
536
  B, T // self.stack_factor, C * self.stack_factor
 
602
  base_model_prefix = "model.encoder"
603
  _no_split_modules = ["WhisperEncoderLayer"]
604
 
605
+ def __init__(self, config: transformers.WhisperConfig):
606
+ super().__init__(config)
607
+ self.config.is_decoder = False
 
608
 
609
+ @property
610
+ def max_context_length(self):
611
+ return (
612
  self.config.max_source_positions
613
  * self.conv1.stride[0]
614
  * self.conv2.stride[0]
615
  )
616
+
617
+ def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype):
618
+ if audio_latency_block_size is None:
619
+ self.audio_streaming_mask = None
620
+ return
621
+
622
+ # Use max_context_length directly in the calculation
623
+ max_seqlen = self.max_context_length
624
  assert (
625
  max_seqlen > 0
626
  ), f"maximum sequence length must be positive, got {max_seqlen}"
 
652
  output_hidden_states=None,
653
  return_dict=None,
654
  ):
655
+ expected_seq_length = self.max_context_length
 
 
 
 
656
  if input_features.shape[-1] > expected_seq_length:
657
  raise ValueError(
658
  f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
 
703
  attention_mask = self.get_extended_attention_mask(
704
  attention_mask,
705
  None,
 
706
  dtype=hidden_states.dtype,
707
  )
708
 
ultravox_processing.py CHANGED
@@ -1,5 +1,5 @@
1
  import dataclasses
2
- from typing import Optional, Union
3
 
4
  import numpy as np
5
  import torch
@@ -15,7 +15,13 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
15
  include_alt_fields: bool = False
16
 
17
  def __call__(self, features, *args, **kwargs):
18
- audio_values = [f.pop("audio_values", None) for f in features]
 
 
 
 
 
 
19
  if self.include_alt_fields:
20
  # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
21
  alt_features = [
@@ -34,8 +40,12 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
34
  batch["alt_attention_mask"] = alt_batch["attention_mask"]
35
  batch["alt_labels"] = alt_batch["labels"]
36
 
 
 
 
 
37
  # Pad the last dimension of all audio_values to the same length, with 0s on the right.
38
- if audio_values and audio_values[0] is not None:
39
  max_len = max([x.shape[-1] for x in audio_values])
40
  batch["audio_values"] = torch.stack(
41
  [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
@@ -45,10 +55,12 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
45
  [f["input_ids"].shape[-1] for f in features]
46
  )
47
  displacement = batch["input_ids"].shape[-1] - input_ids_lens
 
 
 
48
  batch["audio_token_start_idx"] += displacement.to(
49
  batch["audio_token_start_idx"].device
50
  )
51
-
52
  return batch
53
 
54
 
@@ -62,11 +74,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
62
  """
63
 
64
  attributes = ["audio_processor", "tokenizer"]
65
- audio_processor_class = (
66
- "Wav2Vec2Processor",
67
- "SeamlessM4TFeatureExtractor",
68
- "WhisperProcessor",
69
- )
70
  tokenizer_class = (
71
  "PreTrainedTokenizer",
72
  "PreTrainedTokenizerFast",
@@ -80,27 +88,32 @@ class UltravoxProcessor(transformers.ProcessorMixin):
80
  audio_processor=None,
81
  tokenizer=None,
82
  audio_padding: str = "longest",
83
- encoder_ds_factor: int = 320,
84
  stack_factor: int = 8,
85
  audio_placeholder: str = "<|audio|>",
 
 
86
  ):
87
  """
88
  Args:
89
  audio_processor: The audio processor for the audio encoder.
90
  tokenizer: The tokenizer for the language model.
91
  audio_padding: The padding strategy for the audio encoder.
92
- encoder_ds_factor: The downsample factor of the audio encoder.
93
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
 
94
  audio_placeholder: The placeholder for the audio in the text.
 
95
  """
96
  self.audio_padding = audio_padding
97
  self.encoder_ds_factor = encoder_ds_factor
98
  self.stack_factor = stack_factor
99
  self.audio_placeholder = audio_placeholder
100
- self.audio_token_replacement = tokenizer.eos_token
101
  assert (
102
- self.audio_token_replacement is not None
103
  ), "The tokenizer has no EOS token. Cannot recover."
 
 
104
  if tokenizer.pad_token_id is None:
105
  tokenizer.pad_token_id = tokenizer.eos_token_id
106
 
@@ -114,7 +127,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
114
  audio_processor = transformers.AutoProcessor.from_pretrained(
115
  config.audio_model_id
116
  or config.audio_config._name_or_path
117
- or "facebook/wav2vec2-base-960h"
118
  )
119
 
120
  tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -129,30 +142,100 @@ class UltravoxProcessor(transformers.ProcessorMixin):
129
  stack_factor=config.stack_factor,
130
  )
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  def __call__(
133
  self,
134
  text: Optional[str] = None,
135
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
 
 
 
 
 
136
  sampling_rate: Optional[int] = None,
137
  return_tensors: Optional[
138
  Union[str, transformers.TensorType]
139
  ] = transformers.TensorType.PYTORCH,
 
140
  **kwargs,
141
  ) -> transformers.BatchFeature:
142
  """
143
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
144
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
145
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
146
- audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
147
  of the above two methods for more information.
148
 
149
  Args:
150
  text (`str`, `List[str]`):
151
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
152
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
153
- The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
154
- NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
155
- sample length of the audio.
156
  sampling_rate (`int`, *optional*, defaults to 16000):
157
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
158
  you are doing.
@@ -176,75 +259,105 @@ class UltravoxProcessor(transformers.ProcessorMixin):
176
  Returned when `audio` is not `None`.
177
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
178
  """
179
- # TODO: Add support for multiple audio and text inputs.
 
 
 
 
 
 
 
180
  data = {}
181
- audio_embed_frames = 0
182
- if audio is not None and len(audio) > 0:
183
- if self.audio_padding == "max_length":
184
- # 30 seconds is the expected length for Whisper
185
- assert sampling_rate is not None, "Sampling rate must be provided."
186
- audio_len = 30 * sampling_rate
187
- else:
188
- audio_len = audio.shape[-1]
189
- # It's guaranteed that the number of frames is less than or equal to this amount.
190
- # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
191
- # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
192
- nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
193
- audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
194
- data["audio_token_len"] = [audio_embed_frames]
195
 
196
  # Main audio processing. The processor is model-specific.
197
- x = self.audio_processor(
198
- audio,
199
  sampling_rate=sampling_rate,
200
  padding="longest",
201
- max_length=audio_len,
 
202
  return_attention_mask=True,
203
  **kwargs,
204
  )
205
- if "input_features" in x:
206
- data["audio_values"] = x.input_features
207
- else:
208
- data["audio_values"] = x.input_values
209
-
210
- # data["audio_len"] is the number of frames in the audio, used for creating attention masks in whisper encoder
211
- if (
212
- self.audio_padding == "max_length"
213
- ): # audio is padded to max length, so we rely on the attention mask to determine audio_len
214
- data["audio_len"] = (
215
- x.attention_mask.sum(-1) - 1
216
- ) # Whisper attention mask includes an extra 1 at the end that needs to be subtracted
217
- else: # audio is not padded, so we can directly use the audio length
218
- data["audio_len"] = [torch.as_tensor(data["audio_values"]).shape[-1]]
219
 
220
- if text is not None:
221
- assert isinstance(
222
- text, str
223
- ), "Text must be a string. Batch mode not supported yet."
224
- if self.audio_placeholder in text:
225
- if "audio_token_len" not in data:
226
- raise ValueError(
227
- f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
228
- )
229
-
230
- start_idx = len(
231
- self.tokenizer.encode(
232
- text[: text.index(self.audio_placeholder)],
233
- add_special_tokens=False,
234
- )
235
- )
236
- data["audio_token_start_idx"] = [start_idx]
237
-
238
- # Replace the audio placeholder with the audio token.
239
- # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
240
- # where the number of </s> is the number of audio frames.
241
- text = text.replace(
242
- self.audio_placeholder,
243
- self.audio_token_replacement * audio_embed_frames,
244
  )
 
 
 
 
 
 
 
 
 
 
245
 
246
  # Special tokens like BOS should already have been added by the caller.
247
- data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
250
 
 
1
  import dataclasses
2
+ from typing import Any, Dict, List, Optional, Union
3
 
4
  import numpy as np
5
  import torch
 
15
  include_alt_fields: bool = False
16
 
17
  def __call__(self, features, *args, **kwargs):
18
+ audio_values = [x for f in features for x in f.pop("audio_values", [])]
19
+ audio_lens = [x for f in features for x in f.pop("audio_lens", [])]
20
+ audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])]
21
+ audio_token_start_idx = [
22
+ x for f in features for x in f.pop("audio_token_start_idx", [])
23
+ ]
24
+
25
  if self.include_alt_fields:
26
  # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
27
  alt_features = [
 
40
  batch["alt_attention_mask"] = alt_batch["attention_mask"]
41
  batch["alt_labels"] = alt_batch["labels"]
42
 
43
+ batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx)
44
+ batch["audio_lens"] = torch.stack(audio_lens)
45
+ batch["audio_token_len"] = torch.stack(audio_token_len)
46
+
47
  # Pad the last dimension of all audio_values to the same length, with 0s on the right.
48
+ if audio_values:
49
  max_len = max([x.shape[-1] for x in audio_values])
50
  batch["audio_values"] = torch.stack(
51
  [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
 
55
  [f["input_ids"].shape[-1] for f in features]
56
  )
57
  displacement = batch["input_ids"].shape[-1] - input_ids_lens
58
+ displacement = displacement.repeat_interleave(
59
+ batch["audio_batch_size"].squeeze(-1)
60
+ )
61
  batch["audio_token_start_idx"] += displacement.to(
62
  batch["audio_token_start_idx"].device
63
  )
 
64
  return batch
65
 
66
 
 
74
  """
75
 
76
  attributes = ["audio_processor", "tokenizer"]
77
+ audio_processor_class = ("WhisperProcessor",)
 
 
 
 
78
  tokenizer_class = (
79
  "PreTrainedTokenizer",
80
  "PreTrainedTokenizerFast",
 
88
  audio_processor=None,
89
  tokenizer=None,
90
  audio_padding: str = "longest",
91
+ encoder_ds_factor: int = 2,
92
  stack_factor: int = 8,
93
  audio_placeholder: str = "<|audio|>",
94
+ # Defaults to whisper encoder context size
95
+ audio_context_size: Optional[int] = 3000,
96
  ):
97
  """
98
  Args:
99
  audio_processor: The audio processor for the audio encoder.
100
  tokenizer: The tokenizer for the language model.
101
  audio_padding: The padding strategy for the audio encoder.
 
102
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
103
+ encoder_ds_factor: The downsampling factor of the audio encoder.
104
  audio_placeholder: The placeholder for the audio in the text.
105
+ audio_context_size: The maximum number of frames that the audio encoder can handle.
106
  """
107
  self.audio_padding = audio_padding
108
  self.encoder_ds_factor = encoder_ds_factor
109
  self.stack_factor = stack_factor
110
  self.audio_placeholder = audio_placeholder
111
+ self.audio_context_size = audio_context_size
112
  assert (
113
+ tokenizer.eos_token is not None
114
  ), "The tokenizer has no EOS token. Cannot recover."
115
+ self.vocab = tokenizer.get_vocab()
116
+ self.audio_token_replacement = tokenizer.eos_token
117
  if tokenizer.pad_token_id is None:
118
  tokenizer.pad_token_id = tokenizer.eos_token_id
119
 
 
127
  audio_processor = transformers.AutoProcessor.from_pretrained(
128
  config.audio_model_id
129
  or config.audio_config._name_or_path
130
+ or "openai/whisper-tiny"
131
  )
132
 
133
  tokenizer = transformers.AutoTokenizer.from_pretrained(
 
142
  stack_factor=config.stack_factor,
143
  )
144
 
145
+ def _chunk_and_pad_audio(
146
+ self,
147
+ audio_values: torch.Tensor,
148
+ audio_lens: torch.Tensor,
149
+ include_audio_num_chunks: bool = False,
150
+ ) -> Dict[str, Any]:
151
+ """
152
+ Processes the audio batch by chunking any items in the batch according to the audio_context_size,
153
+ padding the last chunk if needed, and returns a dictionary with updated audio data.
154
+
155
+ Args:
156
+ audio_values (torch.Tensor): A tensor of audio values (e.g., in B, D, T format).
157
+ audio_lens (torch.Tensor): A tensor of audio lengths.
158
+
159
+ Returns:
160
+ Dict[str, Any]: Dictionary with the following keys:
161
+ - "audio_values": The concatenated audio tensor after chunking and padding.
162
+ - "audio_lens": Tensor of lengths for each chunk.
163
+ - "audio_is_continuation": Tensor of booleans indicating if the chunk is a continuation of the previous chunk.
164
+ - "audio_batch_size": A Tensor with one integer representing the number of chunks.
165
+
166
+ """
167
+ chunked_audio_values: List[torch.Tensor] = []
168
+ chunked_audio_lens: List[int] = []
169
+ is_continuation_list: List[bool] = []
170
+ num_chunks: List[int] = []
171
+ context_size = self.audio_context_size or audio_values.shape[-1]
172
+
173
+ for i in range(audio_values.shape[0]): # iterate over the batch
174
+ num_chunks.append(int(np.ceil(audio_lens[i] / context_size)))
175
+ for offset in range(0, audio_lens[i], context_size):
176
+ is_continuation = offset > 0
177
+ chunk = audio_values[i, :, offset : offset + context_size]
178
+ if is_continuation and chunk.shape[-1] < context_size:
179
+ # N.B. We only need to pad continuation chunks. If none of the samples require chunking, the
180
+ # batch might not (need to) be padded all the way to the audio_context_size, in which case
181
+ # we've already included the padding above. On the other hand, if we have any continuation
182
+ # chunks we know that the batch needs to be padded to audio_context_size because that's what
183
+ # we're slicing to.
184
+ chunk = F.pad(chunk, (0, context_size - chunk.shape[-1]))
185
+ chunked_audio_values.append(chunk)
186
+ chunked_audio_lens.append(
187
+ min(int(audio_lens[i].item()) - offset, context_size)
188
+ )
189
+ is_continuation_list.append(is_continuation)
190
+
191
+ data = {
192
+ "audio_values": torch.stack(chunked_audio_values, dim=0),
193
+ "audio_lens": torch.tensor(
194
+ chunked_audio_lens, dtype=torch.int64, device=audio_values.device
195
+ ),
196
+ "audio_is_continuation": torch.tensor(
197
+ is_continuation_list, dtype=torch.bool, device=audio_values.device
198
+ ),
199
+ "audio_batch_size": torch.tensor(
200
+ [len(chunked_audio_values)], device=audio_values.device
201
+ ),
202
+ }
203
+ if include_audio_num_chunks:
204
+ data["audio_num_chunks"] = torch.tensor(
205
+ num_chunks, dtype=torch.int64, device=audio_values.device
206
+ )
207
+ return data
208
+
209
  def __call__(
210
  self,
211
  text: Optional[str] = None,
212
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
213
+ audios: Optional[
214
+ Union[
215
+ List[Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
216
+ ]
217
+ ] = None,
218
  sampling_rate: Optional[int] = None,
219
  return_tensors: Optional[
220
  Union[str, transformers.TensorType]
221
  ] = transformers.TensorType.PYTORCH,
222
+ include_audio_num_chunks: bool = False,
223
  **kwargs,
224
  ) -> transformers.BatchFeature:
225
  """
226
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
227
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
228
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
229
+ audio processor's [`~WhisperProcessor.__call__`] if `audio` is not `None`. Please refer to the docstring
230
  of the above two methods for more information.
231
 
232
  Args:
233
  text (`str`, `List[str]`):
234
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
235
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
236
+ The audio to be prepared. Audio can be a single-channel (1-dimensional) NumPy array or PyTorch tensor.
237
+ audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
238
+ A list or two dimensional array of audio to be prepared.
239
  sampling_rate (`int`, *optional*, defaults to 16000):
240
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
241
  you are doing.
 
259
  Returned when `audio` is not `None`.
260
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
261
  """
262
+ # TODO: Add support for multiple text inputs.
263
+ if audio is not None and audios is not None:
264
+ raise ValueError("Only one of `audio` or `audios` should be provided.")
265
+ elif audio is not None:
266
+ audios = audio if isinstance(audio, list) or audio.ndim == 2 else [audio]
267
+ elif audios is None:
268
+ audios = []
269
+
270
  data = {}
271
+ audio_is_continuation = []
272
+ if len(audios) > 0:
273
+ audios = [x.numpy() if isinstance(x, torch.Tensor) else x for x in audios]
274
+
275
+ # Pad out each audio to at least 2 hops (the minimum required by the processor).
276
+ hop_length = self.audio_processor.feature_extractor.hop_length
277
+ audios = [
278
+ (
279
+ np.pad(x, (0, 2 * hop_length - len(x)), mode="constant")
280
+ if len(x) < 2 * hop_length
281
+ else x
282
+ )
283
+ for x in audios
284
+ ]
285
 
286
  # Main audio processing. The processor is model-specific.
287
+ x: transformers.BatchFeature = self.audio_processor(
288
+ audios,
289
  sampling_rate=sampling_rate,
290
  padding="longest",
291
+ pad_to_multiple_of=hop_length, # The attention mask effectively gets padded to the hop length, so pad the audio to be consistent.
292
+ truncation=False,
293
  return_attention_mask=True,
294
  **kwargs,
295
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
+ data.update(
298
+ self._chunk_and_pad_audio(
299
+ audio_values=torch.as_tensor(
300
+ x.input_features if "input_features" in x else x.input_values
301
+ ),
302
+ audio_lens=torch.as_tensor(x.attention_mask).sum(-1),
303
+ include_audio_num_chunks=include_audio_num_chunks,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  )
305
+ )
306
+
307
+ audio_is_continuation = data.pop("audio_is_continuation")
308
+ data["audio_token_len"] = torch.ceil(
309
+ data["audio_lens"] / (self.encoder_ds_factor * self.stack_factor)
310
+ ).to(dtype=torch.int)
311
+
312
+ if text is not None:
313
+ if not isinstance(text, str):
314
+ raise ValueError("Text must be a string. Batch mode not supported yet.")
315
 
316
  # Special tokens like BOS should already have been added by the caller.
317
+ tokenized_parts = self.tokenizer(
318
+ text.split(
319
+ "<|audio|>" # The placeholder isn't part of the vocabulary, so split the text around it.
320
+ ),
321
+ add_special_tokens=False,
322
+ **kwargs,
323
+ )
324
+
325
+ audio_token_start_idx = []
326
+ placeholder_index = -1
327
+ split_input_ids = tokenized_parts["input_ids"]
328
+ input_ids: List[int] = []
329
+
330
+ audio_token_replacement_token_id = self.vocab[self.audio_token_replacement]
331
+
332
+ for i, token_len in enumerate(data.get("audio_token_len", [])):
333
+ if not audio_is_continuation[i]:
334
+ placeholder_index += 1
335
+ if placeholder_index >= len(split_input_ids):
336
+ raise ValueError(
337
+ f"Text contains too few audio placeholders. (Expected {len(audios)} placeholders)"
338
+ )
339
+
340
+ input_ids.extend(split_input_ids[placeholder_index])
341
+
342
+ audio_token_start_idx.append(len(input_ids))
343
+
344
+ input_ids.extend([audio_token_replacement_token_id] * token_len)
345
+
346
+ # Include any tokens after the last audio.
347
+ placeholder_index += 1
348
+ if placeholder_index != len(split_input_ids) - 1:
349
+ raise ValueError(
350
+ f"Text contains too many audio placeholders. (Expected {len(audios)} placeholders)"
351
+ )
352
+ input_ids.extend(split_input_ids[placeholder_index])
353
+
354
+ if "audio_token_len" in data:
355
+ data["audio_token_start_idx"] = torch.as_tensor(audio_token_start_idx)
356
+
357
+ data["input_ids"] = [input_ids]
358
+ data["attention_mask"] = [[1] * len(input_ids)]
359
+
360
+ # Ensure that there are no audio placeholders after the last audio.
361
 
362
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
363