File size: 8,253 Bytes
14e7fb1
 
 
 
 
 
 
 
 
 
 
ffd4f38
e985d60
bff356a
14e7fb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473ca63
14e7fb1
 
ab478b1
 
14e7fb1
 
 
 
028eaf2
14e7fb1
 
 
 
 
 
 
 
 
473ca63
14e7fb1
 
 
 
 
 
 
 
 
473ca63
14e7fb1
 
 
 
 
 
 
473ca63
14e7fb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473ca63
bff356a
 
 
 
 
aea90af
bff356a
 
54ccb2a
14e7fb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bff356a
 
 
14e7fb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473ca63
14e7fb1
 
 
 
 
 
 
 
 
 
 
 
0982cdf
14e7fb1
0982cdf
14e7fb1
 
 
 
 
 
0982cdf
14e7fb1
0982cdf
14e7fb1
 
 
473ca63
8b78c8f
473ca63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5454853
 
 
 
 
473ca63
 
 
8b78c8f
 
 
 
e1d1a57
 
 
 
 
8b78c8f
e1d1a57
 
 
 
 
8b78c8f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import gradio as gr
import torch
import torchaudio
from transformers import AutoTokenizer, AutoModelForCausalLM
from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal
import bitsandbytes as bnb  # Import bitsandbytes for INT8 quantization
import numpy as np
from uuid import uuid4

# Load the necessary models and tokenizers
model_path = "Vikhrmodels/salt-116k"
tokenizer = AutoTokenizer.from_pretrained(model_path)
print(tokenizer)
# Специальные токены
start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"

# Константы
n_codebooks = 3
max_seq_length = 1024
top_k = 20

from safetensors.torch import load_file

def convert_to_16_bit_wav(data):
    if data.dtype == np.float32:
        data = data / np.abs(data).max()
        data = data * 32767
        data = data.astype(np.int16)
    elif data.dtype == np.int32:
        data = data / 65538
        data = data.astype(np.int16)
    elif data.dtype == np.int16:
        pass
    elif data.dtype == np.uint8:
        data = data * 257 - 32768
        data = data.astype(np.int16)
    else:
        raise ValueError("Audio data cannot be converted to 16-bit int format.")
    return data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model with INT8 quantization
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    cache_dir=".",
    load_in_8bit=False,  # Enable loading in INT8
    device_map="auto"  # Automatically map model to available devices
)

# Configurations for Speech Tokenizer
config_path = "audiotokenizer/speechtokenizer_hubert_avg_config.json"
ckpt_path = "audiotokenizer/SpeechTokenizer.pt"
quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
quantizer.eval()

# Freeze layers in the quantizer
def freeze_entire_model(model):
    for n, p in model.named_parameters():
        p.requires_grad = False
    return model

for n, child in quantizer.named_children():
    child.to(device)
    child = freeze_entire_model(child)

# Create padding tokens for audio
def get_audio_padding_tokens(quantizer):
    audio = torch.zeros((1, 1, 1)).to(device)
    codes = quantizer.encode(audio)
    del audio
    torch.cuda.empty_cache()
    return {"audio_tokens": codes.squeeze(1)}

# Decode audio from tokens
def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
    start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
    end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
    start = start[0, -1] + 1 if len(start) else 0
    end = end[0, -1] if len(end) else tokens.shape[-1]

    audio_tokens = tokens[start:end] % n_original_tokens
    reminder = audio_tokens.shape[-1] % n_codebooks

    if reminder:
        audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)

    transposed = audio_tokens.view(-1, n_codebooks).t()
    codes = transposed.view(n_codebooks, 1, -1).to(device)

    audio = quantizer.decode(codes).squeeze(0)
    torch.cuda.empty_cache()
    xp = str(uuid4())+'.wav'
    AudioSignal(audio.detach().cpu().numpy(),quantizer.sample_rate).write(xp)
    return xp


