Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torchaudio | |
class ValleInference(torch.nn.Module): | |
def __init__( | |
self, | |
use_vocos=False, | |
use_speechtokenizer=True, | |
ar_path=None, | |
nar_path=None, | |
speechtokenizer_path=None, | |
device="cuda", | |
): | |
super().__init__() | |
self.device = device | |
# prepare pretrained VALLE AR model | |
from .valle_ar import ValleAR | |
self.ar_model = ValleAR( | |
phone_vocab_size=300, | |
target_vocab_size=1024, | |
pad_token_id=1324, | |
bos_target_id=1325, | |
eos_target_id=1326, | |
bos_phone_id=1327, | |
eos_phone_id=1328, | |
bos_prompt_id=1329, | |
eos_prompt_id=1330, | |
num_hidden_layers=16, | |
) | |
# change the following path to your trained model path | |
assert ar_path is not None | |
self.ar_model.load_state_dict(torch.load(ar_path, map_location="cpu")) | |
self.ar_model.eval().to(self.device) | |
# prepare pretrained VALLE NAR model | |
from .valle_nar import ValleNAR | |
self.nar_model = ValleNAR( | |
phone_vocab_size=300, | |
target_vocab_size=1024, | |
pad_token_id=1324, | |
bos_target_id=1325, | |
eos_target_id=1326, | |
bos_phone_id=1327, | |
eos_phone_id=1328, | |
bos_prompt_id=1329, | |
eos_prompt_id=1330, | |
num_hidden_layers=16, | |
) | |
assert nar_path is not None | |
self.nar_model.load_state_dict(torch.load(nar_path, map_location="cpu")) | |
self.nar_model.eval().to(self.device) | |
# prepare codec encoder | |
assert not ( | |
use_speechtokenizer and use_vocos | |
), "Only one of use_speechtokenizer and use_vocos can be True" | |
self.use_speechtokenizer = use_speechtokenizer | |
if use_speechtokenizer: | |
from models.codec.speechtokenizer.model import SpeechTokenizer | |
# download from https://huggingface.co./fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg | |
config_path = speechtokenizer_path + "/config.json" | |
ckpt_path = speechtokenizer_path + "/SpeechTokenizer.pt" | |
self.codec_encoder = SpeechTokenizer.load_from_checkpoint( | |
config_path, ckpt_path | |
) | |
self.codec_encoder.eval() | |
self.codec_encoder.to(device) | |
print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}") | |
else: | |
# use Encodec | |
from encodec import EncodecModel | |
self.codec_encoder = EncodecModel.encodec_model_24khz() | |
self.codec_encoder.set_target_bandwidth(6.0) | |
self.codec_encoder.to(self.device) | |
if use_vocos: | |
from vocos import Vocos | |
self.codec_decoder = Vocos.from_pretrained( | |
"charactr/vocos-encodec-24khz" | |
) | |
self.codec_decoder.to(self.device) | |
print("Loaded Vocos") | |
print("Loaded EncodecModel") | |
self.use_vocos = use_vocos | |
def decode(self, vq_ids): | |
"""vq_ids.shape: [8, B, T], | |
returns: [B, 1, T]""" | |
if self.use_speechtokenizer: | |
# infer speechtokenizer | |
return self.codec_encoder.decode(vq_ids) # [B, 1, T] | |
else: | |
if not self.use_vocos: | |
# vocos decoder | |
return self.codec_encoder.decode([(vq_ids.transpose(0, 1), None)]) | |
else: | |
# encodec decoder | |
features = self.codec_decoder.codes_to_features(vq_ids.squeeze(1)) | |
bandwidth_id = torch.tensor([2], device=vq_ids.device) | |
return self.codec_decoder.decode( | |
features, bandwidth_id=bandwidth_id | |
).unsqueeze(0) | |
def forward(self, batch, chunk_configs: list, return_prompt=False, prompt_len=None): | |
"""batch: dict( | |
speech: [B, T] | |
phone_ids: [B, T] | |
) | |
returns: [B, 1, T] audio | |
""" | |
if prompt_len is None: | |
prompt_len = 100000 # no prompt length limiting | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.to(self.device) | |
with torch.no_grad(): | |
if self.use_speechtokenizer: | |
vq_id = self.codec_encoder.encode( | |
batch["speech"].unsqueeze(1) | |
) # [B,1,T] -> (n_q, B, T) | |
else: | |
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) | |
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( | |
0, 1 | |
) | |
# typically we only require one config in the chunk, | |
# but we can also use multiple configs to, for example, use different sampling temperature at different positions | |
for chunk in chunk_configs: | |
ar_vq_ids = self.ar_model.sample_hf( | |
batch["phone_ids"], | |
vq_id[0, :, :prompt_len], | |
top_p=chunk["top_p"], | |
top_k=chunk["top_k"], | |
temperature=chunk["temperature"], | |
num_beams=chunk["num_beams"], | |
repeat_penalty=chunk["repeat_penalty"], | |
max_length=chunk["max_length"], | |
) | |
# recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0)) | |
# torchaudio.save('recovered_audio_ar.wav', recovered_audio_ar[0].cpu(), 24000) | |
nar_vq_ids = self.nar_model.sample_hf( | |
phone_ids=batch["phone_ids"], | |
prompt_ids=vq_id[:, :, :prompt_len], | |
first_stage_ids=ar_vq_ids, | |
# first_stage_ids=vq_id[0, :, prompt_len:], | |
) | |
if return_prompt: | |
nar_vq_ids = torch.cat( | |
[vq_id[..., :prompt_len], nar_vq_ids], dim=-1 | |
) | |
recovered_audio = self.decode(nar_vq_ids) | |
return recovered_audio # [B, 1, T] | |