File size: 6,002 Bytes
95bcbb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
from speechbrain.inference.interfaces import Pretrained
class CustomEncoderWav2vec2Classifier(Pretrained):
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
language-id, emotion recognition, keyword spotting, etc).
The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
are defined in the yaml file. If you want to
convert the predicted index into a corresponding text label, please
provide the path of the label_encoder in a variable called 'lab_encoder_file'
within the yaml.
The class can be used either to run only the encoder (encode_batch()) to
extract embeddings or to run a classification step (classify_batch()).
```
Example
-------
>>> import torchaudio
>>> from speechbrain.pretrained import EncoderClassifier
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> classifier = EncoderClassifier.from_hparams(
... source="speechbrain/spkrec-ecapa-voxceleb",
... savedir=tmpdir,
... )
>>> # Compute embeddings
>>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
>>> embeddings = classifier.encode_batch(signal)
>>> # Classification
>>> prediction = classifier .classify_batch(signal)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode_batch(self, wavs, wav_lens=None, normalize=False):
"""Encodes the input audio into a single vector embedding.
The waveforms should already be in the model's desired format.
You can call:
``normalized = <this>.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model. Make sure the sample rate is fs=16000 Hz.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
normalize : bool
If True, it normalizes the embeddings with the statistics
contained in mean_var_norm_emb.
Returns
-------
torch.tensor
The encoded batch
"""
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
# Computing features and embeddings
outputs = self.mods.wav2vec2(wavs)
# last dim will be used for AdaptativeAVG pool
outputs = self.mods.avg_pool(outputs, wav_lens)
outputs = outputs.view(outputs.shape[0], -1)
return outputs
def classify_batch(self, wavs, wav_lens=None):
"""Performs classification on the top of the encoded features.
It returns the posterior probabilities, the index and, if the label
encoder is specified it also the text label.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model. Make sure the sample rate is fs=16000 Hz.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
out_prob
The log posterior probabilities of each class ([batch, N_class])
score:
It is the value of the log-posterior for the best class ([batch,])
index
The indexes of the best class ([batch,])
text_lab:
List with the text labels corresponding to the indexes.
(label encoder should be provided).
"""
outputs = self.encode_batch(wavs, wav_lens)
outputs = self.mods.output_mlp(outputs)
out_prob = self.hparams.softmax(outputs)
score, index = torch.max(out_prob, dim=-1)
text_lab = self.hparams.label_encoder.decode_torch(index)
return out_prob, score, index, text_lab
def classify_file(self, path):
"""Classifies the given audiofile into the given set of labels.
Arguments
---------
path : str
Path to audio file to classify.
Returns
-------
out_prob
The log posterior probabilities of each class ([batch, N_class])
score:
It is the value of the log-posterior for the best class ([batch,])
index
The indexes of the best class ([batch,])
text_lab:
List with the text labels corresponding to the indexes.
(label encoder should be provided).
"""
waveform = self.load_audio(path)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
outputs = self.encode_batch(batch, rel_length)
outputs = self.mods.output_mlp(outputs).squeeze(1)
out_prob = self.hparams.softmax(outputs)
score, index = torch.max(out_prob, dim=-1)
text_lab = self.hparams.label_encoder.decode_torch(index)
return out_prob, score, index, text_lab
def forward(self, wavs, wav_lens=None, normalize=False):
return self.encode_batch(
wavs=wavs, wav_lens=wav_lens, normalize=normalize
) |