# Inference functions
def infer_text_to_audio(text):
    
    max_seq_length=1024
    top_k=20
    
    print(type(tokenizer))
    print(text)
    
    text_tokenized = tokenizer(str(text), return_tensors="pt")
    text_input_tokens = text_tokenized["input_ids"].to(device)

    soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)

    text_tokens = torch.cat([text_input_tokens, soa], dim=1)
    attention_mask = torch.ones(text_tokens.size(), device=device)

    output_audio_tokens = model.generate(text_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)

    padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"].to(device)
    audio_signal = decode_audio(output_audio_tokens[0], quantizer, padding_tokens.t()[0], len(tokenizer) - 1024)

    return audio_signal

def infer_audio_to_text(audio_path):
    max_seq_length=1024
    top_k=20
    audio_data, sample_rate = torchaudio.load(audio_path)

    audio = audio_data.view(1, 1, -1).float().to(device)
    codes = quantizer.encode(audio)
    n_codebooks_a = 1
    raw_audio_tokens = codes[:, :n_codebooks_a] + len(tokenizer) - 1024

    soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1)

    attention_mask = torch.ones(audio_tokens.size(), device=device)

    output_text_tokens = model.generate(audio_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)

    output_text_tokens = output_text_tokens.cpu()[0]
    output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
    decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)

    return decoded_text

# Functions for Gradio Interface
def infer_text_to_audio_gr(text):
    audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer)
    return audio_signal

def infer_audio_to_text_gr(audio_path):
    generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer)
    return generated_text

# Gradio Interface
text_to_audio_interface = gr.Interface(
    fn=infer_text_to_audio_gr,
    inputs=gr.Textbox(label="Input Text"),
    outputs=gr.Audio(label="Audio Answer"),
    title="T2S",
    description="Model in text to audio mode",
    allow_flagging='never',
)

audio_to_text_interface = gr.Interface(
    fn=infer_audio_to_text_gr,
    inputs=gr.Audio(type="filepath", label="Input Audio"),
    outputs=gr.Textbox(label="Text Answer"),
    title="S2T",
    description="Model in audio to text mode",
    allow_flagging='never'
)

# Gradio Demo
#demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Text - Audio", "Audio - Text"])

# Custom CSS for centered links
custom_css = """
<style>
    .center {
        text-align: center;
    }
</style>
"""

# Add Gradio description with centered links
description = f"""
# **Salt: Speech And Language Transformer**

Welcome to the demo of **Salt**, a speech and language model. Vikhr Salt is capable of both **Text-to-Speech (T2S)** and **Speech-to-Text (S2T)** tasks, making it a versatile tool for transforming language into speech and vice versa. Built on a pre-trained large language model, Vikhr Salt incorporates audio tokens using cutting-edge techniques like **Encodec** and **SpeechTokenizer**, enabling robust performance across multiple modalities.

## **🛠 Features**
- **Text-to-Speech (T2S)**: Enter text and generate high-quality audio outputs.
- **Speech-to-Text (S2T)**: Upload an audio file and convert it into accurate text.

## **🚀 Try it out:**
Explore the tabs to try the **Text - Audio** and **Audio - Text** modes!


### **📄 Preprint**  
[Read the paper](https://docs.google.com/document/d/1ZvV47W4BCyZM_JfDC1BKj-0ozwPck5t2yNB8jORVshI/edit?usp=sharing)  

### **📂 Code**  
[Explore the code](https://github.com/VikhrModels/Vikhr4o)  


"""
with gr.Blocks() as demo:
    gr.Markdown(description)
    with gr.Tabs():
        with gr.TabItem("Text - Audio"):
            gr.Markdown("### Text-to-Speech (T2S) Mode")
            input_text = gr.Textbox(label="Input Text")
            output_audio = gr.Audio(label="Audio Answer")
            generate_button = gr.Button("Generate")
            generate_button.click(infer_text_to_audio, inputs=input_text, outputs=output_audio)
        with gr.TabItem("Audio - Text"):
            gr.Markdown("### Speech-to-Text (S2T) Mode")
            input_audio = gr.Audio(type="filepath", label="Input Audio")
            output_text = gr.Textbox(label="Text Answer")
            generate_button = gr.Button("Generate")
            generate_button.click(infer_audio_to_text, inputs=input_audio, outputs=output_text)

# Launch the demo
demo.launch(share=True)