File size: 7,416 Bytes
bc752b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import sys
import copy
import json
import torch
import random
import argparse
import subprocess
import numpy as np
import soundfile as sf
import subprocess
import concurrent.futures

from vita.model.vita_tts.decoder.decoder import LLM2TTSCodecAR
from vita.model.vita_tts.decoder.ticodec.vqvae_tester import VqvaeTester

class llm2TTS():
    def __init__(self, model_path):
        self.model = self.get_model(model_path).cuda().to(
                                    torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
                                    )
        self.infer = self.model.infer

        self.codec_model = VqvaeTester(config_path=model_path + "/codec/model.json", 
                                        model_path=model_path + "/codec/final.pt",
                                        sample_rate=24000)
        self.codec_model = self.codec_model.cuda()
        self.codec_model.vqvae.generator.remove_weight_norm()
        self.codec_model.vqvae.encoder.remove_weight_norm()
        self.codec_model.eval()

    def get_model_conf(self, model_path):
        model_conf = model_path + "/decoder/model.json"
        with open(model_conf, "rb") as f:
            print('reading a config file from ' + model_conf)
            confs = json.load(f)
        # for asr, tts, mt
        idim, odim, args = confs
        return argparse.Namespace(**args)

    def get_model(self, model_path):
        args_load = self.get_model_conf(model_path)
        args_load = vars(args_load)
        args = argparse.Namespace(**args_load)
        odim = args.odim
        idim = args.idim
        model = LLM2TTSCodecAR(idim, odim, args)

        # Resume from a snapshot
        snapshot_dict = torch.load(model_path + "/decoder/final.pt", map_location=lambda storage, loc: storage)
        if 'model' in snapshot_dict.keys():
            resume_model_dict = snapshot_dict['model']
        else:
            resume_model_dict = snapshot_dict
        
        model_dict = model.state_dict()
        for key in model_dict.keys():
            if key in resume_model_dict.keys():
                if model_dict[key].shape == resume_model_dict[key].shape:
                    model_dict[key] = resume_model_dict[key]
                else:
                    print('Key {} has different shape, {} VS {}'.format(key, model_dict[key].shape, 
                                                                        resume_model_dict[key].shape))
            else:
                print('Key {} has not in resume model'.format(key))
        model.load_state_dict(model_dict)
        model.eval()
        return model
    
    def find_min_sum_index(self, buffer, syn, N, threshold):
        """
        Find the index with the minimum sum of a sliding window in the given audio segment 
        and perform operations based on this index.

        Parameters:
        - buffer (torch.Tensor): The buffer containing previously processed audio segments.
        - syn (torch.Tensor): The current audio segment to be processed.
        - N (int): The size of the sliding window used to calculate the sum.
        - threshold (float): Threshold value to determine whether to concatenate buffer and current segment or not.

        Returns:
        - tuple: A tuple containing the updated buffer and the processed audio segment.

        """
        arr = syn[0, 0, :]
        L = len(arr)
        mid = L // 2
        
        kernel = torch.ones(N).to(arr.device)
        window_sums = torch.nn.functional.conv1d(arr.abs().view(1, 1, -1), kernel.view(1, 1, -1), padding=0).squeeze()
        
        start_index = mid - (N // 2)
        min_sum, min_index = torch.min(window_sums[start_index:], dim=0)

        # get the start and end index of the window
        start_index = max(0, min_index.item() + start_index)
        end_index = min(L, min_index.item() + N + start_index)
        
        # calculate the real min_sum and min_index
        min_sum_real, min_index_real = torch.min(arr[start_index: end_index].abs(), dim=0)
        min_index = min_index_real.item() + start_index

        min_sum = min_sum / N
        syn_clone = syn.clone()

        if min_sum < threshold:
            syn = torch.cat([buffer.clone(), syn[:, :, :min_index]], dim=-1)
            buffer = syn_clone[:, :, min_index:]
        else:
            buffer = torch.cat([buffer, syn_clone], dim=-1)
            syn = None
        return buffer, syn

    def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10, 
            penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01):
        """
        Run the speech decoder process.

        Parameters:
        - hidden (torch.Tensor): The output for embedding layer of the language model.
        - top_k (int): The number of top-k tokens to consider during inference.
        - prefix (str, optional): The hidden state from the language model.
        - codec_chunk_size (int, default=40): The size of each chunk to process in the codec model.
        - codec_padding_size (int, default=10): The amount of padding to add on each side of the codec chunk.
        - penalty_window_size (int, default=20): The window size for applying penalties during decoding.
        - penalty (float, default=1.1): The penalty factor.

        Yields:
        - torch.Tensor: Intermediate audio segments generated by the codec model.

        """
        codec_upsample_rate = 600
        left_padding = 0
        right_padding = codec_padding_size
        prefix = None
        buffer = torch.zeros([1, 1, 0]).to(hidden.device)
        with torch.no_grad():
            with torch.autocast(device_type="cuda", 
                    dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32):
                print("Starting TTS...")
                token = torch.full((1, 0), self.model.vocab_size, dtype=torch.long, device=hidden.device)
                for next_token_id in self.infer(hidden, top_k, prefix, penalty_window_size, penalty):
                    token = torch.cat([token, next_token_id], dim=-1)
                    if token.size(1) == left_padding + codec_chunk_size + right_padding:
                        syn = self.codec_model.vqvae(token.unsqueeze(-1), 
                                                     torch.tensor(self.codec_model.vqvae.h.global_tokens, 
                                                     device=token.device).unsqueeze(0).unsqueeze(0))
                        print("Codec Done")
                        syn = syn[:, :, left_padding * codec_upsample_rate: -right_padding * codec_upsample_rate]
                        left_padding = codec_padding_size
                        token = token[:, -(left_padding + right_padding):]
                        buffer, syn = self.find_min_sum_index(buffer, syn, N, seg_threshold)
                        if syn is not None:
                            yield syn
                if token.size(1) > 0:
                    print("Codec Done")
                    syn = self.codec_model.vqvae(token.unsqueeze(-1), 
                                                 torch.tensor(self.codec_model.vqvae.h.global_tokens, 
                                                 device=token.device).unsqueeze(0).unsqueeze(0))
                    syn = syn[:, :, left_padding * codec_upsample_rate:]
                    yield torch.cat([buffer, syn], dim=-1)