Porjaz commited on
Commit
515be22
·
verified ·
1 Parent(s): 39793b6

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +335 -0
train.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env/python3
2
+
3
+ import logging
4
+ import sys
5
+ from pathlib import Path
6
+ import os
7
+
8
+ import librosa
9
+
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+ from hyperpyyaml import load_hyperpyyaml
13
+
14
+ import speechbrain as sb
15
+ from speechbrain.utils.distributed import if_main_process, run_on_main
16
+
17
+ from jiwer import wer, cer
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ # Define training procedure
23
+ class ASR(sb.Brain):
24
+ def compute_forward(self, batch, stage):
25
+ """Forward computations from the waveform batches to the output probabilities."""
26
+ batch = batch.to(self.device)
27
+ sig, self.sig_lens = batch.sig
28
+ tokens_bos, _ = batch.tokens_bos
29
+ sig, self.sig_lens = sig.to(self.device), self.sig_lens.to(self.device)
30
+
31
+ # Add waveform augmentation if specified.
32
+ if stage == sb.Stage.TRAIN:
33
+ sig, self.sig_lens = self.hparams.wav_augment(sig, self.sig_lens)
34
+
35
+ # Forward pass
36
+ encoded_outputs = self.modules.encoder_w2v2(sig.detach())
37
+ embedded_tokens = self.modules.embedding(tokens_bos)
38
+ decoder_outputs, _ = self.modules.decoder(embedded_tokens, encoded_outputs, self.sig_lens)
39
+
40
+ # Output layer for seq2seq log-probabilities
41
+ logits = self.modules.seq_lin(decoder_outputs)
42
+ predictions = {"seq_logprobs": self.hparams.log_softmax(logits)}
43
+
44
+ if self.is_ctc_active(stage):
45
+ # Output layer for ctc log-probabilities
46
+ ctc_logits = self.modules.ctc_lin(encoded_outputs)
47
+ predictions["ctc_logprobs"] = self.hparams.log_softmax(ctc_logits)
48
+ elif stage == sb.Stage.VALID:
49
+ predictions["tokens"], _, _, _ = self.hparams.greedy_search(encoded_outputs, self.sig_lens)
50
+ elif stage == sb.Stage.TEST:
51
+ predictions["tokens"], _, _, _ = self.hparams.test_search(encoded_outputs, self.sig_lens)
52
+
53
+ return predictions
54
+
55
+
56
+ def is_ctc_active(self, stage):
57
+ """Check if CTC is currently active.
58
+
59
+ Arguments
60
+ ---------
61
+ stage : sb.Stage
62
+ Currently executing stage.
63
+ """
64
+ if stage != sb.Stage.TRAIN:
65
+ return False
66
+ current_epoch = self.hparams.epoch_counter.current
67
+ return current_epoch <= self.hparams.number_of_ctc_epochs
68
+
69
+
70
+
71
+ def compute_objectives(self, predictions, batch, stage):
72
+ """Computes the loss (CTC+NLL) given predictions and targets."""
73
+ ids = batch.id
74
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
75
+ tokens, tokens_lens = batch.tokens
76
+
77
+ loss = self.hparams.nll_cost(log_probabilities=predictions["seq_logprobs"], targets=tokens_eos, length=tokens_eos_lens)
78
+
79
+ if self.is_ctc_active(stage):
80
+ # Load tokens without EOS as CTC targets
81
+ loss_ctc = self.hparams.ctc_cost(predictions["ctc_logprobs"], tokens, self.sig_lens, tokens_lens)
82
+ loss *= 1 - self.hparams.ctc_weight
83
+ loss += self.hparams.ctc_weight * loss_ctc
84
+
85
+ if stage != sb.Stage.TRAIN:
86
+ predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]]
87
+ target_words = [words.split(" ") for words in batch.transcript]
88
+ self.wer_metric.append(ids, predicted_words, target_words)
89
+ self.cer_metric.append(ids, predicted_words, target_words)
90
+
91
+ return loss
92
+
93
+ def on_stage_start(self, stage, epoch):
94
+ """Gets called at the beginning of each epoch"""
95
+ if stage != sb.Stage.TRAIN:
96
+ self.cer_metric = self.hparams.cer_computer()
97
+ self.wer_metric = self.hparams.error_rate_computer()
98
+
99
+ def on_stage_end(self, stage, stage_loss, epoch):
100
+ """Gets called at the end of a epoch."""
101
+ # Compute/store important stats
102
+ stage_stats = {"loss": stage_loss}
103
+ if stage == sb.Stage.TRAIN:
104
+ self.train_stats = stage_stats
105
+ else:
106
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
107
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
108
+
109
+ # Perform end-of-iteration things, like annealing, logging, etc.
110
+ if stage == sb.Stage.VALID:
111
+ old_lr, new_lr = self.hparams.lr_annealing(stage_stats["WER"])
112
+ sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
113
+ self.hparams.train_logger.log_stats(
114
+ stats_meta={"epoch": epoch, "lr": old_lr},
115
+ train_stats=self.train_stats,
116
+ valid_stats=stage_stats,
117
+ )
118
+ self.checkpointer.save_and_keep_only(
119
+ meta={"WER": stage_stats["WER"]},
120
+ min_keys=["WER"],
121
+ )
122
+ elif stage == sb.Stage.TEST:
123
+ self.hparams.train_logger.log_stats(
124
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
125
+ test_stats=stage_stats,
126
+ )
127
+ if if_main_process():
128
+ with open(self.hparams.test_wer_file, "w") as w:
129
+ self.wer_metric.write_stats(w)
130
+
131
+ def run_inference(
132
+ self,
133
+ dataset, # Must be obtained from the dataio_function
134
+ min_key, # We load the model with the lowest error rate
135
+ loader_kwargs, # opts for the dataloading
136
+ ):
137
+
138
+ # If dataset isn't a Dataloader, we create it.
139
+ if not isinstance(dataset, DataLoader):
140
+ loader_kwargs["ckpt_prefix"] = None
141
+ dataset = self.make_dataloader(
142
+ dataset, sb.Stage.TEST, **loader_kwargs
143
+ )
144
+
145
+ self.checkpointer.recover_if_possible(min_key=min_key)
146
+ self.modules.eval() # We set the model to eval mode (remove dropout etc)
147
+
148
+ with torch.no_grad():
149
+ true_labels = []
150
+ pred_labels = []
151
+ for batch in dataset:
152
+ # Make sure that your compute_forward returns the predictions !!!
153
+ # In the case of the template, when stage = TEST, a beam search is applied
154
+ # in compute_forward().
155
+ predictions = self.compute_forward(batch, stage=sb.Stage.TEST)
156
+
157
+ pred_batch = []
158
+ predicted_words = []
159
+
160
+ predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]]
161
+ for sent in predicted_words:
162
+ # sent = " ".join(sent)
163
+ sent = filter_repetitions(sent, 3)
164
+ sent = " ".join(sent)
165
+ pred_batch.append(sent)
166
+
167
+ pred_labels.append(pred_batch[0])
168
+ true_labels.append(batch.transcript[0])
169
+
170
+ print('WER: ', wer(true_labels, pred_labels) * 100)
171
+ print('CER: ', cer(true_labels, pred_labels) * 100)
172
+
173
+
174
+ def filter_repetitions(seq, max_repetition_length):
175
+ seq = list(seq)
176
+ output = []
177
+ max_n = len(seq) // 2
178
+ for n in range(max_n, 0, -1):
179
+ max_repetitions = max(max_repetition_length // n, 1)
180
+ # Don't need to iterate over impossible n values:
181
+ # len(seq) can change a lot during iteration
182
+ if (len(seq) <= n*2) or (len(seq) <= max_repetition_length):
183
+ continue
184
+ iterator = enumerate(seq)
185
+ # Fill first buffers:
186
+ buffers = [[next(iterator)[1]] for _ in range(n)]
187
+ for seq_index, token in iterator:
188
+ current_buffer = seq_index % n
189
+ if token != buffers[current_buffer][-1]:
190
+ # No repeat, we can flush some tokens
191
+ buf_len = sum(map(len, buffers))
192
+ flush_start = (current_buffer-buf_len) % n
193
+ # Keep n-1 tokens, but possibly mark some for removal
194
+ for flush_index in range(buf_len - buf_len%n):
195
+ if (buf_len - flush_index) > n-1:
196
+ to_flush = buffers[(flush_index + flush_start) % n].pop(0)
197
+ else:
198
+ to_flush = None
199
+ # Here, repetitions get removed:
200
+ if (flush_index // n < max_repetitions) and to_flush is not None:
201
+ output.append(to_flush)
202
+ elif (flush_index // n >= max_repetitions) and to_flush is None:
203
+ output.append(to_flush)
204
+ buffers[current_buffer].append(token)
205
+ # At the end, final flush
206
+ current_buffer += 1
207
+ buf_len = sum(map(len, buffers))
208
+ flush_start = (current_buffer-buf_len) % n
209
+ for flush_index in range(buf_len):
210
+ to_flush = buffers[(flush_index + flush_start) % n].pop(0)
211
+ # Here, repetitions just get removed:
212
+ if flush_index // n < max_repetitions:
213
+ output.append(to_flush)
214
+ seq = []
215
+ to_delete = 0
216
+ for token in output:
217
+ if token is None:
218
+ to_delete += 1
219
+ elif to_delete > 0:
220
+ to_delete -= 1
221
+ else:
222
+ seq.append(token)
223
+ output = []
224
+ return seq
225
+
226
+ def dataio_prepare(hparams):
227
+ """This function prepares the datasets to be used in the brain class.
228
+ It also defines the data processing pipeline through user-defined functions.
229
+ """
230
+ data_folder = hparams["data_folder"]
231
+
232
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "train.json"), replacements={"data_root": data_folder})
233
+ train_data = train_data.filtered_sorted(sort_key="duration")
234
+ hparams["train_dataloader_opts"]["shuffle"] = False
235
+
236
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "dev.json"), replacements={"data_root": data_folder})
237
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
238
+
239
+ test_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "test.json"), replacements={"data_root": data_folder})
240
+
241
+
242
+ datasets = [train_data, valid_data, test_data]
243
+
244
+ # We get the tokenizer as we need it to encode the labels when creating
245
+ # mini-batches.
246
+ tokenizer = hparams["tokenizer"]
247
+
248
+ # 2. Define audio pipeline:
249
+ @sb.utils.data_pipeline.takes("data_path")
250
+ @sb.utils.data_pipeline.provides("sig")
251
+ def audio_pipeline(data_path):
252
+ sig, sr = librosa.load(data_path, sr=16000)
253
+ # sig = sb.dataio.dataio.read_audio(wav) # alternatively use the SpeechBrain data loading function
254
+ return sig
255
+
256
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
257
+
258
+ # 3. Define text pipeline:
259
+ @sb.utils.data_pipeline.takes("transcript")
260
+ @sb.utils.data_pipeline.provides("transcript", "tokens_list", "tokens_bos", "tokens_eos", "tokens")
261
+ def text_pipeline(transcript):
262
+ yield transcript
263
+ tokens_list = tokenizer.encode_as_ids(transcript)
264
+ yield tokens_list
265
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
266
+ yield tokens_bos
267
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
268
+ yield tokens_eos
269
+ tokens = torch.LongTensor(tokens_list)
270
+ yield tokens
271
+
272
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
273
+
274
+ # 4. Set output:
275
+ sb.dataio.dataset.set_output_keys(datasets, ["id", "sig", "transcript", "tokens_list", "tokens_bos", "tokens_eos", "tokens"])
276
+
277
+ return (train_data, valid_data, test_data)
278
+
279
+
280
+ if __name__ == "__main__":
281
+ # CLI:
282
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
283
+
284
+ # create ddp_group with the right communication protocol
285
+ sb.utils.distributed.ddp_init_group(run_opts)
286
+
287
+ with open(hparams_file) as fin:
288
+ hparams = load_hyperpyyaml(fin, overrides)
289
+
290
+ # Create experiment directory
291
+ sb.create_experiment_directory(
292
+ experiment_directory=hparams["output_folder"],
293
+ hyperparams_to_save=hparams_file,
294
+ overrides=overrides,
295
+ )
296
+
297
+ # here we create the datasets objects as well as tokenization and encoding
298
+ (train_data, valid_data, test_data) = dataio_prepare(hparams)
299
+
300
+ run_on_main(hparams["pretrainer"].collect_files)
301
+ hparams["pretrainer"].load_collected()
302
+
303
+ # Trainer initialization
304
+ asr_brain = ASR(
305
+ modules=hparams["modules"],
306
+ opt_class=hparams["opt_class"],
307
+ hparams=hparams,
308
+ run_opts=run_opts,
309
+ checkpointer=hparams["checkpointer"],
310
+ )
311
+
312
+ # We dynamically add the tokenizer to our brain class.
313
+ # NB: This tokenizer corresponds to the one used for the LM!!
314
+ asr_brain.tokenizer = hparams["tokenizer"]
315
+ train_dataloader_opts = hparams["train_dataloader_opts"]
316
+ valid_dataloader_opts = hparams["valid_dataloader_opts"]
317
+
318
+
319
+ # Training/validation loop
320
+ if hparams["skip_training"] == False:
321
+ print("Training...")
322
+ # Training
323
+ asr_brain.fit(
324
+ asr_brain.hparams.epoch_counter,
325
+ train_data,
326
+ valid_data,
327
+ train_loader_kwargs=train_dataloader_opts,
328
+ valid_loader_kwargs=valid_dataloader_opts,
329
+ )
330
+
331
+ else:
332
+ # evaluate
333
+ print("Evaluating")
334
+ asr_brain.run_inference(test_data, "WER", hparams["test_dataloader_opts"])
335
+