AlexWortega commited on
Commit
473ca63
1 Parent(s): ffd4f38

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -25
main.py CHANGED
@@ -24,39 +24,24 @@ top_k = 20
24
  from safetensors.torch import load_file
25
 
26
  def convert_to_16_bit_wav(data):
27
- # Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html
28
- # breakpoint()
29
  if data.dtype == np.float32:
30
- # warnings.warn(
31
- # "Audio data is not in 16-bit integer format."
32
- # "Trying to convert to 16-bit int format."
33
- # )
34
  data = data / np.abs(data).max()
35
  data = data * 32767
36
  data = data.astype(np.int16)
37
  elif data.dtype == np.int32:
38
- # warnings.warn(
39
- # "Audio data is not in 16-bit integer format."
40
- # "Trying to convert to 16-bit int format."
41
- # )
42
  data = data / 65538
43
  data = data.astype(np.int16)
44
  elif data.dtype == np.int16:
45
  pass
46
  elif data.dtype == np.uint8:
47
- # warnings.warn(
48
- # "Audio data is not in 16-bit integer format."
49
- # "Trying to convert to 16-bit int format."
50
- # )
51
  data = data * 257 - 32768
52
  data = data.astype(np.int16)
53
  else:
54
- raise ValueError("Audio data cannot be converted to " "16-bit int format.")
55
  return data
56
 
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
 
59
-
60
  # Load the model with INT8 quantization
61
  model = AutoModelForCausalLM.from_pretrained(
62
  model_path,
@@ -71,7 +56,7 @@ ckpt_path = "audiotokenizer/SpeechTokenizer.pt"
71
  quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
72
  quantizer.eval()
73
 
74
- # Перемещение всех слоев квантизатора на устройство и их заморозка
75
  def freeze_entire_model(model):
76
  for n, p in model.named_parameters():
77
  p.requires_grad = False
@@ -81,7 +66,7 @@ for n, child in quantizer.named_children():
81
  child.to(device)
82
  child = freeze_entire_model(child)
83
 
84
- # Функция для создания токенов заполнения для аудио
85
  def get_audio_padding_tokens(quantizer):
86
  audio = torch.zeros((1, 1, 1)).to(device)
87
  codes = quantizer.encode(audio)
@@ -89,7 +74,7 @@ def get_audio_padding_tokens(quantizer):
89
  torch.cuda.empty_cache()
90
  return {"audio_tokens": codes.squeeze(1)}
91
 
92
- # Функция для декодирования аудио из токенов
93
  def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
94
  start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
95
  end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
@@ -112,9 +97,7 @@ def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
112
  return xp
113
 
114
 
115
- # Пример использования
116
-
117
- # Функция инференса для текста на входе и аудио на выходе
118
  def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
119
  text_tokenized = tokenizer(text, return_tensors="pt")
120
  text_input_tokens = text_tokenized["input_ids"].to(device)
@@ -132,7 +115,6 @@ def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024,
132
 
133
  return audio_signal
134
 
135
- # Функция инференса для аудио на входе и текста на выходе
136
  def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
137
  audio_data, sample_rate = torchaudio.load(audio_path)
138
 
@@ -155,7 +137,7 @@ def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=
155
 
156
  return decoded_text
157
 
158
- # Functions for inference
159
  def infer_text_to_audio_gr(text):
160
  audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer)
161
  return audio_signal
@@ -183,6 +165,42 @@ audio_to_text_interface = gr.Interface(
183
  allow_flagging='never'
184
  )
185
 
186
- # Launch Gradio App
187
  demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Text - Audio", "Audio - Text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  demo.launch(share=True)
 
24
  from safetensors.torch import load_file
25
 
26
  def convert_to_16_bit_wav(data):
 
 
27
  if data.dtype == np.float32:
 
 
 
 
28
  data = data / np.abs(data).max()
29
  data = data * 32767
30
  data = data.astype(np.int16)
31
  elif data.dtype == np.int32:
 
 
 
 
32
  data = data / 65538
33
  data = data.astype(np.int16)
34
  elif data.dtype == np.int16:
35
  pass
36
  elif data.dtype == np.uint8:
 
 
 
 
37
  data = data * 257 - 32768
38
  data = data.astype(np.int16)
39
  else:
40
+ raise ValueError("Audio data cannot be converted to 16-bit int format.")
41
  return data
42
 
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
 
 
45
  # Load the model with INT8 quantization
46
  model = AutoModelForCausalLM.from_pretrained(
47
  model_path,
 
56
  quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
57
  quantizer.eval()
58
 
59
+ # Freeze layers in the quantizer
60
  def freeze_entire_model(model):
61
  for n, p in model.named_parameters():
62
  p.requires_grad = False
 
66
  child.to(device)
67
  child = freeze_entire_model(child)
68
 
69
+ # Create padding tokens for audio
70
  def get_audio_padding_tokens(quantizer):
71
  audio = torch.zeros((1, 1, 1)).to(device)
72
  codes = quantizer.encode(audio)
 
74
  torch.cuda.empty_cache()
75
  return {"audio_tokens": codes.squeeze(1)}
76
 
77
+ # Decode audio from tokens
78
  def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
79
  start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
80
  end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
 
97
  return xp
98
 
99
 
100
+ # Inference functions
 
 
101
  def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
102
  text_tokenized = tokenizer(text, return_tensors="pt")
103
  text_input_tokens = text_tokenized["input_ids"].to(device)
 
115
 
116
  return audio_signal
117
 
 
118
  def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
119
  audio_data, sample_rate = torchaudio.load(audio_path)
120
 
 
137
 
138
  return decoded_text
139
 
140
+ # Functions for Gradio Interface
141
  def infer_text_to_audio_gr(text):
142
  audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer)
143
  return audio_signal
 
165
  allow_flagging='never'
166
  )
167
 
168
+ # Gradio Demo
169
  demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Text - Audio", "Audio - Text"])
170
+
171
+ # Custom CSS for centered links
172
+ custom_css = """
173
+ <style>
174
+ .center {
175
+ text-align: center;
176
+ }
177
+ </style>
178
+ """
179
+
180
+ # Add Gradio description with centered links
181
+ description = f"""
182
+ # **Salt: Speech And Language Transformer**
183
+
184
+ Welcome to the demo of **Salt**, a speech and language model. Vikhr Salt is capable of both **Text-to-Speech (T2S)** and **Speech-to-Text (S2T)** tasks, making it a versatile tool for transforming language into speech and vice versa. Built on a pre-trained large language model, Vikhr Salt incorporates audio tokens using cutting-edge techniques like **Encodec** and **SpeechTokenizer**, enabling robust performance across multiple modalities.
185
+
186
+ ## **🛠 Features**
187
+ - **Text-to-Speech (T2S)**: Enter text and generate high-quality audio outputs.
188
+ - **Speech-to-Text (S2T)**: Upload an audio file and convert it into accurate text.
189
+
190
+ ## **🚀 Try it out:**
191
+ Explore the tabs to try the **Text - Audio** and **Audio - Text** modes!
192
+
193
+ ---
194
+
195
+ <div class="center">
196
+ ### **📄 Preprint**
197
+ [Read the paper](https://docs.google.com/document/d/1ZvV47W4BCyZM_JfDC1BKj-0ozwPck5t2yNB8jORVshI/edit?usp=sharing)
198
+
199
+ ### **📂 Code**
200
+ [Explore the code](https://github.com/VikhrModels/Vikhr4o)
201
+ </div>
202
+
203
+ """
204
+
205
+ # Launch Gradio App
206
  demo.launch(share=True)