gpt-omni commited on
Commit
eb83dcd
1 Parent(s): 69a5822
Files changed (1) hide show
  1. app.py +270 -5
app.py CHANGED
@@ -7,18 +7,283 @@ import numpy as np
7
  import spaces
8
  import torch
9
 
10
-
11
- from inference import OmniInference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- omni_client = OmniInference('./checkpoint', device)
16
- omni_client.warm_up()
17
 
18
  OUT_CHUNK = 4096
19
  OUT_RATE = 24000
20
  OUT_CHANNELS = 1
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def process_audio(audio):
24
  filepath = audio
@@ -28,7 +293,7 @@ def process_audio(audio):
28
 
29
  cnt = 0
30
  tik = time.time()
31
- for chunk in omni_client.run_AT_batch_stream(filepath):
32
  # Convert chunk to numpy array
33
  if cnt == 0:
34
  print(f"first chunk time cost: {time.time() - tik:.3f}")
 
7
  import spaces
8
  import torch
9
 
10
+ import os
11
+ import lightning as L
12
+ import torch
13
+ import time
14
+ import spaces
15
+ from snac import SNAC
16
+ from litgpt import Tokenizer
17
+ from litgpt.utils import (
18
+ num_parameters,
19
+ )
20
+ from litgpt.generate.base import (
21
+ generate_AA,
22
+ generate_ASR,
23
+ generate_TA,
24
+ generate_TT,
25
+ generate_AT,
26
+ generate_TA_BATCH,
27
+ )
28
+ from typing import Any, Literal, Optional
29
+ import soundfile as sf
30
+ from litgpt.model import GPT, Config
31
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
32
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
33
+ from utils.snac_utils import get_snac, generate_audio_data
34
+ import whisper
35
+ from tqdm import tqdm
36
+ from huggingface_hub import snapshot_download
37
+ from litgpt.generate.base import sample
38
 
39
 
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ ckpt_dir = "./checkpoint"
42
+
43
 
44
  OUT_CHUNK = 4096
45
  OUT_RATE = 24000
46
  OUT_CHANNELS = 1
47
 
