KingNish commited on
Commit
e02e40c
·
verified ·
1 Parent(s): f72005d

Upload ./vocos/experiment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocos/experiment.py +371 -0
vocos/experiment.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import torchaudio
7
+ import transformers
8
+
9
+ from vocos.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
10
+ from vocos.feature_extractors import FeatureExtractor
11
+ from vocos.heads import FourierHead
12
+ from vocos.helpers import plot_spectrogram_to_numpy
13
+ from vocos.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss
14
+ from vocos.models import Backbone
15
+ from vocos.modules import safe_log
16
+
17
+
18
+ class VocosExp(pl.LightningModule):
19
+ # noinspection PyUnusedLocal
20
+ def __init__(
21
+ self,
22
+ feature_extractor: FeatureExtractor,
23
+ backbone: Backbone,
24
+ head: FourierHead,
25
+ sample_rate: int,
26
+ initial_learning_rate: float,
27
+ num_warmup_steps: int = 0,
28
+ mel_loss_coeff: float = 45,
29
+ mrd_loss_coeff: float = 1.0,
30
+ pretrain_mel_steps: int = 0,
31
+ decay_mel_coeff: bool = False,
32
+ evaluate_utmos: bool = False,
33
+ evaluate_pesq: bool = False,
34
+ evaluate_periodicty: bool = False,
35
+ ):
36
+ """
37
+ Args:
38
+ feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
39
+ backbone (Backbone): An instance of Backbone model.
40
+ head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
41
+ sample_rate (int): Sampling rate of the audio signals.
42
+ initial_learning_rate (float): Initial learning rate for the optimizer.
43
+ num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
44
+ mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
45
+ mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
46
+ pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
47
+ decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
48
+ evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
49
+ evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
50
+ evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
51
+ """
52
+ super().__init__()
53
+ self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"])
54
+
55
+ self.feature_extractor = feature_extractor
56
+ self.backbone = backbone
57
+ self.head = head
58
+
59
+ self.multiperioddisc = MultiPeriodDiscriminator()
60
+ self.multiresddisc = MultiResolutionDiscriminator()
61
+
62
+ self.disc_loss = DiscriminatorLoss()
63
+ self.gen_loss = GeneratorLoss()
64
+ self.feat_matching_loss = FeatureMatchingLoss()
65
+ self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
66
+
67
+ self.train_discriminator = False
68
+ self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
69
+
70
+ def configure_optimizers(self):
71
+ disc_params = [
72
+ {"params": self.multiperioddisc.parameters()},
73
+ {"params": self.multiresddisc.parameters()},
74
+ ]
75
+ gen_params = [
76
+ {"params": self.feature_extractor.parameters()},
77
+ {"params": self.backbone.parameters()},
78
+ {"params": self.head.parameters()},
79
+ ]
80
+
81
+ opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9))
82
+ opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9))
83
+
84
+ max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
85
+ scheduler_disc = transformers.get_cosine_schedule_with_warmup(
86
+ opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
87
+ )
88
+ scheduler_gen = transformers.get_cosine_schedule_with_warmup(
89
+ opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
90
+ )
91
+
92
+ return (
93
+ [opt_disc, opt_gen],
94
+ [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
95
+ )
96
+
97
+ def forward(self, audio_input, **kwargs):
98
+ features = self.feature_extractor(audio_input, **kwargs)
99
+ x = self.backbone(features, **kwargs)
100
+ audio_output = self.head(x)
101
+ return audio_output
102
+
103
+ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
104
+ audio_input = batch
105
+
106
+ # train discriminator
107
+ if optimizer_idx == 0 and self.train_discriminator:
108
+ with torch.no_grad():
109
+ audio_hat = self(audio_input, **kwargs)
110
+
111
+ real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
112
+ real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
113
+ loss_mp, loss_mp_real, _ = self.disc_loss(
114
+ disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
115
+ )
116
+ loss_mrd, loss_mrd_real, _ = self.disc_loss(
117
+ disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
118
+ )
119
+ loss_mp /= len(loss_mp_real)
120
+ loss_mrd /= len(loss_mrd_real)
121
+ loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd
122
+
123
+ self.log("discriminator/total", loss, prog_bar=True)
124
+ self.log("discriminator/multi_period_loss", loss_mp)
125
+ self.log("discriminator/multi_res_loss", loss_mrd)
126
+ return loss
127
+
128
+ # train generator
129
+ if optimizer_idx == 1:
130
+ audio_hat = self(audio_input, **kwargs)
131
+ if self.train_discriminator:
132
+ _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
133
+ y=audio_input, y_hat=audio_hat, **kwargs,
134
+ )
135
+ _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
136
+ y=audio_input, y_hat=audio_hat, **kwargs,
137
+ )
138
+ loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
139
+ loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
140
+ loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
141
+ loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
142
+ loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp)
143
+ loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd)
144
+
145
+ self.log("generator/multi_period_loss", loss_gen_mp)
146
+ self.log("generator/multi_res_loss", loss_gen_mrd)
147
+ self.log("generator/feature_matching_mp", loss_fm_mp)
148
+ self.log("generator/feature_matching_mrd", loss_fm_mrd)
149
+ else:
150
+ loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0
151
+
152
+ mel_loss = self.melspec_loss(audio_hat, audio_input)
153
+ loss = (
154
+ loss_gen_mp
155
+ + self.hparams.mrd_loss_coeff * loss_gen_mrd
156
+ + loss_fm_mp
157
+ + self.hparams.mrd_loss_coeff * loss_fm_mrd
158
+ + self.mel_loss_coeff * mel_loss
159
+ )
160
+
161
+ self.log("generator/total_loss", loss, prog_bar=True)
162
+ self.log("mel_loss_coeff", self.mel_loss_coeff)
163
+ self.log("generator/mel_loss", mel_loss)
164
+
165
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
166
+ self.logger.experiment.add_audio(
167
+ "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate
168
+ )
169
+ self.logger.experiment.add_audio(
170
+ "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate
171
+ )
172
+ with torch.no_grad():
173
+ mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
174
+ mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0]))
175
+ self.logger.experiment.add_image(
176
+ "train/mel_target",
177
+ plot_spectrogram_to_numpy(mel.data.cpu().numpy()),
178
+ self.global_step,
179
+ dataformats="HWC",
180
+ )
181
+ self.logger.experiment.add_image(
182
+ "train/mel_pred",
183
+ plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
184
+ self.global_step,
185
+ dataformats="HWC",
186
+ )
187
+
188
+ return loss
189
+
190
+ def on_validation_epoch_start(self):
191
+ if self.hparams.evaluate_utmos:
192
+ from metrics.UTMOS import UTMOSScore
193
+
194
+ if not hasattr(self, "utmos_model"):
195
+ self.utmos_model = UTMOSScore(device=self.device)
196
+
197
+ def validation_step(self, batch, batch_idx, **kwargs):
198
+ audio_input = batch
199
+ audio_hat = self(audio_input, **kwargs)
200
+
201
+ audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000)
202
+ audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000)
203
+
204
+ if self.hparams.evaluate_periodicty:
205
+ from metrics.periodicity import calculate_periodicity_metrics
206
+
207
+ periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
208
+ else:
209
+ periodicity_loss = pitch_loss = f1_score = 0
210
+
211
+ if self.hparams.evaluate_utmos:
212
+ utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
213
+ else:
214
+ utmos_score = torch.zeros(1, device=self.device)
215
+
216
+ if self.hparams.evaluate_pesq:
217
+ from pesq import pesq
218
+
219
+ pesq_score = 0
220
+ for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()):
221
+ pesq_score += pesq(16000, ref, deg, "wb", on_error=1)
222
+ pesq_score /= len(audio_16_khz)
223
+ pesq_score = torch.tensor(pesq_score)
224
+ else:
225
+ pesq_score = torch.zeros(1, device=self.device)
226
+
227
+ mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
228
+ total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score)
229
+
230
+ return {
231
+ "val_loss": total_loss,
232
+ "mel_loss": mel_loss,
233
+ "utmos_score": utmos_score,
234
+ "pesq_score": pesq_score,
235
+ "periodicity_loss": periodicity_loss,
236
+ "pitch_loss": pitch_loss,
237
+ "f1_score": f1_score,
238
+ "audio_input": audio_input[0],
239
+ "audio_pred": audio_hat[0],
240
+ }
241
+
242
+ def validation_epoch_end(self, outputs):
243
+ if self.global_rank == 0:
244
+ *_, audio_in, audio_pred = outputs[0].values()
245
+ self.logger.experiment.add_audio(
246
+ "val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
247
+ )
248
+ self.logger.experiment.add_audio(
249
+ "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
250
+ )
251
+ mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
252
+ mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
253
+ self.logger.experiment.add_image(
254
+ "val_mel_target",
255
+ plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()),
256
+ self.global_step,
257
+ dataformats="HWC",
258
+ )
259
+ self.logger.experiment.add_image(
260
+ "val_mel_hat",
261
+ plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
262
+ self.global_step,
263
+ dataformats="HWC",
264
+ )
265
+ avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
266
+ mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean()
267
+ utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean()
268
+ pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean()
269
+ periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean()
270
+ pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean()
271
+ f1_score = np.array([x["f1_score"] for x in outputs]).mean()
272
+
273
+ self.log("val_loss", avg_loss, sync_dist=True)
274
+ self.log("val/mel_loss", mel_loss, sync_dist=True)
275
+ self.log("val/utmos_score", utmos_score, sync_dist=True)
276
+ self.log("val/pesq_score", pesq_score, sync_dist=True)
277
+ self.log("val/periodicity_loss", periodicity_loss, sync_dist=True)
278
+ self.log("val/pitch_loss", pitch_loss, sync_dist=True)
279
+ self.log("val/f1_score", f1_score, sync_dist=True)
280
+
281
+ @property
282
+ def global_step(self):
283
+ """
284
+ Override global_step so that it returns the total number of batches processed
285
+ """
286
+ return self.trainer.fit_loop.epoch_loop.total_batch_idx
287
+
288
+ def on_train_batch_start(self, *args):
289
+ if self.global_step >= self.hparams.pretrain_mel_steps:
290
+ self.train_discriminator = True
291
+ else:
292
+ self.train_discriminator = False
293
+
294
+ def on_train_batch_end(self, *args):
295
+ def mel_loss_coeff_decay(current_step, num_cycles=0.5):
296
+ max_steps = self.trainer.max_steps // 2
297
+ if current_step < self.hparams.num_warmup_steps:
298
+ return 1.0
299
+ progress = float(current_step - self.hparams.num_warmup_steps) / float(
300
+ max(1, max_steps - self.hparams.num_warmup_steps)
301
+ )
302
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
303
+
304
+ if self.hparams.decay_mel_coeff:
305
+ self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)
306
+
307
+
308
+ class VocosEncodecExp(VocosExp):
309
+ """
310
+ VocosEncodecExp is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
311
+ It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
312
+ a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
313
+ while during validation, a fixed bandwidth_id is used.
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ feature_extractor: FeatureExtractor,
319
+ backbone: Backbone,
320
+ head: FourierHead,
321
+ sample_rate: int,
322
+ initial_learning_rate: float,
323
+ num_warmup_steps: int,
324
+ mel_loss_coeff: float = 45,
325
+ mrd_loss_coeff: float = 1.0,
326
+ pretrain_mel_steps: int = 0,
327
+ decay_mel_coeff: bool = False,
328
+ evaluate_utmos: bool = False,
329
+ evaluate_pesq: bool = False,
330
+ evaluate_periodicty: bool = False,
331
+ ):
332
+ super().__init__(
333
+ feature_extractor,
334
+ backbone,
335
+ head,
336
+ sample_rate,
337
+ initial_learning_rate,
338
+ num_warmup_steps,
339
+ mel_loss_coeff,
340
+ mrd_loss_coeff,
341
+ pretrain_mel_steps,
342
+ decay_mel_coeff,
343
+ evaluate_utmos,
344
+ evaluate_pesq,
345
+ evaluate_periodicty,
346
+ )
347
+ # Override with conditional discriminators
348
+ self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
349
+ self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
350
+
351
+ def training_step(self, *args):
352
+ bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,)
353
+ output = super().training_step(*args, bandwidth_id=bandwidth_id)
354
+ return output
355
+
356
+ def validation_step(self, *args):
357
+ bandwidth_id = torch.tensor([0], device=self.device)
358
+ output = super().validation_step(*args, bandwidth_id=bandwidth_id)
359
+ return output
360
+
361
+ def validation_epoch_end(self, outputs):
362
+ if self.global_rank == 0:
363
+ *_, audio_in, _ = outputs[0].values()
364
+ # Resynthesis with encodec for reference
365
+ self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
366
+ encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
367
+ self.logger.experiment.add_audio(
368
+ "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate,
369
+ )
370
+
371
+ super().validation_epoch_end(outputs)