Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import copy | |
import json | |
import torch | |
import random | |
import argparse | |
import subprocess | |
import numpy as np | |
import soundfile as sf | |
import subprocess | |
import concurrent.futures | |
from vita.model.vita_tts.decoder.decoder import LLM2TTSCodecAR | |
from vita.model.vita_tts.decoder.ticodec.vqvae_tester import VqvaeTester | |
class llm2TTS(): | |
def __init__(self, model_path): | |
self.model = self.get_model(model_path).cuda().to( | |
torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
) | |
self.infer = self.model.infer | |
self.codec_model = VqvaeTester(config_path=model_path + "/codec/model.json", | |
model_path=model_path + "/codec/final.pt", | |
sample_rate=24000) | |
self.codec_model = self.codec_model.cuda() | |
self.codec_model.vqvae.generator.remove_weight_norm() | |
self.codec_model.vqvae.encoder.remove_weight_norm() | |
self.codec_model.eval() | |
def get_model_conf(self, model_path): | |
model_conf = model_path + "/decoder/model.json" | |
with open(model_conf, "rb") as f: | |
print('reading a config file from ' + model_conf) | |
confs = json.load(f) | |
# for asr, tts, mt | |
idim, odim, args = confs | |
return argparse.Namespace(**args) | |
def get_model(self, model_path): | |
args_load = self.get_model_conf(model_path) | |
args_load = vars(args_load) | |
args = argparse.Namespace(**args_load) | |
odim = args.odim | |
idim = args.idim | |
model = LLM2TTSCodecAR(idim, odim, args) | |
# Resume from a snapshot | |
snapshot_dict = torch.load(model_path + "/decoder/final.pt", map_location=lambda storage, loc: storage) | |
if 'model' in snapshot_dict.keys(): | |
resume_model_dict = snapshot_dict['model'] | |
else: | |
resume_model_dict = snapshot_dict | |
model_dict = model.state_dict() | |
for key in model_dict.keys(): | |
if key in resume_model_dict.keys(): | |
if model_dict[key].shape == resume_model_dict[key].shape: | |
model_dict[key] = resume_model_dict[key] | |
else: | |
print('Key {} has different shape, {} VS {}'.format(key, model_dict[key].shape, | |
resume_model_dict[key].shape)) | |
else: | |
print('Key {} has not in resume model'.format(key)) | |
model.load_state_dict(model_dict) | |
model.eval() | |
return model | |
def find_min_sum_index(self, buffer, syn, N, threshold): | |
""" | |
Find the index with the minimum sum of a sliding window in the given audio segment | |
and perform operations based on this index. | |
Parameters: | |
- buffer (torch.Tensor): The buffer containing previously processed audio segments. | |
- syn (torch.Tensor): The current audio segment to be processed. | |
- N (int): The size of the sliding window used to calculate the sum. | |
- threshold (float): Threshold value to determine whether to concatenate buffer and current segment or not. | |
Returns: | |
- tuple: A tuple containing the updated buffer and the processed audio segment. | |
""" | |
arr = syn[0, 0, :] | |
L = len(arr) | |
mid = L // 2 | |
kernel = torch.ones(N).to(arr.device) | |
window_sums = torch.nn.functional.conv1d(arr.abs().view(1, 1, -1), kernel.view(1, 1, -1), padding=0).squeeze() | |
start_index = mid - (N // 2) | |
min_sum, min_index = torch.min(window_sums[start_index:], dim=0) | |
# get the start and end index of the window | |
start_index = max(0, min_index.item() + start_index) | |
end_index = min(L, min_index.item() + N + start_index) | |
# calculate the real min_sum and min_index | |
min_sum_real, min_index_real = torch.min(arr[start_index: end_index].abs(), dim=0) | |
min_index = min_index_real.item() + start_index | |
min_sum = min_sum / N | |
syn_clone = syn.clone() | |
if min_sum < threshold: | |
syn = torch.cat([buffer.clone(), syn[:, :, :min_index]], dim=-1) | |
buffer = syn_clone[:, :, min_index:] | |
else: | |
buffer = torch.cat([buffer, syn_clone], dim=-1) | |
syn = None | |
return buffer, syn | |
def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10, | |
penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01): | |
""" | |
Run the speech decoder process. | |
Parameters: | |
- hidden (torch.Tensor): The output for embedding layer of the language model. | |
- top_k (int): The number of top-k tokens to consider during inference. | |
- prefix (str, optional): The hidden state from the language model. | |
- codec_chunk_size (int, default=40): The size of each chunk to process in the codec model. | |
- codec_padding_size (int, default=10): The amount of padding to add on each side of the codec chunk. | |
- penalty_window_size (int, default=20): The window size for applying penalties during decoding. | |
- penalty (float, default=1.1): The penalty factor. | |
Yields: | |
- torch.Tensor: Intermediate audio segments generated by the codec model. | |
""" | |
codec_upsample_rate = 600 | |
left_padding = 0 | |
right_padding = codec_padding_size | |
prefix = None | |
buffer = torch.zeros([1, 1, 0]).to(hidden.device) | |
with torch.no_grad(): | |
with torch.autocast(device_type="cuda", | |
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32): | |
print("Starting TTS...") | |
token = torch.full((1, 0), self.model.vocab_size, dtype=torch.long, device=hidden.device) | |
for next_token_id in self.infer(hidden, top_k, prefix, penalty_window_size, penalty): | |
token = torch.cat([token, next_token_id], dim=-1) | |
if token.size(1) == left_padding + codec_chunk_size + right_padding: | |
syn = self.codec_model.vqvae(token.unsqueeze(-1), | |
torch.tensor(self.codec_model.vqvae.h.global_tokens, | |
device=token.device).unsqueeze(0).unsqueeze(0)) | |
print("Codec Done") | |
syn = syn[:, :, left_padding * codec_upsample_rate: -right_padding * codec_upsample_rate] | |
left_padding = codec_padding_size | |
token = token[:, -(left_padding + right_padding):] | |
buffer, syn = self.find_min_sum_index(buffer, syn, N, seg_threshold) | |
if syn is not None: | |
yield syn | |
if token.size(1) > 0: | |
print("Codec Done") | |
syn = self.codec_model.vqvae(token.unsqueeze(-1), | |
torch.tensor(self.codec_model.vqvae.h.global_tokens, | |
device=token.device).unsqueeze(0).unsqueeze(0)) | |
syn = syn[:, :, left_padding * codec_upsample_rate:] | |
yield torch.cat([buffer, syn], dim=-1) | |