lxysl's picture
upload vita-1.5 app.py
bc752b1
import os
import sys
import copy
import json
import torch
import random
import argparse
import subprocess
import numpy as np
import soundfile as sf
import subprocess
import concurrent.futures
from vita.model.vita_tts.decoder.decoder import LLM2TTSCodecAR
from vita.model.vita_tts.decoder.ticodec.vqvae_tester import VqvaeTester
class llm2TTS():
def __init__(self, model_path):
self.model = self.get_model(model_path).cuda().to(
torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
)
self.infer = self.model.infer
self.codec_model = VqvaeTester(config_path=model_path + "/codec/model.json",
model_path=model_path + "/codec/final.pt",
sample_rate=24000)
self.codec_model = self.codec_model.cuda()
self.codec_model.vqvae.generator.remove_weight_norm()
self.codec_model.vqvae.encoder.remove_weight_norm()
self.codec_model.eval()
def get_model_conf(self, model_path):
model_conf = model_path + "/decoder/model.json"
with open(model_conf, "rb") as f:
print('reading a config file from ' + model_conf)
confs = json.load(f)
# for asr, tts, mt
idim, odim, args = confs
return argparse.Namespace(**args)
def get_model(self, model_path):
args_load = self.get_model_conf(model_path)
args_load = vars(args_load)
args = argparse.Namespace(**args_load)
odim = args.odim
idim = args.idim
model = LLM2TTSCodecAR(idim, odim, args)
# Resume from a snapshot
snapshot_dict = torch.load(model_path + "/decoder/final.pt", map_location=lambda storage, loc: storage)
if 'model' in snapshot_dict.keys():
resume_model_dict = snapshot_dict['model']
else:
resume_model_dict = snapshot_dict
model_dict = model.state_dict()
for key in model_dict.keys():
if key in resume_model_dict.keys():
if model_dict[key].shape == resume_model_dict[key].shape:
model_dict[key] = resume_model_dict[key]
else:
print('Key {} has different shape, {} VS {}'.format(key, model_dict[key].shape,
resume_model_dict[key].shape))
else:
print('Key {} has not in resume model'.format(key))
model.load_state_dict(model_dict)
model.eval()
return model
def find_min_sum_index(self, buffer, syn, N, threshold):
"""
Find the index with the minimum sum of a sliding window in the given audio segment
and perform operations based on this index.
Parameters:
- buffer (torch.Tensor): The buffer containing previously processed audio segments.
- syn (torch.Tensor): The current audio segment to be processed.
- N (int): The size of the sliding window used to calculate the sum.
- threshold (float): Threshold value to determine whether to concatenate buffer and current segment or not.
Returns:
- tuple: A tuple containing the updated buffer and the processed audio segment.
"""
arr = syn[0, 0, :]
L = len(arr)
mid = L // 2
kernel = torch.ones(N).to(arr.device)
window_sums = torch.nn.functional.conv1d(arr.abs().view(1, 1, -1), kernel.view(1, 1, -1), padding=0).squeeze()
start_index = mid - (N // 2)
min_sum, min_index = torch.min(window_sums[start_index:], dim=0)
# get the start and end index of the window
start_index = max(0, min_index.item() + start_index)
end_index = min(L, min_index.item() + N + start_index)
# calculate the real min_sum and min_index
min_sum_real, min_index_real = torch.min(arr[start_index: end_index].abs(), dim=0)
min_index = min_index_real.item() + start_index
min_sum = min_sum / N
syn_clone = syn.clone()
if min_sum < threshold:
syn = torch.cat([buffer.clone(), syn[:, :, :min_index]], dim=-1)
buffer = syn_clone[:, :, min_index:]
else:
buffer = torch.cat([buffer, syn_clone], dim=-1)
syn = None
return buffer, syn
def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10,
penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01):
"""
Run the speech decoder process.
Parameters:
- hidden (torch.Tensor): The output for embedding layer of the language model.
- top_k (int): The number of top-k tokens to consider during inference.
- prefix (str, optional): The hidden state from the language model.
- codec_chunk_size (int, default=40): The size of each chunk to process in the codec model.
- codec_padding_size (int, default=10): The amount of padding to add on each side of the codec chunk.
- penalty_window_size (int, default=20): The window size for applying penalties during decoding.
- penalty (float, default=1.1): The penalty factor.
Yields:
- torch.Tensor: Intermediate audio segments generated by the codec model.
"""
codec_upsample_rate = 600
left_padding = 0
right_padding = codec_padding_size
prefix = None
buffer = torch.zeros([1, 1, 0]).to(hidden.device)
with torch.no_grad():
with torch.autocast(device_type="cuda",
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32):
print("Starting TTS...")
token = torch.full((1, 0), self.model.vocab_size, dtype=torch.long, device=hidden.device)
for next_token_id in self.infer(hidden, top_k, prefix, penalty_window_size, penalty):
token = torch.cat([token, next_token_id], dim=-1)
if token.size(1) == left_padding + codec_chunk_size + right_padding:
syn = self.codec_model.vqvae(token.unsqueeze(-1),
torch.tensor(self.codec_model.vqvae.h.global_tokens,
device=token.device).unsqueeze(0).unsqueeze(0))
print("Codec Done")
syn = syn[:, :, left_padding * codec_upsample_rate: -right_padding * codec_upsample_rate]
left_padding = codec_padding_size
token = token[:, -(left_padding + right_padding):]
buffer, syn = self.find_min_sum_index(buffer, syn, N, seg_threshold)
if syn is not None:
yield syn
if token.size(1) > 0:
print("Codec Done")
syn = self.codec_model.vqvae(token.unsqueeze(-1),
torch.tensor(self.codec_model.vqvae.h.global_tokens,
device=token.device).unsqueeze(0).unsqueeze(0))
syn = syn[:, :, left_padding * codec_upsample_rate:]
yield torch.cat([buffer, syn], dim=-1)