return_timestamps="word" do not work for this model

#3
by justkr - opened

Hello,

I would like to use marianbasti/distil-whisper-large-v3-es with return_timestamps="word" while inferencing but unfortunately, I receive an error: IndexError: list index out of range.
Similiar error was noticed and corrected for whisper-large-v3: https://github.com/huggingface/transformers/issues/31683

Now I use transformers==4.47.0 and return_timestamps="word" works both for whisper-large-v3 and distil-whisper/distil-large-v3 but not for marianbasti/distil-whisper-large-v3-es.

My script:

from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)
import torch

model_id = "marianbasti/distil-whisper-large-v3-es"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
use_safetensors=True,
).to("cpu")

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch.float32,
device="cpu",
model_kwargs={"attn_implementation": 'sdpa'},
generate_kwargs={"max_new_tokens": 128},
chunk_length_s=25,
batch_size=16,
)

results = pipe("audio.flac", return_timestamps="word")

Please let me know, if you have the same problem.
Thanks!

Hi there!

Could you provide more of the stack trace? Have you given it a try using the pipeline settings described in the model card?
I have not used word-level timestamps yet.

Cheers!

Hello,

I have tested the code from model card but with return_timestamps="word":

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "marianbasti/distil-whisper-large-v3-es"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample, return_timestamps="word")

The error is the same:

IndexError Traceback (most recent call last)
Cell In[68], line 23
21 dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
22 sample = dataset[0]["audio"]
---> 23 result = pipe(sample, return_timestamps="word")

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/pipelines/automatic_speech_recognition.py:283, in AutomaticSpeechRecognitionPipeline.call(self, inputs, **kwargs)
222 def call(
223 self,
224 inputs: Union[np.ndarray, bytes, str],
225 **kwargs,
226 ):
227 """
228 Transcribe the audio sequence(s) given as inputs to text. See the [AutomaticSpeechRecognitionPipeline]
229 documentation for more information.
(...)
281 "".join(chunk["text"] for chunk in output["chunks"]).
282 """
--> 283 return super().call(inputs, **kwargs)

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/pipelines/base.py:1293, in Pipeline.call(self, inputs, num_workers, batch_size, *args, **kwargs)
1291 return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
1292 elif self.framework == "pt" and isinstance(self, ChunkPipeline):
-> 1293 return next(
1294 iter(
1295 self.get_iterator(
1296 [inputs], num_workers, batch_size, preprocess_params, forward_params, postprocess_params
1297 )
1298 )
1299 )
1300 else:
1301 return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/pipelines/pt_utils.py:124, in PipelineIterator.next(self)
121 return self.loader_batch_item()
123 # We're out of items within a batch
--> 124 item = next(self.iterator)
125 processed = self.infer(item, **self.params)
126 # We now have a batch of "inferred things".

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/pipelines/pt_utils.py:269, in PipelinePackIterator.next(self)
266 return accumulator
268 while not is_last:
--> 269 processed = self.infer(next(self.iterator), **self.params)
270 if self.loader_batch_size is not None:
271 if isinstance(processed, torch.Tensor):

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/pipelines/base.py:1208, in Pipeline.forward(self, model_inputs, **forward_params)
1206 with inference_context():
1207 model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
-> 1208 model_outputs = self._forward(model_inputs, **forward_params)
1209 model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
1210 else:

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/pipelines/automatic_speech_recognition.py:515, in AutomaticSpeechRecognitionPipeline._forward(self, model_inputs, return_timestamps, **generate_kwargs)
512 if "generation_config" not in generate_kwargs:
513 generate_kwargs["generation_config"] = self.generation_config
--> 515 tokens = self.model.generate(
516 inputs=inputs,
517 attention_mask=attention_mask,
518 **generate_kwargs,
519 )
520 # whisper longform generation stores timestamps in "segments"
521 if return_timestamps == "word" and self.type == "seq2seq_whisper":

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:688, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, time_precision_features, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
679 proc.set_begin_index(decoder_input_ids.shape[-1])
681 # 6.6 Run generate with fallback
682 (
683 seek_sequences,
684 seek_outputs,
685 should_skip,
686 do_condition_on_prev_tokens,
687 model_output_type,
--> 688 ) = self.generate_with_fallback(
689 segment_input=segment_input,
690 decoder_input_ids=decoder_input_ids,
691 cur_bsz=cur_bsz,
692 batch_idx_map=batch_idx_map,
693 seek=seek,
694 num_segment_frames=num_segment_frames,
695 max_frames=max_frames,
696 temperatures=temperatures,
697 generation_config=generation_config,
698 logits_processor=logits_processor,
699 stopping_criteria=stopping_criteria,
700 prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
701 synced_gpus=synced_gpus,
702 return_token_timestamps=return_token_timestamps,
703 do_condition_on_prev_tokens=do_condition_on_prev_tokens,
704 is_shortform=is_shortform,
705 batch_size=batch_size,
706 attention_mask=attention_mask,
707 kwargs=kwargs,
708 )
710 # 6.7 In every generated sequence, split by timestamp tokens and extract segments
711 for i, seek_sequence in enumerate(seek_sequences):

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:867, in WhisperGenerationMixin.generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batch_idx_map, seek, num_segment_frames, max_frames, temperatures, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, do_condition_on_prev_tokens, is_shortform, batch_size, attention_mask, kwargs)
864 model_output_type = type(seek_outputs)
866 # post-process sequence tokens and outputs to be in list form
--> 867 seek_sequences, seek_outputs = self._postprocess_outputs(
868 seek_outputs=seek_outputs,
869 decoder_input_ids=decoder_input_ids,
870 return_token_timestamps=return_token_timestamps,
871 generation_config=generation_config,
872 is_shortform=is_shortform,
873 )
875 if cur_bsz < batch_size:
876 seek_sequences = seek_sequences[:cur_bsz]

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:968, in WhisperGenerationMixin._postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform)
966 if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
967 num_frames = getattr(generation_config, "num_frames", None)
--> 968 seek_outputs["token_timestamps"] = self._extract_token_timestamps(
969 seek_outputs,
970 generation_config.alignment_heads,
971 num_frames=num_frames,
972 num_input_ids=decoder_input_ids.shape[-1],
973 )
974 seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
976 seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:195, in WhisperGenerationMixin._extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision, num_frames, num_input_ids)
191 cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
193 # Select specific cross-attention layers and heads. This is a tensor
194 # of shape (batch size, num selected, output length, input length).
--> 195 weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
196 weights = weights.permute([1, 0, 2, 3])
198 weight_length = None

File /anaconda/envs/myenv_esp/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:195, in (.0)
191 cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
193 # Select specific cross-attention layers and heads. This is a tensor
194 # of shape (batch size, num selected, output length, input length).
--> 195 weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
196 weights = weights.permute([1, 0, 2, 3])
198 weight_length = None

IndexError: list index out of range

Kind regards!

Sign up or log in to comment