brdhaker3 commited on
Commit
3342e5f
·
verified ·
1 Parent(s): 052481b

Delete 1234

Browse files
1234/app.py DELETED
@@ -1,464 +0,0 @@
1
- #!/usr/bin/env python3
2
- import sys
3
- import torch
4
- import logging
5
- import gradio as gr
6
- import speechbrain as sb
7
- from pathlib import Path
8
- import os
9
- import torchaudio
10
- from hyperpyyaml import load_hyperpyyaml
11
- from speechbrain.tokenizers.SentencePiece import SentencePiece
12
- from speechbrain.utils.data_utils import undo_padding
13
- from speechbrain.utils.distributed import run_on_main
14
-
15
- """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
16
- The system employs a wav2vec2 encoder and a CTC decoder.
17
- Decoding is performed with greedy decoding (will be extended to beam search).
18
-
19
- To run this recipe, do the following:
20
- > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
21
-
22
- With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
23
- The wav2vec2 model is pretrained following the model given in the hprams file.
24
- It may be dependent on the language.
25
-
26
- The neural network is trained with CTC on sub-word units estimated with
27
- Byte Pairwise Encoding (BPE).
28
-
29
- The experiment file is flexible enough to support a large variety of
30
- different systems. By properly changing the parameter files, you can try
31
- different encoders, decoders, tokens (e.g, characters instead of BPE),
32
- training languages (all CommonVoice languages), and many
33
- other possible variations.
34
-
35
- Authors
36
- * Titouan Parcollet 2021
37
- """
38
-
39
- logger = logging.getLogger(__name__)
40
-
41
-
42
- # Define training procedure
43
- class ASR(sb.core.Brain):
44
- def compute_forward(self, batch, stage):
45
- """Forward computations from the waveform batches to the output probabilities."""
46
-
47
- batch = batch.to(self.device)
48
- wavs, wav_lens = batch.sig
49
- wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
50
- if stage == sb.Stage.TRAIN:
51
- if hasattr(self.hparams, "augmentation"):
52
- wavs = self.hparams.augmentation(wavs, wav_lens)
53
-
54
- # Forward pass
55
- feats = self.modules.wav2vec2(wavs, wav_lens)
56
- x = self.modules.enc(feats)
57
- logits = self.modules.ctc_lin(x)
58
- p_ctc = self.hparams.log_softmax(logits)
59
-
60
- return p_ctc, wav_lens
61
-
62
- def treat_wav(self,sig):
63
- feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
64
- feats = self.modules.enc(feats)
65
- logits = self.modules.ctc_lin(feats)
66
- p_ctc = self.hparams.log_softmax(logits)
67
- predicted_words =[]
68
- for logs in p_ctc:
69
- text = decoder.decode(logs.detach().cpu().numpy())
70
- predicted_words.append(text.split(" "))
71
- return " ".join(predicted_words[0])
72
-
73
- def compute_objectives(self, predictions, batch, stage):
74
- """Computes the loss (CTC) given predictions and targets."""
75
-
76
- p_ctc, wav_lens = predictions
77
-
78
- ids = batch.id
79
- tokens, tokens_lens = batch.tokens
80
-
81
- loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
82
-
83
- if stage != sb.Stage.TRAIN:
84
- predicted_tokens = sb.decoders.ctc_greedy_decode(
85
- p_ctc, wav_lens, blank_id=self.hparams.blank_index
86
- )
87
- # Decode token terms to words
88
- if self.hparams.use_language_modelling:
89
- predicted_words = []
90
- for logs in p_ctc:
91
- text = decoder.decode(logs.detach().cpu().numpy())
92
- predicted_words.append(text.split(" "))
93
- else:
94
- predicted_words = [
95
- "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
96
- for utt_seq in predicted_tokens
97
- ]
98
- # Convert indices to words
99
- target_words = [wrd.split(" ") for wrd in batch.wrd]
100
-
101
- self.wer_metric.append(ids, predicted_words, target_words)
102
- self.cer_metric.append(ids, predicted_words, target_words)
103
-
104
- return loss
105
-
106
- def fit_batch(self, batch):
107
- """Train the parameters given a single batch in input"""
108
- should_step = self.step % self.grad_accumulation_factor == 0
109
- # Managing automatic mixed precision
110
- # TOFIX: CTC fine-tuning currently is unstable
111
- # This is certainly due to CTC being done in fp16 instead of fp32
112
- if self.auto_mix_prec:
113
- with torch.cuda.amp.autocast():
114
- with self.no_sync():
115
- outputs = self.compute_forward(batch, sb.Stage.TRAIN)
116
- loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
117
- with self.no_sync(not should_step):
118
- self.scaler.scale(
119
- loss / self.grad_accumulation_factor
120
- ).backward()
121
- if should_step:
122
-
123
- if not self.hparams.wav2vec2.freeze:
124
- self.scaler.unscale_(self.wav2vec_optimizer)
125
- self.scaler.unscale_(self.model_optimizer)
126
- if self.check_gradients(loss):
127
- if not self.hparams.wav2vec2.freeze:
128
- if self.optimizer_step >= self.hparams.warmup_steps:
129
- self.scaler.step(self.wav2vec_optimizer)
130
- self.scaler.step(self.model_optimizer)
131
- self.scaler.update()
132
- self.zero_grad()
133
- self.optimizer_step += 1
134
- else:
135
- # This is mandatory because HF models have a weird behavior with DDP
136
- # on the forward pass
137
- with self.no_sync():
138
- outputs = self.compute_forward(batch, sb.Stage.TRAIN)
139
-
140
- loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
141
-
142
- with self.no_sync(not should_step):
143
- (loss / self.grad_accumulation_factor).backward()
144
- if should_step:
145
- if self.check_gradients(loss):
146
- if not self.hparams.wav2vec2.freeze:
147
- if self.optimizer_step >= self.hparams.warmup_steps:
148
- self.wav2vec_optimizer.step()
149
- self.model_optimizer.step()
150
- self.zero_grad()
151
- self.optimizer_step += 1
152
-
153
- self.on_fit_batch_end(batch, outputs, loss, should_step)
154
- return loss.detach().cpu()
155
-
156
- def evaluate_batch(self, batch, stage):
157
- """Computations needed for validation/test batches"""
158
- predictions = self.compute_forward(batch, stage=stage)
159
- with torch.no_grad():
160
- loss = self.compute_objectives(predictions, batch, stage=stage)
161
- return loss.detach()
162
-
163
- def on_stage_start(self, stage, epoch):
164
- """Gets called at the beginning of each epoch"""
165
- if stage != sb.Stage.TRAIN:
166
- self.cer_metric = self.hparams.cer_computer()
167
- self.wer_metric = self.hparams.error_rate_computer()
168
-
169
- def on_stage_end(self, stage, stage_loss, epoch):
170
- """Gets called at the end of an epoch."""
171
- # Compute/store important stats
172
- stage_stats = {"loss": stage_loss}
173
- if stage == sb.Stage.TRAIN:
174
- self.train_stats = stage_stats
175
- else:
176
- stage_stats["CER"] = self.cer_metric.summarize("error_rate")
177
- stage_stats["WER"] = self.wer_metric.summarize("error_rate")
178
-
179
- # Perform end-of-iteration things, like annealing, logging, etc.
180
- if stage == sb.Stage.VALID:
181
- old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
182
- stage_stats["loss"]
183
- )
184
- old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
185
- stage_stats["loss"]
186
- )
187
- sb.nnet.schedulers.update_learning_rate(
188
- self.model_optimizer, new_lr_model
189
- )
190
- if not self.hparams.wav2vec2.freeze:
191
- sb.nnet.schedulers.update_learning_rate(
192
- self.wav2vec_optimizer, new_lr_wav2vec
193
- )
194
- self.hparams.train_logger.log_stats(
195
- stats_meta={
196
- "epoch": epoch,
197
- "lr_model": old_lr_model,
198
- "lr_wav2vec": old_lr_wav2vec,
199
- },
200
- train_stats=self.train_stats,
201
- valid_stats=stage_stats,
202
- )
203
- self.checkpointer.save_and_keep_only(
204
- meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
205
- )
206
- elif stage == sb.Stage.TEST:
207
- self.hparams.train_logger.log_stats(
208
- stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
209
- test_stats=stage_stats,
210
- )
211
- with open(self.hparams.wer_file, "w") as w:
212
- self.wer_metric.write_stats(w)
213
-
214
- def init_optimizers(self):
215
- "Initializes the wav2vec2 optimizer and model optimizer"
216
-
217
- # If the wav2vec encoder is unfrozen, we create the optimizer
218
- if not self.hparams.wav2vec2.freeze:
219
- self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
220
- self.modules.wav2vec2.parameters()
221
- )
222
- if self.checkpointer is not None:
223
- self.checkpointer.add_recoverable(
224
- "wav2vec_opt", self.wav2vec_optimizer
225
- )
226
-
227
- self.model_optimizer = self.hparams.model_opt_class(
228
- self.hparams.model.parameters()
229
- )
230
-
231
- if self.checkpointer is not None:
232
- self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
233
-
234
- def zero_grad(self, set_to_none=False):
235
- if not self.hparams.wav2vec2.freeze:
236
- self.wav2vec_optimizer.zero_grad(set_to_none)
237
- self.model_optimizer.zero_grad(set_to_none)
238
-
239
-
240
- # Define custom data procedure
241
- def dataio_prepare(hparams):
242
- """This function prepares the datasets to be used in the brain class.
243
- It also defines the data processing pipeline through user-defined functions."""
244
-
245
- # 1. Define datasets
246
- data_folder = hparams["data_folder"]
247
-
248
- train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
249
- csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
250
- )
251
-
252
- if hparams["sorting"] == "ascending":
253
- # we sort training data to speed up training and get better results.
254
- train_data = train_data.filtered_sorted(
255
- sort_key="duration",
256
- key_max_value={"duration": hparams["avoid_if_longer_than"]},
257
- )
258
- # when sorting do not shuffle in dataloader ! otherwise is pointless
259
- hparams["dataloader_options"]["shuffle"] = False
260
-
261
- elif hparams["sorting"] == "descending":
262
- train_data = train_data.filtered_sorted(
263
- sort_key="duration",
264
- reverse=True,
265
- key_max_value={"duration": hparams["avoid_if_longer_than"]},
266
- )
267
- # when sorting do not shuffle in dataloader ! otherwise is pointless
268
- hparams["dataloader_options"]["shuffle"] = False
269
-
270
- elif hparams["sorting"] == "random":
271
- pass
272
-
273
- else:
274
- raise NotImplementedError(
275
- "sorting must be random, ascending or descending"
276
- )
277
-
278
- valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
279
- csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
280
- )
281
- # We also sort the validation data so it is faster to validate
282
- valid_data = valid_data.filtered_sorted(sort_key="duration")
283
- test_datasets = {}
284
- for csv_file in hparams["test_csv"]:
285
- name = Path(csv_file).stem
286
- test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
287
- csv_path=csv_file, replacements={"data_root": data_folder}
288
- )
289
- test_datasets[name] = test_datasets[name].filtered_sorted(
290
- sort_key="duration"
291
- )
292
-
293
- datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
294
-
295
-
296
- # 2. Define audio pipeline:
297
- @sb.utils.data_pipeline.takes("wav")
298
- @sb.utils.data_pipeline.provides("sig")
299
- def audio_pipeline(wav):
300
- info = torchaudio.info(wav)
301
- sig = sb.dataio.dataio.read_audio(wav)
302
- resampled = torchaudio.transforms.Resample(
303
- info.sample_rate, hparams["sample_rate"],
304
- )(sig)
305
- return resampled
306
-
307
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
308
- label_encoder = sb.dataio.encoder.CTCTextEncoder()
309
-
310
- # 3. Define text pipeline:
311
- @sb.utils.data_pipeline.takes("wrd")
312
- @sb.utils.data_pipeline.provides(
313
- "wrd", "char_list", "tokens_list", "tokens"
314
- )
315
- def text_pipeline(wrd):
316
- yield wrd
317
- char_list = list(wrd)
318
- yield char_list
319
- tokens_list = label_encoder.encode_sequence(char_list)
320
- yield tokens_list
321
- tokens = torch.LongTensor(tokens_list)
322
- yield tokens
323
-
324
- sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
325
- lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
326
- special_labels = {
327
- "blank_label": hparams["blank_index"],
328
- "unk_label": hparams["unk_index"]
329
- }
330
- label_encoder.load_or_create(
331
- path=lab_enc_file,
332
- from_didatasets=[train_data],
333
- output_key="char_list",
334
- special_labels=special_labels,
335
- sequence_input=True,
336
- )
337
-
338
- # 4. Set output:
339
- sb.dataio.dataset.set_output_keys(
340
- datasets, ["id", "sig", "wrd", "char_list", "tokens"],
341
- )
342
- return train_data, valid_data,test_datasets, label_encoder
343
-
344
-
345
-
346
- # Load hyperparameters file with command-line overrides
347
- hparams_file, run_opts, overrides = sb.parse_arguments(["train_semi.yaml"])
348
- with open(hparams_file) as fin:
349
- hparams = load_hyperpyyaml(fin, overrides)
350
-
351
- # If --distributed_launch then
352
- # create ddp_group with the right communication protocol
353
- sb.utils.distributed.ddp_init_group(run_opts)
354
-
355
-
356
- # Create experiment directory
357
- sb.create_experiment_directory(
358
- experiment_directory=hparams["output_folder"],
359
- hyperparams_to_save=hparams_file,
360
- overrides=overrides,
361
- )
362
-
363
- # Due to DDP, we do the preparation ONLY on the main python process
364
- # Defining tokenizer and loading it
365
- # Create the datasets objects as well as tokenization and encoding :-D
366
- label_encoder = sb.dataio.encoder.CTCTextEncoder()
367
-
368
- lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
369
- special_labels = {
370
- "blank_label": hparams["blank_index"],
371
- "unk_label": hparams["unk_index"]
372
- }
373
- label_encoder.load_or_create(
374
- path=lab_enc_file,
375
- from_didatasets=[[]],
376
- output_key="char_list",
377
- special_labels=special_labels,
378
- sequence_input=True,
379
- )
380
-
381
- from pyctcdecode import build_ctcdecoder
382
- ind2lab = label_encoder.ind2lab
383
- print(ind2lab)
384
- labels = [ind2lab[x] for x in range(len(ind2lab))]
385
- labels = [""] + labels[1:-1] + ["1"]
386
- # Replace the <blank> token with a blank character, needed for PyCTCdecode
387
- print(labels)
388
- decoder = build_ctcdecoder(
389
- labels,
390
- kenlm_model_path=hparams["ngram_lm_path"], # .arpa or .bin
391
- alpha=0.5, # Default by KenLM
392
- beta=1.0, # Default by KenLM
393
- )
394
- # Trainer initialization
395
- run_opts["device"] = "cpu"
396
- asr_brain = ASR(
397
- modules=hparams["modules"],
398
- hparams=hparams,
399
- run_opts=run_opts,
400
- checkpointer=hparams["checkpointer"],
401
- )
402
-
403
- # Adding objects to trainer.
404
- asr_brain.tokenizer = label_encoder
405
- asr_brain.checkpointer.recover_if_possible(device="cpu")
406
- asr_brain.modules.eval()
407
- description = """This is a speechbrain-based Automatic Speech Recognition (ASR) model for Tunisian arabic. It outputs Tunisian Arabic transcriptions written in Arabic characters.
408
- This model outputs transcriptions in arabic alphabet only and performs poorly with sentences containing foreign words. However if you do need code-switching in your transcripts, i.e. foreign outputs in latin alphabet, you would better use the code switched model, available in another space from the same author. (https://huggingface.co/SalahZa/Code_Switched_Tunisian_Speech_Recognition)
409
-
410
- Run is done on CPU to keep it free in this space. This leads to quite long running times on long sequences. If for your project or research, you want to transcribe long sequences, you would better use the model directly from its page, some instructions for inference on a test set have been provided there. (https://huggingface.co/SalahZa/Tunisian_Automatic_Speech_Recognition). If you need help, feel free to drop an email here : [email protected]
411
-
412
- Authors :
413
- * [Salah Zaiem](https://fr.linkedin.com/in/salah-zaiem)
414
- * [Ahmed Amine Ben Aballah](https://www.linkedin.com/in/aabenz/)
415
- * [Ata Kaboudi](https://www.linkedin.com/in/ata-kaboudi-63365b1a8)
416
- * [Amir Kanoun](https://tn.linkedin.com/in/ahmed-amir-kanoun)
417
-
418
- More in-depth details and insights are available in a released preprint. Please find the paper [here](https://arxiv.org/abs/2309.11327).
419
- If you use or refer to this model, please cite :
420
-
421
- ```
422
- @misc{abdallah2023leveraging,
423
- title={Leveraging Data Collection and Unsupervised Learning for Code-switched Tunisian Arabic Automatic Speech Recognition},
424
- author={Ahmed Amine Ben Abdallah and Ata Kabboudi and Amir Kanoun and Salah Zaiem},
425
- year={2023},
426
- eprint={2309.11327},
427
- archivePrefix={arXiv},
428
- primaryClass={eess.AS}
429
- }
430
-
431
-
432
- """
433
- title = "Tunisian Speech Recognition"
434
-
435
- def treat_wav_file(file_mic,file_upload ,asr=asr_brain, device="cpu") :
436
- if (file_mic is not None) and (file_upload is not None):
437
- warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
438
- wav = file_mic
439
- elif (file_mic is None) and (file_upload is None):
440
- return "ERROR: You have to either use the microphone or upload an audio file"
441
- elif file_mic is not None:
442
- wav = file_mic
443
- else:
444
- wav = file_upload
445
- info = torchaudio.info(wav)
446
- sr = info.sample_rate
447
- sig = sb.dataio.dataio.read_audio(wav)
448
- if len(sig.shape)>1 :
449
- sig = torch.mean(sig, dim=1)
450
- sig = torch.unsqueeze(sig, 0)
451
- tensor_wav = sig.to(device)
452
- resampled = torchaudio.functional.resample( tensor_wav, sr, 16000)
453
- sentence = asr.treat_wav(resampled)
454
- return sentence
455
-
456
- gr.Interface(
457
- title = title,
458
- description = description,
459
- fn=treat_wav_file,
460
- inputs=[gr.Audio(source="microphone", type='filepath', label = "record", optional = True),
461
- gr.Audio(source="upload", type='filepath', label="filein", optional=True)]
462
- ,outputs="text").launch()
463
-
464
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1234/env.log DELETED
@@ -1,520 +0,0 @@
1
- SpeechBrain system description
2
- ==============================
3
- Python version:
4
- 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
5
- ==============================
6
- Installed Python packages:
7
- absl-py==1.4.0
8
- aiohttp==3.9.5
9
- aiosignal==1.3.1
10
- alabaster==0.7.16
11
- albumentations==1.3.1
12
- altair==4.2.2
13
- annotated-types==0.7.0
14
- anyio==3.7.1
15
- argon2-cffi==23.1.0
16
- argon2-cffi-bindings==21.2.0
17
- array_record==0.5.1
18
- arviz==0.15.1
19
- astropy==5.3.4
20
- astunparse==1.6.3
21
- async-timeout==4.0.3
22
- atpublic==4.1.0
23
- attrs==23.2.0
24
- audioread==3.0.1
25
- autograd==1.6.2
26
- Babel==2.15.0
27
- backcall==0.2.0
28
- beautifulsoup4==4.12.3
29
- bidict==0.23.1
30
- bigframes==1.6.0
31
- bleach==6.1.0
32
- blinker==1.4
33
- blis==0.7.11
34
- blosc2==2.0.0
35
- bokeh==3.3.4
36
- bqplot==0.12.43
37
- branca==0.7.2
38
- build==1.2.1
39
- CacheControl==0.14.0
40
- cachetools==5.3.3
41
- catalogue==2.0.10
42
- certifi==2024.2.2
43
- cffi==1.16.0
44
- chardet==5.2.0
45
- charset-normalizer==3.3.2
46
- chex==0.1.86
47
- click==8.1.7
48
- click-plugins==1.1.1
49
- cligj==0.7.2
50
- cloudpathlib==0.16.0
51
- cloudpickle==2.2.1
52
- cmake==3.27.9
53
- cmdstanpy==1.2.2
54
- colorcet==3.1.0
55
- colorlover==0.3.0
56
- colour==0.1.5
57
- community==1.0.0b1
58
- confection==0.1.4
59
- cons==0.4.6
60
- contextlib2==21.6.0
61
- contourpy==1.2.1
62
- cryptography==42.0.7
63
- cuda-python==12.2.1
64
- cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.4.1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=57366e7ef09dc63e0b389aff20df6c37d91e2790065861ee31a4720149f5b694
65
- cufflinks==0.17.3
66
- cupy-cuda12x==12.2.0
67
- cvxopt==1.3.2
68
- cvxpy==1.3.4
69
- cycler==0.12.1
70
- cymem==2.0.8
71
- Cython==3.0.10
72
- dask==2023.8.1
73
- datascience==0.17.6
74
- datasets==2.19.1
75
- db-dtypes==1.2.0
76
- dbus-python==1.2.18
77
- debugpy==1.6.6
78
- decorator==4.4.2
79
- defusedxml==0.7.1
80
- dill==0.3.8
81
- distributed==2023.8.1
82
- distro==1.7.0
83
- dlib==19.24.4
84
- dm-tree==0.1.8
85
- docstring_parser==0.16
86
- docutils==0.18.1
87
- dopamine_rl==4.0.9
88
- duckdb==0.10.2
89
- earthengine-api==0.1.403
90
- easydict==1.13
91
- ecos==2.0.13
92
- editdistance==0.6.2
93
- eerepr==0.0.4
94
- en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
95
- entrypoints==0.4
96
- et-xmlfile==1.1.0
97
- etils==1.7.0
98
- etuples==0.3.9
99
- exceptiongroup==1.2.1
100
- fastai==2.7.15
101
- fastcore==1.5.38
102
- fastdownload==0.0.7
103
- fastjsonschema==2.19.1
104
- fastprogress==1.0.3
105
- fastrlock==0.8.2
106
- filelock==3.14.0
107
- fiona==1.9.6
108
- firebase-admin==5.3.0
109
- Flask==2.2.5
110
- flatbuffers==24.3.25
111
- flax==0.8.3
112
- folium==0.14.0
113
- fonttools==4.51.0
114
- frozendict==2.4.4
115
- frozenlist==1.4.1
116
- fsspec==2023.6.0
117
- future==0.18.3
118
- gast==0.5.4
119
- gcsfs==2023.6.0
120
- GDAL==3.6.4
121
- gdown==5.1.0
122
- geemap==0.32.1
123
- gensim==4.3.2
124
- geocoder==1.38.1
125
- geographiclib==2.0
126
- geopandas==0.13.2
127
- geopy==2.3.0
128
- gin-config==0.5.0
129
- glob2==0.7
130
- google==2.0.3
131
- google-ai-generativelanguage==0.6.4
132
- google-api-core==2.11.1
133
- google-api-python-client==2.84.0
134
- google-auth==2.27.0
135
- google-auth-httplib2==0.1.1
136
- google-auth-oauthlib==1.2.0
137
- google-cloud-aiplatform==1.52.0
138
- google-cloud-bigquery==3.21.0
139
- google-cloud-bigquery-connection==1.12.1
140
- google-cloud-bigquery-storage==2.25.0
141
- google-cloud-core==2.3.3
142
- google-cloud-datastore==2.15.2
143
- google-cloud-firestore==2.11.1
144
- google-cloud-functions==1.13.3
145
- google-cloud-iam==2.15.0
146
- google-cloud-language==2.13.3
147
- google-cloud-resource-manager==1.12.3
148
- google-cloud-storage==2.8.0
149
- google-cloud-translate==3.11.3
150
- google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz#sha256=be2293f8ecde760dca63eb67279ba06283ad80f0e13e1c08343f5ef1b33c7116
151
- google-crc32c==1.5.0
152
- google-generativeai==0.5.4
153
- google-pasta==0.2.0
154
- google-resumable-media==2.7.0
155
- googleapis-common-protos==1.63.0
156
- googledrivedownloader==0.4
157
- graphviz==0.20.3
158
- greenlet==3.0.3
159
- grpc-google-iam-v1==0.13.0
160
- grpcio==1.64.0
161
- grpcio-status==1.48.2
162
- gspread==6.0.2
163
- gspread-dataframe==3.3.1
164
- gym==0.25.2
165
- gym-notices==0.0.8
166
- h5netcdf==1.3.0
167
- h5py==3.9.0
168
- holidays==0.49
169
- holoviews==1.17.1
170
- html5lib==1.1
171
- httpimport==1.3.1
172
- httplib2==0.22.0
173
- huggingface-hub==0.23.1
174
- humanize==4.7.0
175
- hyperopt==0.2.7
176
- HyperPyYAML==1.2.2
177
- hypothesis==6.102.6
178
- ibis-framework==8.0.0
179
- idna==3.7
180
- imageio==2.31.6
181
- imageio-ffmpeg==0.4.9
182
- imagesize==1.4.1
183
- imbalanced-learn==0.10.1
184
- imgaug==0.4.0
185
- importlib_metadata==7.1.0
186
- importlib_resources==6.4.0
187
- imutils==0.5.4
188
- inflect==7.0.0
189
- iniconfig==2.0.0
190
- intel-openmp==2023.2.4
191
- ipyevents==2.0.2
192
- ipyfilechooser==0.6.0
193
- ipykernel==5.5.6
194
- ipyleaflet==0.18.2
195
- ipython==7.34.0
196
- ipython-genutils==0.2.0
197
- ipython-sql==0.5.0
198
- ipytree==0.2.2
199
- ipywidgets==7.7.1
200
- itsdangerous==2.2.0
201
- jax==0.4.26
202
- jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750
203
- jeepney==0.7.1
204
- jellyfish==1.0.3
205
- jieba==0.42.1
206
- Jinja2==3.1.4
207
- jiwer==3.0.4
208
- joblib==1.4.2
209
- jsonpickle==3.0.4
210
- jsonschema==4.19.2
211
- jsonschema-specifications==2023.12.1
212
- jupyter-client==6.1.12
213
- jupyter-console==6.1.0
214
- jupyter-server==1.24.0
215
- jupyter_core==5.7.2
216
- jupyterlab_pygments==0.3.0
217
- jupyterlab_widgets==3.0.10
218
- kaggle==1.6.14
219
- kagglehub==0.2.5
220
- kenlm==0.2.0
221
- keras==2.15.0
222
- keyring==23.5.0
223
- kiwisolver==1.4.5
224
- langcodes==3.4.0
225
- language_data==1.2.0
226
- launchpadlib==1.10.16
227
- lazr.restfulclient==0.14.4
228
- lazr.uri==1.0.6
229
- lazy_loader==0.4
230
- libclang==18.1.1
231
- librosa==0.10.2.post1
232
- lightgbm==4.1.0
233
- linkify-it-py==2.0.3
234
- llvmlite==0.41.1
235
- locket==1.0.0
236
- logical-unification==0.4.6
237
- lxml==4.9.4
238
- malloy==2023.1067
239
- marisa-trie==1.1.1
240
- Markdown==3.6
241
- markdown-it-py==3.0.0
242
- MarkupSafe==2.1.5
243
- matplotlib==3.7.1
244
- matplotlib-inline==0.1.7
245
- matplotlib-venn==0.11.10
246
- mdit-py-plugins==0.4.1
247
- mdurl==0.1.2
248
- miniKanren==1.0.3
249
- missingno==0.5.2
250
- mistune==0.8.4
251
- mizani==0.9.3
252
- mkl==2023.2.0
253
- ml-dtypes==0.2.0
254
- mlxtend==0.22.0
255
- more-itertools==10.1.0
256
- moviepy==1.0.3
257
- mpmath==1.3.0
258
- msgpack==1.0.8
259
- multidict==6.0.5
260
- multipledispatch==1.0.0
261
- multiprocess==0.70.16
262
- multitasking==0.0.11
263
- murmurhash==1.0.10
264
- music21==9.1.0
265
- natsort==8.4.0
266
- nbclassic==1.0.0
267
- nbclient==0.10.0
268
- nbconvert==6.5.4
269
- nbformat==5.10.4
270
- nest-asyncio==1.6.0
271
- networkx==3.3
272
- nibabel==4.0.2
273
- nltk==3.8.1
274
- notebook==6.5.5
275
- notebook_shim==0.2.4
276
- numba==0.58.1
277
- numexpr==2.10.0
278
- numpy==1.25.2
279
- nvidia-cublas-cu12==12.1.3.1
280
- nvidia-cuda-cupti-cu12==12.1.105
281
- nvidia-cuda-nvrtc-cu12==12.1.105
282
- nvidia-cuda-runtime-cu12==12.1.105
283
- nvidia-cudnn-cu12==8.9.2.26
284
- nvidia-cufft-cu12==11.0.2.54
285
- nvidia-curand-cu12==10.3.2.106
286
- nvidia-cusolver-cu12==11.4.5.107
287
- nvidia-cusparse-cu12==12.1.0.106
288
- nvidia-nccl-cu12==2.20.5
289
- nvidia-nvjitlink-cu12==12.5.40
290
- nvidia-nvtx-cu12==12.1.105
291
- nvtx==0.2.10
292
- oauth2client==4.1.3
293
- oauthlib==3.2.2
294
- opencv-contrib-python==4.8.0.76
295
- opencv-python==4.8.0.76
296
- opencv-python-headless==4.9.0.80
297
- openpyxl==3.1.2
298
- opt-einsum==3.3.0
299
- optax==0.2.2
300
- orbax-checkpoint==0.4.4
301
- osqp==0.6.2.post8
302
- packaging==24.0
303
- pandas==2.0.3
304
- pandas-datareader==0.10.0
305
- pandas-gbq==0.19.2
306
- pandas-stubs==2.0.3.230814
307
- pandocfilters==1.5.1
308
- panel==1.3.8
309
- param==2.1.0
310
- parso==0.8.4
311
- parsy==2.1
312
- partd==1.4.2
313
- pathlib==1.0.1
314
- patsy==0.5.6
315
- peewee==3.17.5
316
- pexpect==4.9.0
317
- pickleshare==0.7.5
318
- Pillow==9.4.0
319
- pip-tools==6.13.0
320
- platformdirs==4.2.2
321
- plotly==5.15.0
322
- plotnine==0.12.4
323
- pluggy==1.5.0
324
- polars==0.20.2
325
- pooch==1.8.1
326
- portpicker==1.5.2
327
- prefetch-generator==1.0.3
328
- preshed==3.0.9
329
- prettytable==3.10.0
330
- proglog==0.1.10
331
- progressbar2==4.2.0
332
- prometheus_client==0.20.0
333
- promise==2.3
334
- prompt-toolkit==3.0.43
335
- prophet==1.1.5
336
- proto-plus==1.23.0
337
- protobuf==3.20.3
338
- psutil==5.9.5
339
- psycopg2==2.9.9
340
- ptyprocess==0.7.0
341
- py-cpuinfo==9.0.0
342
- py4j==0.10.9.7
343
- pyarrow==14.0.2
344
- pyarrow-hotfix==0.6
345
- pyasn1==0.6.0
346
- pyasn1_modules==0.4.0
347
- pycocotools==2.0.7
348
- pycparser==2.22
349
- pyctcdecode==0.5.0
350
- pydantic==2.7.1
351
- pydantic_core==2.18.2
352
- pydata-google-auth==1.8.2
353
- pydot==1.4.2
354
- pydot-ng==2.0.0
355
- pydotplus==2.0.2
356
- PyDrive==1.3.1
357
- PyDrive2==1.6.3
358
- pyerfa==2.0.1.4
359
- pygame==2.5.2
360
- Pygments==2.16.1
361
- PyGObject==3.42.1
362
- pygtrie==2.5.0
363
- PyJWT==2.3.0
364
- pymc==5.10.4
365
- pymystem3==0.2.0
366
- pynvjitlink-cu12==0.2.3
367
- PyOpenGL==3.1.7
368
- pyOpenSSL==24.1.0
369
- pyparsing==3.1.2
370
- pyperclip==1.8.2
371
- pyproj==3.6.1
372
- pyproject_hooks==1.1.0
373
- pyshp==2.3.1
374
- PySocks==1.7.1
375
- pytensor==2.18.6
376
- pytest==7.4.4
377
- python-apt @ file:///backend-container/containers/python_apt-0.0.0-cp310-cp310-linux_x86_64.whl#sha256=b209c7165d6061963abe611492f8c91c3bcef4b7a6600f966bab58900c63fefa
378
- python-box==7.1.1
379
- python-dateutil==2.8.2
380
- python-louvain==0.16
381
- python-slugify==8.0.4
382
- python-utils==3.8.2
383
- pytz==2023.4
384
- pyviz_comms==3.0.2
385
- PyWavelets==1.6.0
386
- PyYAML==6.0.1
387
- pyzmq==24.0.1
388
- qdldl==0.1.7.post2
389
- qudida==0.0.4
390
- rapidfuzz==3.9.1
391
- ratelim==0.1.6
392
- referencing==0.35.1
393
- regex==2023.12.25
394
- requests==2.31.0
395
- requests-oauthlib==1.3.1
396
- requirements-parser==0.9.0
397
- rich==13.7.1
398
- rmm-cu12==24.4.0
399
- rpds-py==0.18.1
400
- rpy2==3.4.2
401
- rsa==4.9
402
- ruamel.yaml==0.18.6
403
- ruamel.yaml.clib==0.2.8
404
- safetensors==0.4.3
405
- scikit-image==0.19.3
406
- scikit-learn==1.2.2
407
- scipy==1.11.4
408
- scooby==0.10.0
409
- scs==3.2.4.post1
410
- seaborn==0.13.1
411
- SecretStorage==3.3.1
412
- Send2Trash==1.8.3
413
- sentencepiece==0.1.99
414
- shapely==2.0.4
415
- six==1.16.0
416
- sklearn-pandas==2.2.0
417
- smart-open==6.4.0
418
- sniffio==1.3.1
419
- snowballstemmer==2.2.0
420
- sortedcontainers==2.4.0
421
- soundfile==0.12.1
422
- soupsieve==2.5
423
- soxr==0.3.7
424
- spacy==3.7.4
425
- spacy-legacy==3.0.12
426
- spacy-loggers==1.0.5
427
- speechbrain==0.5.16
428
- Sphinx==5.0.2
429
- sphinxcontrib-applehelp==1.0.8
430
- sphinxcontrib-devhelp==1.0.6
431
- sphinxcontrib-htmlhelp==2.0.5
432
- sphinxcontrib-jsmath==1.0.1
433
- sphinxcontrib-qthelp==1.0.7
434
- sphinxcontrib-serializinghtml==1.1.10
435
- SQLAlchemy==2.0.30
436
- sqlglot==20.11.0
437
- sqlparse==0.5.0
438
- srsly==2.4.8
439
- stanio==0.5.0
440
- statsmodels==0.14.2
441
- StrEnum==0.4.15
442
- sympy==1.12
443
- tables==3.8.0
444
- tabulate==0.9.0
445
- tbb==2021.12.0
446
- tblib==3.0.0
447
- tenacity==8.3.0
448
- tensorboard==2.15.2
449
- tensorboard-data-server==0.7.2
450
- tensorflow @ https://storage.googleapis.com/colab-tf-builds-public-09h6ksrfwbb9g9xv/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=a2ec79931350b378c1ef300ca836b52a55751acb71a433582508a07f0de57c42
451
- tensorflow-datasets==4.9.4
452
- tensorflow-estimator==2.15.0
453
- tensorflow-gcs-config==2.15.0
454
- tensorflow-hub==0.16.1
455
- tensorflow-io-gcs-filesystem==0.37.0
456
- tensorflow-metadata==1.15.0
457
- tensorflow-probability==0.23.0
458
- tensorstore==0.1.45
459
- termcolor==2.4.0
460
- terminado==0.18.1
461
- text-unidecode==1.3
462
- textblob==0.17.1
463
- tf-slim==1.1.0
464
- tf_keras==2.15.1
465
- thinc==8.2.3
466
- threadpoolctl==3.5.0
467
- tifffile==2024.5.10
468
- tinycss2==1.3.0
469
- tokenizers==0.19.1
470
- toml==0.10.2
471
- tomli==2.0.1
472
- toolz==0.12.1
473
- torch @ https://download.pytorch.org/whl/cu121/torch-2.3.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=0a12aa9aa6bc442dff8823ac8b48d991fd0771562eaa38593f9c8196d65f7007
474
- torchaudio @ https://download.pytorch.org/whl/cu121/torchaudio-2.3.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=38b49393f8c322dcaa29d19e5acbf5a0b1978cf1b719445ab670f1fb486e3aa6
475
- torchsummary==1.5.1
476
- torchtext==0.18.0
477
- torchvision @ https://download.pytorch.org/whl/cu121/torchvision-0.18.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=13e1b48dc5ce41ccb8100ab3dd26fdf31d8f1e904ecf2865ac524493013d0df5
478
- tornado==6.3.3
479
- tqdm==4.66.4
480
- traitlets==5.7.1
481
- traittypes==0.2.1
482
- transformers==4.41.0
483
- triton==2.3.0
484
- tweepy==4.14.0
485
- typer==0.9.4
486
- types-pytz==2024.1.0.20240417
487
- types-setuptools==70.0.0.20240523
488
- typing_extensions==4.11.0
489
- tzdata==2024.1
490
- tzlocal==5.2
491
- uc-micro-py==1.0.3
492
- uritemplate==4.1.1
493
- urllib3==2.0.7
494
- vega-datasets==0.9.0
495
- wadllib==1.3.6
496
- wasabi==1.1.2
497
- wcwidth==0.2.13
498
- weasel==0.3.4
499
- webcolors==1.13
500
- webencodings==0.5.1
501
- websocket-client==1.8.0
502
- Werkzeug==3.0.3
503
- widgetsnbextension==3.6.6
504
- wordcloud==1.9.3
505
- wrapt==1.14.1
506
- xarray==2023.7.0
507
- xarray-einstats==0.7.0
508
- xgboost==2.0.3
509
- xlrd==2.0.1
510
- xxhash==3.4.1
511
- xyzservices==2024.4.0
512
- yarl==1.9.4
513
- yellowbrick==1.5
514
- yfinance==0.2.40
515
- zict==3.0.0
516
- zipp==3.18.2
517
- ==============================
518
- Could not get git revision==============================
519
- CUDA version:
520
- 12.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1234/hyperparams.yaml DELETED
@@ -1,179 +0,0 @@
1
- # Generated 2024-05-27 from:
2
- # /content/drive/MyDrive/TASR/finetuning.yaml
3
- # yamllint disable
4
- # ################################
5
- # Model: wav2vec2 + DNN + CTC
6
- # Augmentation: SpecAugment
7
- # Authors: Titouan Parcollet 2021
8
- # ################################
9
-
10
- # Seed needs to be set at top of yaml, before objects with parameters are made
11
- seed: 1234
12
- __set_seed: !!python/object/apply:torch.manual_seed [1234]
13
- output_folder: semi_wavlm_large_tunisian_ctc/1234
14
- wer_file: semi_wavlm_large_tunisian_ctc/1234/wer.txt
15
- save_folder: semi_wavlm_large_tunisian_ctc/1234/save
16
- train_log: semi_wavlm_large_tunisian_ctc/1234/train_log.txt
17
-
18
- # URL for the biggest LeBenchmark wav2vec french.
19
- wav2vec2_folder: semi_wavlm_large_tunisian_ctc/1234/save/wav2vec2_checkpoint
20
-
21
- # Data files
22
- data_folder: /path/to/data # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
23
- train_tsv_file: /path/to/data/train.tsv # Standard CommonVoice .tsv files
24
- dev_tsv_file: /path/to/data/dev.tsv # Standard CommonVoice .tsv files
25
- test_tsv_file: /path/to/data/test.tsv # Standard CommonVoice .tsv files
26
- accented_letters: true
27
- language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
28
- train_csv: Data/train_wavs/train.csv
29
- valid_csv: Data/dev_wavs/dev.csv
30
- test_csv:
31
- - Data/test_wavs/test.csv
32
-
33
- skip_prep: true # Skip data preparation
34
-
35
- use_language_modelling: true
36
- ngram_lm_path: outdomain.arpa
37
-
38
- # We remove utterance slonger than 10s in the train/dev/test sets as
39
- # longer sentences certainly correspond to "open microphones".
40
- avoid_if_longer_than: 10.0
41
- avoid_if_shorter_than: 1.2
42
-
43
-
44
- # Training parameters
45
- number_of_epochs: 20
46
- lr: 1.0
47
- lr_wav2vec: 0.0001
48
- sorting: ascending
49
- auto_mix_prec: false
50
- sample_rate: 16000
51
- ckpt_interval_minutes: 30 # save checkpoint every N min
52
-
53
- # With data_parallel batch_size is split into N jobs
54
- # With DDP batch_size is multiplied by N jobs
55
- # Must be 6 per GPU to fit 16GB of VRAM
56
- batch_size: 6
57
- test_batch_size: 4
58
-
59
- dataloader_options:
60
- batch_size: 6
61
- num_workers: 6
62
- test_dataloader_options:
63
- batch_size: 4
64
- num_workers: 6
65
-
66
- # BPE parameters
67
- token_type: char # ["unigram", "bpe", "char"]
68
- character_coverage: 1.0
69
-
70
- # Model parameters
71
- # activation: !name:torch.nn.LeakyReLU
72
- wav2vec_output_dim: 1024
73
- dnn_neurons: 1024
74
- freeze_wav2vec: false
75
- freeze_feature_extractor: true
76
- dropout: 0.15
77
- warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps
78
-
79
- # Outputs
80
- output_neurons: 40 # BPE size, index(blank/eos/bos) = 0
81
-
82
- # Decoding parameters
83
- # Be sure that the bos and eos index match with the BPEs ones
84
- blank_index: 0
85
- unk_index: 1
86
-
87
- #
88
- # Functions and classes
89
- #
90
- epoch_counter: &id007 !new:speechbrain.utils.epoch_loop.EpochCounter
91
-
92
- limit: 20
93
-
94
- augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
95
- sample_rate: 16000
96
- speeds: [95, 100, 105]
97
-
98
- enc: &id002 !new:speechbrain.nnet.containers.Sequential
99
- input_shape: [null, null, 1024]
100
- linear1: !name:speechbrain.nnet.linear.Linear
101
- n_neurons: 1024
102
- bias: true
103
- bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
104
- activation: !new:torch.nn.LeakyReLU
105
- drop: !new:torch.nn.Dropout
106
- p: 0.15
107
- linear2: !name:speechbrain.nnet.linear.Linear
108
- n_neurons: 1024
109
- bias: true
110
- bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
111
- activation2: !new:torch.nn.LeakyReLU
112
- drop2: !new:torch.nn.Dropout
113
- p: 0.15
114
- linear3: !name:speechbrain.nnet.linear.Linear
115
- n_neurons: 1024
116
- bias: true
117
- bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
118
- activation3: !new:torch.nn.LeakyReLU
119
-
120
- wav2vec2: &id001 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
121
- source: wavlm-large/
122
- output_norm: false
123
- freeze: false
124
- freeze_feature_extractor: true
125
- save_path: semi_wavlm_large_tunisian_ctc/1234/save/wav2vec2_checkpoint
126
-
127
-
128
- ctc_lin: &id003 !new:speechbrain.nnet.linear.Linear
129
-
130
- input_size: 1024
131
- n_neurons: 40
132
-
133
- log_softmax: !new:speechbrain.nnet.activations.Softmax
134
- apply_log: true
135
-
136
- ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
137
- blank_index: 0
138
-
139
- modules:
140
- wav2vec2: *id001
141
- enc: *id002
142
- ctc_lin: *id003
143
- model: &id004 !new:torch.nn.ModuleList
144
- - [*id002, *id003]
145
- model_opt_class: !name:torch.optim.Adadelta
146
- lr: 1.0
147
- rho: 0.95
148
- eps: 1.e-8
149
-
150
- wav2vec_opt_class: !name:torch.optim.Adam
151
- lr: 0.0001
152
-
153
- lr_annealing_model: &id005 !new:speechbrain.nnet.schedulers.NewBobScheduler
154
- initial_value: 1.0
155
- improvement_threshold: 0.0025
156
- annealing_factor: 0.8
157
- patient: 0
158
-
159
- lr_annealing_wav2vec: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
160
- initial_value: 0.0001
161
- improvement_threshold: 0.0025
162
- annealing_factor: 0.9
163
- patient: 0
164
-
165
- checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
166
- checkpoints_dir: semi_wavlm_large_tunisian_ctc/1234/save
167
- recoverables:
168
- wav2vec2: *id001
169
- model: *id004
170
- scheduler_model: *id005
171
- scheduler_wav2vec: *id006
172
- counter: *id007
173
- train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
174
- save_file: semi_wavlm_large_tunisian_ctc/1234/train_log.txt
175
-
176
- error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
177
-
178
- cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
179
- split_tokens: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1234/log.txt DELETED
The diff for this file is too large to render. See raw diff
 