48
+ # TODO
49
+ text_vocabsize = 151936
50
+ text_specialtokens = 64
51
+ audio_vocabsize = 4096
52
+ audio_specialtokens = 64
53
+
54
+ padded_text_vocabsize = text_vocabsize + text_specialtokens
55
+ padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
56
+
57
+ _eot = text_vocabsize
58
+ _pad_t = text_vocabsize + 1
59
+ _input_t = text_vocabsize + 2
60
+ _answer_t = text_vocabsize + 3
61
+ _asr = text_vocabsize + 4
62
+
63
+ _eoa = audio_vocabsize
64
+ _pad_a = audio_vocabsize + 1
65
+ _input_a = audio_vocabsize + 2
66
+ _answer_a = audio_vocabsize + 3
67
+ _split = audio_vocabsize + 4
68
+
69
+
70
+ if not os.path.exists(ckpt_dir):
71
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
72
+ download_model(ckpt_dir)
73
+
74
+
75
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
76
+ whispermodel = whisper.load_model("small").to(device)
77
+ text_tokenizer = Tokenizer(ckpt_dir)
78
+ fabric = L.Fabric(devices=1, strategy="auto")
79
+ config = Config.from_file(ckpt_dir + "/model_config.yaml")
80
+ config.post_adapter = False
81
+
82
+ model = GPT(config, device=device)
83
+
84
+ # model = fabric.setup(model)
85
+ state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
86
+ model.load_state_dict(state_dict, strict=True)
87
+ model = model.to(device)
88
+ model.eval()
89
+
90
+
91
+ def download_model(ckpt_dir):
92
+ repo_id = "gpt-omni/mini-omni"
93
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
94
+
95
+
96
+ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
97
+ with torch.no_grad():
98
+ mel = mel.unsqueeze(0).to(device)
99
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
100
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
101
+ T = audio_feature.size(0)
102
+ input_ids_AA = []
103
+ for i in range(7):
104
+ input_ids_item = []
105
+ input_ids_item.append(layershift(_input_a, i))
106
+ input_ids_item += [layershift(_pad_a, i)] * T
107
+ input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
108
+ input_ids_AA.append(torch.tensor(input_ids_item))
109
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
110
+ input_ids_AA.append(input_id_T)
111
+
112
+ input_ids_AT = []
113
+ for i in range(7):
114
+ input_ids_item = []
115
+ input_ids_item.append(layershift(_input_a, i))
116
+ input_ids_item += [layershift(_pad_a, i)] * T
117
+ input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
118
+ input_ids_AT.append(torch.tensor(input_ids_item))
119
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
120
+ input_ids_AT.append(input_id_T)
121
+
122
+ input_ids = [input_ids_AA, input_ids_AT]
123
+ stacked_inputids = [[] for _ in range(8)]
124
+ for i in range(2):
125
+ for j in range(8):
126
+ stacked_inputids[j].append(input_ids[i][j])
127
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
128
+ return torch.stack([audio_feature, audio_feature]), stacked_inputids
129
+
130
+
131
+ def next_token_batch(
132
+ model: GPT,
133
+ audio_features: torch.tensor,
134
+ input_ids: list,
135
+ whisper_lens: int,
136
+ task: list,
137
+ input_pos: torch.Tensor,
138
+ **kwargs: Any,
139
+ ) -> torch.Tensor:
140
+ input_pos = input_pos.to(model.device)
141
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
142
+ logits_a, logit_t = model(
143
+ audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
144
+ )
145
+
146
+ for i in range(7):
147
+ logits_a[i] = logits_a[i][0].unsqueeze(0)
148
+ logit_t = logit_t[1].unsqueeze(0)
149
+
150
+ next_audio_tokens = []
151
+ for logit_a in logits_a:
152
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
153
+ next_audio_tokens.append(next_a)
154
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
155
+ return next_audio_tokens, next_t
156
+
157
+
158
+ def load_audio(path):
159
+ audio = whisper.load_audio(path)
160
+ duration_ms = (len(audio) / 16000) * 1000
161
+ audio = whisper.pad_or_trim(audio)
162
+ mel = whisper.log_mel_spectrogram(audio)
163
+ return mel, int(duration_ms / 20) + 1
164
+
165
+
166
+ # @torch.inference_mode()
167
+ @spaces.GPU
168
+ def run_AT_batch_stream(
169
+ audio_path,
170
+ stream_stride=4,
171
+ max_returned_tokens=2048,
172
+ temperature=0.9,
173
+ top_k=1,
174
+ top_p=1.0,
175
+ eos_id_a=_eoa,
176
+ eos_id_t=_eot,
177
+ ):
178
+
179
+ assert os.path.exists(audio_path), f"audio file {audio_path} not found"
180
+
181
+ # with self.fabric.init_tensor():
182
+ model.set_kv_cache(batch_size=2)
183
+
184
+ mel, leng = load_audio(audio_path)
185
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
186
+ T = input_ids[0].size(1)
187
+ device = input_ids[0].device
188
+
189
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
190
+
191
+ if model.max_seq_length < max_returned_tokens - 1:
192
+ raise NotImplementedError(
193
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
194
+ )
195
+
196
+ input_pos = torch.tensor([T], device=device)
197
+ list_output = [[] for i in range(8)]
198
+ tokens_A, token_T = next_token_batch(
199
+ model,
200
+ audio_feature.to(torch.float32).to(model.device),
201
+ input_ids,
202
+ [T - 3, T - 3],
203
+ ["A1T2", "A1T2"],
204
+ input_pos=torch.arange(0, T, device=device),
205
+ temperature=temperature,
206
+ top_k=top_k,
207
+ top_p=top_p,
208
+ )
209
+
210
+ for i in range(7):
211
+ list_output[i].append(tokens_A[i].tolist()[0])
212
+ list_output[7].append(token_T.tolist()[0])
213
+
214
+ model_input_ids = [[] for i in range(8)]
215
+ for i in range(7):
216
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
217
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
218
+ model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
219
+ model_input_ids[i] = torch.stack(model_input_ids[i])
220
+
221
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
222
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
223
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
224
+
225
+ text_end = False
226
+ index = 1
227
+ nums_generate = stream_stride
228
+ begin_generate = False
229
+ current_index = 0
230
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
231
+ tokens_A, token_T = next_token_batch(
232
+ model,
233
+ None,
234
+ model_input_ids,
235
+ None,
236
+ None,
237
+ input_pos=input_pos,
238
+ temperature=temperature,
239
+ top_k=top_k,
240
+ top_p=top_p,
241
+ )
242
+
243
+ if text_end:
244
+ token_T = torch.tensor([_pad_t], device=device)
245
+
246
+ if tokens_A[-1] == eos_id_a:
247
+ break
248
+
249
+ if token_T == eos_id_t:
250
+ text_end = True
251
+
252
+ for i in range(7):
253
+ list_output[i].append(tokens_A[i].tolist()[0])
254
+ list_output[7].append(token_T.tolist()[0])
255
+
256
+ model_input_ids = [[] for i in range(8)]
257
+ for i in range(7):
258
+ tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
259
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
260
+ model_input_ids[i].append(
261
+ torch.tensor([layershift(4097, i)], device=device)
262
+ )
263
+ model_input_ids[i] = torch.stack(model_input_ids[i])
264
+
265
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
266
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
267
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
268
+
269
+ if index == 7:
270
+ begin_generate = True
271
+
272
+ if begin_generate:
273
+ current_index += 1
274
+ if current_index == nums_generate:
275
+ current_index = 0
276
+ snac = get_snac(list_output, index, nums_generate)
277
+ audio_stream = generate_audio_data(snac, snacmodel, device)
278
+ yield audio_stream
279
+
280
+ input_pos = input_pos.add_(1)
281
+ index += 1
282
+ text = text_tokenizer.decode(torch.tensor(list_output[-1]))
283
+ print(f"text output: {text}")
284
+ model.clear_kv_cache()
285
+ return list_output
286
+
287
 
288
  def process_audio(audio):
289
  filepath = audio
 
293
 
294
  cnt = 0
295
  tik = time.time()
296
+ for chunk in run_AT_batch_stream(filepath):
297
  # Convert chunk to numpy array
298
  if cnt == 0:
299
  print(f"first chunk time cost: {time.time() - tik:.3f}")