Automatic speech recognition
Automatic speech recognition (ASR) converts a speech signal to text. It is an example of a sequence-to-sequence task, going from a sequence of audio inputs to textual outputs. Voice assistants like Siri and Alexa utilize ASR models to assist users.
This guide will show you how to fine-tune Wav2Vec2 on the TIMIT dataset to transcribe audio to text.
See the automatic speech recognition task page for more information about its associated models, datasets, and metrics.
Load TIMIT dataset
Load the TIMIT dataset from the 🤗 Datasets library:
>>> from datasets import load_dataset
>>> timit = load_dataset("timit_asr")
Then take a look at an example:
>>> timit
DatasetDict({
train: Dataset({
features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
num_rows: 4620
})
test: Dataset({
features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
num_rows: 1680
})
})
While the dataset contains a lot of helpful information, like dialect_region
and sentence_type
, you will focus on the audio
and text
fields in this guide. Remove the other columns:
>>> timit = timit.remove_columns(
... ["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"]
... )
Take a look at the example again:
>>> timit["train"][0]
{'audio': {'array': array([-2.1362305e-04, 6.1035156e-05, 3.0517578e-05, ...,
-3.0517578e-05, -9.1552734e-05, -6.1035156e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV',
'sampling_rate': 16000},
'file': '/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV',
'text': 'Would such an act of refusal be useful?'}
The audio
column contains a 1-dimensional array
of the speech signal that must be called to load and resample the audio file.
Preprocess
Load the Wav2Vec2 processor to process the audio signal and transcribed text:
>>> from transformers import AutoProcessor
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
The preprocessing function needs to:
- Call the
audio
column to load and resample the audio file. - Extract the
input_values
from the audio file. - Typically, when you call the processor, you call the feature extractor. Since you also want to tokenize text, instruct the processor to call the tokenizer instead with a context manager.
>>> def prepare_dataset(batch):
... audio = batch["audio"]
... batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
... batch["input_length"] = len(batch["input_values"])
... with processor.as_target_processor():
... batch["labels"] = processor(batch["text"]).input_ids
... return batch
Use 🤗 Datasets map
function to apply the preprocessing function over the entire dataset. You can speed up the map function by increasing the number of processes with num_proc
. Remove the columns you don’t need:
>>> timit = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], num_proc=4)
🤗 Transformers doesn’t have a data collator for automatic speech recognition, so you will need to create one. You can adapt the DataCollatorWithPadding to create a batch of examples for automatic speech recognition. It will also dynamically pad your text and labels to the length of the longest element in its batch, so they are a uniform length. While it is possible to pad your text in the tokenizer
function by setting padding=True
, dynamic padding is more efficient.
Unlike other data collators, this specific data collator needs to apply a different padding method to input_values
and labels
. You can apply a different padding method with a context manager:
>>> import torch
>>> from dataclasses import dataclass, field
>>> from typing import Any, Dict, List, Optional, Union
>>> @dataclass
... class DataCollatorCTCWithPadding:
... processor: AutoProcessor
... padding: Union[bool, str] = True
... def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
... # split inputs and labels since they have to be of different lengths and need
... # different padding methods
... input_features = [{"input_values": feature["input_values"]} for feature in features]
... label_features = [{"input_ids": feature["labels"]} for feature in features]
... batch = self.processor.pad(
... input_features,
... padding=self.padding,
... return_tensors="pt",
... )
... with self.processor.as_target_processor():
... labels_batch = self.processor.pad(
... label_features,
... padding=self.padding,
... return_tensors="pt",
... )
... # replace padding with -100 to ignore loss correctly
... labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
... batch["labels"] = labels
... return batch
Create a batch of examples and dynamically pad them with DataCollatorForCTCWithPadding
:
>>> data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
Train
Load Wav2Vec2 with AutoModelForCTC. For ctc_loss_reduction
, it is often better to use the average instead of the default summation:
>>> from transformers import AutoModelForCTC, TrainingArguments, Trainer
>>> model = AutoModelForCTC.from_pretrained(
... "facebook/wav2vec-base",
... ctc_loss_reduction="mean",
... pad_token_id=processor.tokenizer.pad_token_id,
... )
If you aren’t familiar with fine-tuning a model with the Trainer, take a look at the basic tutorial here!
At this point, only three steps remain:
- Define your training hyperparameters in TrainingArguments.
- Pass the training arguments to Trainer along with the model, datasets, tokenizer, and data collator.
- Call train() to fine-tune your model.
>>> training_args = TrainingArguments(
... output_dir="./results",
... group_by_length=True,
... per_device_train_batch_size=16,
... evaluation_strategy="steps",
... num_train_epochs=3,
... fp16=True,
... gradient_checkpointing=True,
... learning_rate=1e-4,
... weight_decay=0.005,
... save_total_limit=2,
... )
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=timit["train"],
... eval_dataset=timit["test"],
... tokenizer=processor.feature_extractor,
... data_collator=data_collator,
... )
>>> trainer.train()