1234/save/CKPT+2024-05-27+00-52-30+00/CKPT.yaml DELETED
@@ -1,4 +0,0 @@
1
- # yamllint disable
2
- WER: 22.820037105751393
3
- end-of-epoch: true
4
- unixtime: 1716771150.6156993
 
 
 
 
 
1234/save/CKPT+2024-05-27+00-52-30+00/brain.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:17a40ee07b922d0dc2b48662b8e4fd2f6032c946ffdd21cee17047c2a4a6feca
3
- size 51
 
 
 
 
1234/save/CKPT+2024-05-27+00-52-30+00/counter.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f5ca38f748a1d6eaf726b8a42fb575c3c71f1864a8143301782de13da2d9202b
3
- size 2
 
 
 
 
1234/save/CKPT+2024-05-27+00-52-30+00/dataloader-TRAIN.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:459535faa370a3b5f8b87203b089623c7aeb9325abf241ec8a685b9c325047a3
3
- size 3
 
 
 
 
1234/save/CKPT+2024-05-27+00-52-30+00/model.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:766f5e1023404b4364d9f5d4a270b0699af3a15cb47684f33dd4b8098bd000b6
3
- size 12814875
 
 
 
 
1234/save/CKPT+2024-05-27+00-52-30+00/modelopt.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:59e5f239bff4d0686120f4e4a9c5d54524a622089908b7e3abc8ebbb1b68f113
3
- size 25576098
 
 
 
 
1234/save/CKPT+2024-05-27+00-52-30+00/scheduler_model.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:94041bc093d9222795932e865c54f3c3a2d48cf44677968ce689d2867af20a92
3
- size 1152
 
 
 
 
1234/save/CKPT+2024-05-27+00-52-30+00/scheduler_wav2vec.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c8c35e64e340922dc41304565bc0c1429bdf06ca931ed9b0bbecd4d7a3d1547c
3
- size 1160
 
 
 
 
1234/save/CKPT+2024-05-27+00-52-30+00/wav2vec2.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ebad96920b7643616c39cf455ad89689b844603bf08c821a55004e4553bf37fa
3
- size 1262006670
 
 
 
 
1234/save/CKPT+2024-05-27+00-52-30+00/wav2vec_opt.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:af8c6950a6c04fd90519da5f40d9e0d1a3ace6aaa1212be88009c04a6fa46a39
3
- size 2490362300
 
 
 
 
1234/save/label_encoder.txt DELETED
@@ -1,44 +0,0 @@
1
- 'ب' => 38
2
- 'ا' => 1
3
- 'ه' => 2
4
- 'ي' => 3
5
- 'و' => 4
6
- 'ن' => 5
7
- 'أ' => 6
8
- ' ' => 7
9
- 'م' => 8
10
- 'ش' => 9
11
- 'ل' => 10
12
- 'س' => 11
13
- 'ت' => 12
14
- 'د' => 13
15
- 'ر' => 14
16
- 'ى' => 15
17
- 'ح' => 16
18
- 'ط' => 17
19
- 'ع' => 18
20
- 'ك' => 19
21
- 'ف' => 20
22
- 'ق' => 21
23
- 'آ' => 22
24
- 'ة' => 23
25
- 'ج' => 24
26
- 'ض' => 25
27
- 'ز' => 26
28
- 'ص' => 27
29
- 'إ' => 28
30
- 'ث' => 29
31
- 'خ' => 30
32
- 'ڨ' => 31
33
- 'ذ' => 32
34
- 'ظ' => 33
35
- 'ء' => 34
36
- 'غ' => 35
37
- 'ئ' => 36
38
- 'ؤ' => 37
39
- '<blank>' => 0
40
- 1 => 39
41
- ================
42
- 'starting_index' => 0
43
- 'unk_label' => 1
44
- 'blank_label' => '<blank>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1234/train_log.txt DELETED
@@ -1,14 +0,0 @@
1
- Epoch loaded: 12 - test loss: 4.76e-01, test CER: 9.62, test WER: 26.35
2
- epoch: 13, lr_model: 3.28e-01, lr_wav2vec: 5.90e-05 - train loss: 3.66e-01 - valid loss: 3.36e-01, valid CER: 7.90, valid WER: 20.96
3
- epoch: 14, lr_model: 3.28e-01, lr_wav2vec: 5.90e-05 - train loss: 3.07e-01 - valid loss: 3.48e-01, valid CER: 8.34, valid WER: 23.38
4
- epoch: 15, lr_model: 2.62e-01, lr_wav2vec: 5.31e-05 - train loss: 2.66e-01 - valid loss: 3.58e-01, valid CER: 7.79, valid WER: 21.89
5
- epoch: 16, lr_model: 2.10e-01, lr_wav2vec: 4.78e-05 - train loss: 2.41e-01 - valid loss: 3.61e-01, valid CER: 8.09, valid WER: 22.82
6
- epoch: 17, lr_model: 1.68e-01, lr_wav2vec: 4.30e-05 - train loss: 2.21e-01 - valid loss: 3.75e-01, valid CER: 7.83, valid WER: 22.45
7
- epoch: 18, lr_model: 1.34e-01, lr_wav2vec: 3.87e-05 - train loss: 2.08e-01 - valid loss: 3.98e-01, valid CER: 8.12, valid WER: 22.82
8
- epoch: 19, lr_model: 1.07e-01, lr_wav2vec: 3.49e-05 - train loss: 1.94e-01 - valid loss: 4.06e-01, valid CER: 8.05, valid WER: 22.63
9
- epoch: 20, lr_model: 8.59e-02, lr_wav2vec: 3.14e-05 - train loss: 1.87e-01 - valid loss: 4.21e-01, valid CER: 7.72, valid WER: 22.82
10
- Epoch loaded: 20 - test loss: 4.38e-01, test CER: 9.18, test WER: 24.78
11
- Epoch loaded: 20 - test loss: 4.38e-01, test CER: 9.18, test WER: 24.78
12
- Epoch loaded: 20 - test loss: 4.38e-01, test CER: 9.18, test WER: 24.78
13
- Epoch loaded: 20 - test loss: 4.38e-01, test CER: 9.18, test WER: 24.78
14
- Epoch loaded: 20 - test loss: 4.38e-01, test CER: 9.18, test WER: 24.78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1234/train_with_wavlm.py DELETED
@@ -1,399 +0,0 @@
1
- #!/usr/bin/env python3
2
- import sys
3
- import torch
4
- import logging
5
- import speechbrain as sb
6
- from pathlib import Path
7
- import os
8
- import torchaudio
9
- from hyperpyyaml import load_hyperpyyaml
10
- from speechbrain.tokenizers.SentencePiece import SentencePiece
11
- from speechbrain.utils.data_utils import undo_padding
12
- from speechbrain.utils.distributed import run_on_main
13
-
14
- """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
15
- The system employs a wav2vec2 encoder and a CTC decoder.
16
- Decoding is performed with greedy decoding (will be extended to beam search).
17
-
18
- To run this recipe, do the following:
19
- > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
20
-
21
- With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
22
- The wav2vec2 model is pretrained following the model given in the hprams file.
23
- It may be dependent on the language.
24
-
25
- The neural network is trained with CTC on sub-word units estimated with
26
- Byte Pairwise Encoding (BPE).
27
-
28
- The experiment file is flexible enough to support a large variety of
29
- different systems. By properly changing the parameter files, you can try
30
- different encoders, decoders, tokens (e.g, characters instead of BPE),
31
- training languages (all CommonVoice languages), and many
32
- other possible variations.
33
-
34
- Authors
35
- * Titouan Parcollet 2021
36
- """
37
-
38
- logger = logging.getLogger(__name__)
39
-
40
-
41
- # Define training procedure
42
- class ASR(sb.core.Brain):
43
- def compute_forward(self, batch, stage):
44
- """Forward computations from the waveform batches to the output probabilities."""
45
-
46
- batch = batch.to(self.device)
47
- wavs, wav_lens = batch.sig
48
- wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
49
- if stage == sb.Stage.TRAIN:
50
- if hasattr(self.hparams, "augmentation"):
51
- wavs = self.hparams.augmentation(wavs, wav_lens)
52
-
53
- # Forward pass
54
- feats = self.modules.wav2vec2(wavs, wav_lens)
55
- x = self.modules.enc(feats)
56
- logits = self.modules.ctc_lin(x)
57
- p_ctc = self.hparams.log_softmax(logits)
58
-
59
- return p_ctc, wav_lens
60
-
61
- def compute_objectives(self, predictions, batch, stage):
62
- """Computes the loss (CTC) given predictions and targets."""
63
-
64
- p_ctc, wav_lens = predictions
65
-
66
- ids = batch.id
67
- tokens, tokens_lens = batch.tokens
68
-
69
- loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
70
-
71
- if stage != sb.Stage.TRAIN:
72
- predicted_tokens = sb.decoders.ctc_greedy_decode(
73
- p_ctc, wav_lens, blank_id=self.hparams.blank_index
74
- )
75
- # Decode token terms to words
76
- if self.hparams.use_language_modelling:
77
- predicted_words = []
78
- for logs in p_ctc:
79
- text = decoder.decode(logs.detach().cpu().numpy())
80
- predicted_words.append(text.split(" "))
81
- else:
82
- predicted_words = [
83
- "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
84
- for utt_seq in predicted_tokens
85
- ]
86
- # Convert indices to words
87
- target_words = [wrd.split(" ") for wrd in batch.wrd]
88
-
89
- self.wer_metric.append(ids, predicted_words, target_words)
90
- self.cer_metric.append(ids, predicted_words, target_words)
91
-
92
- return loss
93
-
94
- def fit_batch(self, batch):
95
- """Train the parameters given a single batch in input"""
96
- should_step = self.step % self.grad_accumulation_factor == 0
97
- # Managing automatic mixed precision
98
- # TOFIX: CTC fine-tuning currently is unstable
99
- # This is certainly due to CTC being done in fp16 instead of fp32
100
- if self.auto_mix_prec:
101
- with torch.cuda.amp.autocast():
102
- with self.no_sync():
103
- outputs = self.compute_forward(batch, sb.Stage.TRAIN)
104
- loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
105
- with self.no_sync(not should_step):
106
- self.scaler.scale(
107
- loss / self.grad_accumulation_factor
108
- ).backward()
109
- if should_step:
110
-
111
- if not self.hparams.wav2vec2.freeze:
112
- self.scaler.unscale_(self.wav2vec_optimizer)
113
- self.scaler.unscale_(self.model_optimizer)
114
- if self.check_gradients(loss):
115
- if not self.hparams.wav2vec2.freeze:
116
- if self.optimizer_step >= self.hparams.warmup_steps:
117
- self.scaler.step(self.wav2vec_optimizer)
118
- self.scaler.step(self.model_optimizer)
119
- self.scaler.update()
120
- self.zero_grad()
121
- self.optimizer_step += 1
122
- else:
123
- # This is mandatory because HF models have a weird behavior with DDP
124
- # on the forward pass
125
- with self.no_sync():
126
- outputs = self.compute_forward(batch, sb.Stage.TRAIN)
127
-
128
- loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
129
-
130
- with self.no_sync(not should_step):
131
- (loss / self.grad_accumulation_factor).backward()
132
- if should_step:
133
- if self.check_gradients(loss):
134
- if not self.hparams.wav2vec2.freeze:
135
- if self.optimizer_step >= self.hparams.warmup_steps:
136
- self.wav2vec_optimizer.step()
137
- self.model_optimizer.step()
138
- self.zero_grad()
139
- self.optimizer_step += 1
140
-
141
- self.on_fit_batch_end(batch, outputs, loss, should_step)
142
- return loss.detach().cpu()
143
-
144
- def evaluate_batch(self, batch, stage):
145
- """Computations needed for validation/test batches"""
146
- predictions = self.compute_forward(batch, stage=stage)
147
- with torch.no_grad():
148
- loss = self.compute_objectives(predictions, batch, stage=stage)
149
- return loss.detach()
150
-
151
- def on_stage_start(self, stage, epoch):
152
- """Gets called at the beginning of each epoch"""
153
- if stage != sb.Stage.TRAIN:
154
- self.cer_metric = self.hparams.cer_computer()
155
- self.wer_metric = self.hparams.error_rate_computer()
156
-
157
- def on_stage_end(self, stage, stage_loss, epoch):
158
- """Gets called at the end of an epoch."""
159
- # Compute/store important stats
160
- stage_stats = {"loss": stage_loss}
161
- if stage == sb.Stage.TRAIN:
162
- self.train_stats = stage_stats
163
- else:
164
- stage_stats["CER"] = self.cer_metric.summarize("error_rate")
165
- stage_stats["WER"] = self.wer_metric.summarize("error_rate")
166
-
167
- # Perform end-of-iteration things, like annealing, logging, etc.
168
- if stage == sb.Stage.VALID:
169
- old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
170
- stage_stats["loss"]
171
- )
172
- old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
173
- stage_stats["loss"]
174
- )
175
- sb.nnet.schedulers.update_learning_rate(
176
- self.model_optimizer, new_lr_model
177
- )
178
- if not self.hparams.wav2vec2.freeze:
179
- sb.nnet.schedulers.update_learning_rate(
180
- self.wav2vec_optimizer, new_lr_wav2vec
181
- )
182
- self.hparams.train_logger.log_stats(
183
- stats_meta={
184
- "epoch": epoch,
185
- "lr_model": old_lr_model,
186
- "lr_wav2vec": old_lr_wav2vec,
187
- },
188
- train_stats=self.train_stats,
189
- valid_stats=stage_stats,
190
- )
191
- self.checkpointer.save_and_keep_only(
192
- meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
193
- )
194
- elif stage == sb.Stage.TEST:
195
- self.hparams.train_logger.log_stats(
196
- stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
197
- test_stats=stage_stats,
198
- )
199
- with open(self.hparams.wer_file, "w") as w:
200
- self.wer_metric.write_stats(w)
201
-
202
- def init_optimizers(self):
203
- "Initializes the wav2vec2 optimizer and model optimizer"
204
-
205
- # If the wav2vec encoder is unfrozen, we create the optimizer
206
- if not self.hparams.wav2vec2.freeze:
207
- self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
208
- self.modules.wav2vec2.parameters()
209
- )
210
- if self.checkpointer is not None:
211
- self.checkpointer.add_recoverable(
212
- "wav2vec_opt", self.wav2vec_optimizer
213
- )
214
-
215
- self.model_optimizer = self.hparams.model_opt_class(
216
- self.hparams.model.parameters()
217
- )
218
-
219
- if self.checkpointer is not None:
220
- self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
221
-
222
- def zero_grad(self, set_to_none=False):
223
- if not self.hparams.wav2vec2.freeze:
224
- self.wav2vec_optimizer.zero_grad(set_to_none)
225
- self.model_optimizer.zero_grad(set_to_none)
226
-
227
-
228
- # Define custom data procedure
229
- def dataio_prepare(hparams):
230
- """This function prepares the datasets to be used in the brain class.
231
- It also defines the data processing pipeline through user-defined functions."""
232
-
233
- # 1. Define datasets
234
- data_folder = hparams["data_folder"]
235
-
236
- train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
237
- csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
238
- )
239
-
240
- if hparams["sorting"] == "ascending":
241
- # we sort training data to speed up training and get better results.
242
- train_data = train_data.filtered_sorted(
243
- sort_key="duration",
244
- key_max_value={"duration": hparams["avoid_if_longer_than"]},
245
- )
246
- # when sorting do not shuffle in dataloader ! otherwise is pointless
247
- hparams["dataloader_options"]["shuffle"] = False
248
-
249
- elif hparams["sorting"] == "descending":
250
- train_data = train_data.filtered_sorted(
251
- sort_key="duration",
252
- reverse=True,
253
- key_max_value={"duration": hparams["avoid_if_longer_than"]},
254
- )
255
- # when sorting do not shuffle in dataloader ! otherwise is pointless
256
- hparams["dataloader_options"]["shuffle"] = False
257
-
258
- elif hparams["sorting"] == "random":
259
- pass
260
-
261
- else:
262
- raise NotImplementedError(
263
- "sorting must be random, ascending or descending"
264
- )
265
-
266
- valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
267
- csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
268
- )
269
- # We also sort the validation data so it is faster to validate
270
- valid_data = valid_data.filtered_sorted(sort_key="duration")
271
- test_datasets = {}
272
- for csv_file in hparams["test_csv"]:
273
- name = Path(csv_file).stem
274
- test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
275
- csv_path=csv_file, replacements={"data_root": data_folder}
276
- )
277
- test_datasets[name] = test_datasets[name].filtered_sorted(
278
- sort_key="duration"
279
- )
280
-
281
- datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
282
-
283
-
284
- # 2. Define audio pipeline:
285
- @sb.utils.data_pipeline.takes("wav")
286
- @sb.utils.data_pipeline.provides("sig")
287
- def audio_pipeline(wav):
288
- info = torchaudio.info(wav)
289
- sig = sb.dataio.dataio.read_audio(wav)
290
- resampled = torchaudio.transforms.Resample(
291
- info.sample_rate, hparams["sample_rate"],
292
- )(sig)
293
- return resampled
294
-
295
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
296
- label_encoder = sb.dataio.encoder.CTCTextEncoder()
297
-
298
- # 3. Define text pipeline:
299
- @sb.utils.data_pipeline.takes("wrd")
300
- @sb.utils.data_pipeline.provides(
301
- "wrd", "char_list", "tokens_list", "tokens"
302
- )
303
- def text_pipeline(wrd):
304
- yield wrd
305
- char_list = list(wrd)
306
- yield char_list
307
- tokens_list = label_encoder.encode_sequence(char_list)
308
- yield tokens_list
309
- tokens = torch.LongTensor(tokens_list)
310
- yield tokens
311
-
312
- sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
313
- lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
314
- special_labels = {
315
- "blank_label": hparams["blank_index"],
316
- "unk_label": hparams["unk_index"]
317
- }
318
- label_encoder.load_or_create(
319
- path=lab_enc_file,
320
- from_didatasets=[train_data],
321
- output_key="char_list",
322
- special_labels=special_labels,
323
- sequence_input=True,
324
- )
325
-
326
- # 4. Set output:
327
- sb.dataio.dataset.set_output_keys(
328
- datasets, ["id", "sig", "wrd", "char_list", "tokens"],
329
- )
330
- return train_data, valid_data,test_datasets, label_encoder
331
-
332
-
333
- if __name__ == "__main__":
334
-
335
- # Load hyperparameters file with command-line overrides
336
- hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
337
- with open(hparams_file) as fin:
338
- hparams = load_hyperpyyaml(fin, overrides)
339
-
340
- # If --distributed_launch then
341
- # create ddp_group with the right communication protocol
342
- sb.utils.distributed.ddp_init_group(run_opts)
343
-
344
-
345
- # Create experiment directory
346
- sb.create_experiment_directory(
347
- experiment_directory=hparams["output_folder"],
348
- hyperparams_to_save=hparams_file,
349
- overrides=overrides,
350
- )
351
-
352
- # Due to DDP, we do the preparation ONLY on the main python process
353
- # Defining tokenizer and loading it
354
- # Create the datasets objects as well as tokenization and encoding :-D
355
- train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)
356
- if hparams["use_language_modelling"]:
357
- print("using langauge_modeeling")
358
- from pyctcdecode import build_ctcdecoder
359
- ind2lab = label_encoder.ind2lab
360
- print(ind2lab)
361
- labels = [ind2lab[x] for x in range(len(ind2lab))]
362
- labels = [""] + labels[1:-1] + ["1"]
363
- # Replace the <blank> token with a blank character, needed for PyCTCdecode
364
- print(labels)
365
- decoder = build_ctcdecoder(
366
- labels,
367
- kenlm_model_path=hparams["ngram_lm_path"], # .arpa or .bin
368
- alpha=0.5, # Default by KenLM
369
- beta=1.0, # Default by KenLM
370
- )
371
- # Trainer initialization
372
- asr_brain = ASR(
373
- modules=hparams["modules"],
374
- hparams=hparams,
375
- run_opts=run_opts,
376
- checkpointer=hparams["checkpointer"],
377
- )
378
-
379
- # Adding objects to trainer.
380
- asr_brain.tokenizer = label_encoder
381
-
382
- # Training
383
- asr_brain.fit(
384
- asr_brain.hparams.epoch_counter,
385
- train_data,
386
- valid_data,
387
- train_loader_kwargs=hparams["dataloader_options"],
388
- valid_loader_kwargs=hparams["test_dataloader_options"],
389
- )
390
-
391
- # Test
392
- for k in test_datasets.keys(): # keys are test_clean, test_other etc
393
- asr_brain.hparams.wer_file = os.path.join(
394
- hparams["output_folder"], "wer_{}.txt".format(k)
395
- )
396
- asr_brain.evaluate(
397
- test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
398
- )
399
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1234/wer_test.txt DELETED
The diff for this file is too large to render. See raw diff