maskgct / models /tts /valle_v2 /valle_nar_trainer.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torchaudio
import numpy as np
import time
from .valle_ar_trainer import ValleARTrainer, make_pad_mask
class ValleNARTrainer(ValleARTrainer):
def __init__(self, args=None, cfg=None):
super().__init__(args, cfg)
print("simple NAR")
self.top1_accuracies = {
1: [],
2: [],
3: [],
4: [],
5: [],
6: [],
7: [],
}
self.top5_accuracies = {
1: [],
2: [],
3: [],
4: [],
5: [],
6: [],
7: [],
}
self.top10_accuracies = {
1: [],
2: [],
3: [],
4: [],
5: [],
6: [],
7: [],
}
def _build_model(self):
from .valle_nar import ValleNAR
return ValleNAR(**self.cfg.model)
def _train_step(self, batch):
# inference codec
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
speech: [B, T]
speech_len: [B]
phone_ids: [B, T]
phone_lens: [B]
"""
device = self.accelerator.device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
with torch.no_grad():
if self.cfg.use_speechtokenizer:
# Extract discrete codes from SpeechTokenizer
# 16k
vq_id = self.codec_encoder.encode(
batch["speech"].unsqueeze(1)
) # [B,T] -> (n_q, B, T)
# RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens
# RVQ_supplement = codes[1:, :, :] # Contain timbre info, complete info lost by the first quantizer
# Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding
# wav = self.codec_encoder.decode(vq_id)
# torchaudio.save('a.wav', wav[0].cpu(), 16000)
# # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers
# wav = model.decode(codes[i: (j + 1)], st=i)
else:
# using encodec, 24k
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
0, 1
)
# recovered_audio = self.codec_decoder(vq_emb, vq=False)
# torchaudio.save('a.wav', recovered_audio[0], 16000)
# vq_id: [8, B, T//320]
batch["speech"] = vq_id
batch["speech_len"] = batch["speech_len"] // 320 # our codec downsamples 320x
assert batch["speech_len"].max() <= batch["speech"].shape[-1]
phone_mask = 1 - make_pad_mask(
batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False
).to(torch.long)
speech_mask = 1 - make_pad_mask(
batch["speech_len"], max_len=batch["speech"].size(-1)
).to(torch.long)
np.random.seed(int(time.time()) - 5 * self.accelerator.process_index)
if hasattr(self.cfg.train, "dropout"):
dropout = self.cfg.train.dropout
else:
dropout = 0.0
out = self.model(
phone_ids=batch["phone_ids"],
phone_mask=phone_mask,
target_ids=batch["speech"],
target_mask=speech_mask,
dropout=dropout,
)
loss = out.loss
self.accelerator.log(
{f"Train/NAR L{out.target_quantization_layer} Top1 acc": out.top1_acc},
step=self.step,
)
self.accelerator.log(
{f"Train/NAR L{out.target_quantization_layer} Top5 acc": out.top5_acc},
step=self.step,
)
self.accelerator.log(
{f"Train/NAR L{out.target_quantization_layer} Top10 acc": out.top10_acc},
step=self.step,
)
# if hasattr(out, 'top1_acc'):
# idx = out.target_quantization_layer
# self.top1_accuracies[idx].append(out.top1_acc)
# self.top5_accuracies[idx].append(out.top5_acc)
# self.top10_accuracies[idx].append(out.top10_acc)
# if len(self.top1_accuracies[idx]) >= 160:
# breakpoint()
# if self.accelerator.is_main_process:
# print(loss)
return loss
def _test_step(self, batch):
# inference codec
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
speech: [B, T]
speech_len: [B]
phone_ids: [B, T]
phone_lens: [B]
"""
import torchaudio
device = self.accelerator.device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
with torch.no_grad():
if self.cfg.use_speechtokenizer:
# Extract discrete codes from SpeechTokenizer
# 16k
vq_id = self.codec_encoder.encode(
batch["speech"].unsqueeze(1)
) # [B,1,T] -> (n_q, B, T)
# Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding
# wav = self.codec_encoder.decode(vq_id)
# torchaudio.save('a.wav', wav[0].cpu(), 16000)
else:
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
0, 1
)
# recovered_audio = self.codec_encoder.decode([(vq_id.transpose(0,1), None)])
# recovered_audio = self.codec_decoder(vq_emb, vq=False)
# torchaudio.save('a.wav', recovered_audio[0], 16000)
# vq_id: [8, B, T//200]
# vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1)
# recovered_audio = self.codec_decoder(vq_emb, vq=False)
# recovered_audio.shape: torch.Size([1, 1, 50200])
batch["speech"] = vq_id
# save gt
if self.cfg.use_speechtokenizer:
recovered_audio = self.codec_encoder.decode(vq_id)
else:
recovered_audio = self.codec_encoder.decode(
[(vq_id.transpose(0, 1), None)]
)
torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000)
self.model.eval()
out_vq_ids = self.model.sample_hf(
phone_ids=batch["phone_ids"][:1],
prompt_ids=batch["speech"][:, :1, :150],
first_stage_ids=batch["speech"][0, :1, 150:],
)
# breakpoint()
# out_vq_ids = torch.cat([batch['speech'][:, :225], out_vq_ids], dim=1)
# reconstruct form tokens
if self.cfg.use_speechtokenizer:
recovered_audio = self.codec_encoder.decode(out_vq_ids)
else:
recovered_audio = self.codec_encoder.decode(
[(out_vq_ids.transpose(0, 1)[:1], None)]
)
torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000)
breakpoint()