gpt-omni commited on
Commit
5e4b316
1 Parent(s): 8fc1cf4
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple web interactive chat demo based on gradio."""
2
+
3
+ import os
4
+ import time
5
+ import gradio as gr
6
+ import numpy as np
7
+ import spaces
8
+ import torch
9
+
10
+
11
+ from inference import OmniInference
12
+
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ omni_client = OmniInference('./checkpoint', device)
16
+ omni_client.warm_up()
17
+
18
+
19
+ OUT_CHUNK = 4096
20
+ OUT_RATE = 24000
21
+ OUT_CHANNELS = 1
22
+
23
+
24
+ @spaces.GPU
25
+ def process_audio(audio):
26
+ filepath = audio
27
+ print(f"filepath: {filepath}")
28
+ if filepath is None:
29
+ return
30
+
31
+ cnt = 0
32
+ tik = time.time()
33
+ for chunk in omni_client.run_AT_batch_stream(filepath):
34
+ # Convert chunk to numpy array
35
+ if cnt == 0:
36
+ print(f"first chunk time cost: {time.time() - tik:.3f}")
37
+ cnt += 1
38
+ audio_data = np.frombuffer(chunk, dtype=np.int16)
39
+ audio_data = audio_data.reshape(-1, OUT_CHANNELS)
40
+ yield OUT_RATE, audio_data.astype(np.int16)
41
+
42
+
43
+ demo = gr.Interface(
44
+ process_audio,
45
+ inputs=gr.Audio(type="filepath", label="Microphone"),
46
+ outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)],
47
+ title="Chat Mini-Omni Demo",
48
+ live=True,
49
+ )
50
+ demo.queue().launch()
inference.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning as L
3
+ import torch
4
+ import time
5
+ from snac import SNAC
6
+ from litgpt import Tokenizer
7
+ from litgpt.utils import (
8
+ num_parameters,
9
+ )
10
+ from litgpt.generate.base import (
11
+ generate_AA,
12
+ generate_ASR,
13
+ generate_TA,
14
+ generate_TT,
15
+ generate_AT,
16
+ generate_TA_BATCH,
17
+ next_token_batch
18
+ )
19
+ import soundfile as sf
20
+ from litgpt.model import GPT, Config
21
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
22
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
23
+ from utils.snac_utils import get_snac, generate_audio_data
24
+ import whisper
25
+ from tqdm import tqdm
26
+ from huggingface_hub import snapshot_download
27
+
28
+
29
+ torch.set_printoptions(sci_mode=False)
30
+
31
+
32
+ # TODO
33
+ text_vocabsize = 151936
34
+ text_specialtokens = 64
35
+ audio_vocabsize = 4096
36
+ audio_specialtokens = 64
37
+
38
+ padded_text_vocabsize = text_vocabsize + text_specialtokens
39
+ padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
40
+
41
+ _eot = text_vocabsize
42
+ _pad_t = text_vocabsize + 1
43
+ _input_t = text_vocabsize + 2
44
+ _answer_t = text_vocabsize + 3
45
+ _asr = text_vocabsize + 4
46
+
47
+ _eoa = audio_vocabsize
48
+ _pad_a = audio_vocabsize + 1
49
+ _input_a = audio_vocabsize + 2
50
+ _answer_a = audio_vocabsize + 3
51
+ _split = audio_vocabsize + 4
52
+
53
+
54
+ def get_input_ids_TA(text, text_tokenizer):
55
+ input_ids_item = [[] for _ in range(8)]
56
+ text_tokens = text_tokenizer.encode(text)
57
+ for i in range(7):
58
+ input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
59
+ layershift(_answer_a, i)
60
+ ]
61
+ input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
62
+ input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
63
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
64
+ return input_ids_item
65
+
66
+
67
+ def get_input_ids_TT(text, text_tokenizer):
68
+ input_ids_item = [[] for i in range(8)]
69
+ text_tokens = text_tokenizer.encode(text).tolist()
70
+
71
+ for i in range(7):
72
+ input_ids_item[i] = torch.tensor(
73
+ [layershift(_pad_a, i)] * (len(text_tokens) + 3)
74
+ ).unsqueeze(0)
75
+ input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
76
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
77
+
78
+ return input_ids_item
79
+
80
+
81
+ def get_input_ids_whisper(
82
+ mel, leng, whispermodel, device,
83
+ special_token_a=_answer_a, special_token_t=_answer_t,
84
+ ):
85
+
86
+ with torch.no_grad():
87
+ mel = mel.unsqueeze(0).to(device)
88
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
89
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
90
+
91
+ T = audio_feature.size(0)
92
+ input_ids = []
93
+ for i in range(7):
94
+ input_ids_item = []
95
+ input_ids_item.append(layershift(_input_a, i))
96
+ input_ids_item += [layershift(_pad_a, i)] * T
97
+ input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
98
+ input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
99
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
100
+ input_ids.append(input_id_T.unsqueeze(0))
101
+ return audio_feature.unsqueeze(0), input_ids
102
+
103
+
104
+ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
105
+ with torch.no_grad():
106
+ mel = mel.unsqueeze(0).to(device)
107
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
108
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
109
+ T = audio_feature.size(0)
110
+ input_ids_AA = []
111
+ for i in range(7):
112
+ input_ids_item = []
113
+ input_ids_item.append(layershift(_input_a, i))
114
+ input_ids_item += [layershift(_pad_a, i)] * T
115
+ input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
116
+ input_ids_AA.append(torch.tensor(input_ids_item))
117
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
118
+ input_ids_AA.append(input_id_T)
119
+
120
+ input_ids_AT = []
121
+ for i in range(7):
122
+ input_ids_item = []
123
+ input_ids_item.append(layershift(_input_a, i))
124
+ input_ids_item += [layershift(_pad_a, i)] * T
125
+ input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
126
+ input_ids_AT.append(torch.tensor(input_ids_item))
127
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
128
+ input_ids_AT.append(input_id_T)
129
+
130
+ input_ids = [input_ids_AA, input_ids_AT]
131
+ stacked_inputids = [[] for _ in range(8)]
132
+ for i in range(2):
133
+ for j in range(8):
134
+ stacked_inputids[j].append(input_ids[i][j])
135
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
136
+ return torch.stack([audio_feature, audio_feature]), stacked_inputids
137
+
138
+
139
+ def load_audio(path):
140
+ audio = whisper.load_audio(path)
141
+ duration_ms = (len(audio) / 16000) * 1000
142
+ audio = whisper.pad_or_trim(audio)
143
+ mel = whisper.log_mel_spectrogram(audio)
144
+ return mel, int(duration_ms / 20) + 1
145
+
146
+
147
+ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
148
+ snacmodel, out_dir=None):
149
+ with fabric.init_tensor():
150
+ model.set_kv_cache(batch_size=2)
151
+ tokenlist = generate_TA_BATCH(
152
+ model,
153
+ audio_feature,
154
+ input_ids,
155
+ [leng, leng],
156
+ ["A1A2", "A1T2"],
157
+ max_returned_tokens=2048,
158
+ temperature=0.9,
159
+ top_k=1,
160
+ eos_id_a=_eoa,
161
+ eos_id_t=_eot,
162
+ pad_id_t=_pad_t,
163
+ shift=padded_text_vocabsize,
164
+ include_prompt=True,
165
+ generate_text=True,
166
+ )
167
+ text_tokenlist = tokenlist[-1]
168
+ if text_vocabsize in text_tokenlist:
169
+ text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
170
+ text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
171
+
172
+ audio_tokenlist = tokenlist[:-1]
173
+ audiolist = reconscruct_snac(audio_tokenlist)
174
+ audio = reconstruct_tensors(audiolist)
175
+ if out_dir is None:
176
+ out_dir = "./output/default/A1-A2-batch"
177
+ else:
178
+ out_dir = out_dir + "/A1-A2-batch"
179
+ if not os.path.exists(out_dir):
180
+ os.makedirs(out_dir)
181
+ with torch.inference_mode():
182
+ audio_hat = snacmodel.decode(audio)
183
+ sf.write(
184
+ f"{out_dir}/{step:02d}.wav",
185
+ audio_hat.squeeze().cpu().numpy(),
186
+ 24000,
187
+ )
188
+ model.clear_kv_cache()
189
+ return text
190
+
191
+
192
+ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
193
+ with fabric.init_tensor():
194
+ model.set_kv_cache(batch_size=1)
195
+ tokenlist = generate_AT(
196
+ model,
197
+ audio_feature,
198
+ input_ids,
199
+ [leng],
200
+ ["AT"],
201
+ max_returned_tokens=2048,
202
+ temperature=0.9,
203
+ top_k=1,
204
+ eos_id_a=_eoa,
205
+ eos_id_t=_eot,
206
+ pad_id_t=_pad_t,
207
+ shift=padded_text_vocabsize,
208
+ include_prompt=True,
209
+ generate_text=True,
210
+ )
211
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
212
+
213
+
214
+ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
215
+ snacmodel, out_dir=None):
216
+ with fabric.init_tensor():
217
+ model.set_kv_cache(batch_size=1)
218
+ tokenlist = generate_AA(
219
+ model,
220
+ audio_feature,
221
+ input_ids,
222
+ [leng],
223
+ ["A1T2"],
224
+ max_returned_tokens=2048,
225
+ temperature=0.9,
226
+ top_k=1,
227
+ eos_id_a=_eoa,
228
+ eos_id_t=_eot,
229
+ pad_id_t=_pad_t,
230
+ shift=padded_text_vocabsize,
231
+ include_prompt=True,
232
+ generate_text=True,
233
+ )
234
+ audiolist = reconscruct_snac(tokenlist)
235
+ tokenlist = tokenlist[-1]
236
+ if text_vocabsize in tokenlist:
237
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
238
+ if out_dir is None:
239
+ out_dir = "./output/default/A1-A2"
240
+ else:
241
+ out_dir = out_dir + "/A1-A2"
242
+ if not os.path.exists(out_dir):
243
+ os.makedirs(out_dir)
244
+
245
+ audio = reconstruct_tensors(audiolist)
246
+ with torch.inference_mode():
247
+ audio_hat = snacmodel.decode(audio)
248
+ sf.write(
249
+ f"{out_dir}/{step:02d}.wav",
250
+ audio_hat.squeeze().cpu().numpy(),
251
+ 24000,
252
+ )
253
+ model.clear_kv_cache()
254
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
255
+
256
+
257
+ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
258
+ with fabric.init_tensor():
259
+ model.set_kv_cache(batch_size=1)
260
+ tokenlist = generate_ASR(
261
+ model,
262
+ audio_feature,
263
+ input_ids,
264
+ [leng],
265
+ ["A1T1"],
266
+ max_returned_tokens=2048,
267
+ temperature=0.9,
268
+ top_k=1,
269
+ eos_id_a=_eoa,
270
+ eos_id_t=_eot,
271
+ pad_id_t=_pad_t,
272
+ shift=padded_text_vocabsize,
273
+ include_prompt=True,
274
+ generate_text=True,
275
+ )
276
+ model.clear_kv_cache()
277
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
278
+
279
+
280
+ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
281
+ snacmodel, out_dir=None):
282
+ with fabric.init_tensor():
283
+ model.set_kv_cache(batch_size=1)
284
+ tokenlist = generate_TA(
285
+ model,
286
+ None,
287
+ input_ids,
288
+ None,
289
+ ["T1A2"],
290
+ max_returned_tokens=2048,
291
+ temperature=0.9,
292
+ top_k=1,
293
+ eos_id_a=_eoa,
294
+ eos_id_t=_eot,
295
+ pad_id_t=_pad_t,
296
+ shift=padded_text_vocabsize,
297
+ include_prompt=True,
298
+ generate_text=True,
299
+ )
300
+
301
+ audiolist = reconscruct_snac(tokenlist)
302
+ tokenlist = tokenlist[-1]
303
+
304
+ if text_vocabsize in tokenlist:
305
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
306
+ audio = reconstruct_tensors(audiolist)
307
+ if out_dir is None:
308
+ out_dir = "./output/default/T1-A2"
309
+ else:
310
+ out_dir = out_dir + "/T1-A2"
311
+ if not os.path.exists(out_dir):
312
+ os.makedirs(out_dir)
313
+
314
+ with torch.inference_mode():
315
+ audio_hat = snacmodel.decode(audio)
316
+ sf.write(
317
+ f"{out_dir}/{step:02d}.wav",
318
+ audio_hat.squeeze().cpu().numpy(),
319
+ 24000,
320
+ )
321
+ model.clear_kv_cache()
322
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
323
+
324
+
325
+ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
326
+
327
+ with fabric.init_tensor():
328
+ model.set_kv_cache(batch_size=1)
329
+ tokenlist = generate_TT(
330
+ model,
331
+ None,
332
+ input_ids,
333
+ None,
334
+ ["T1T2"],
335
+ max_returned_tokens=2048,
336
+ temperature=0.9,
337
+ top_k=1,
338
+ eos_id_a=_eoa,
339
+ eos_id_t=_eot,
340
+ pad_id_t=_pad_t,
341
+ shift=padded_text_vocabsize,
342
+ include_prompt=True,
343
+ generate_text=True,
344
+ )
345
+ model.clear_kv_cache()
346
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
347
+
348
+
349
+ def load_model(ckpt_dir, device):
350
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
351
+ whispermodel = whisper.load_model("small").to(device)
352
+ text_tokenizer = Tokenizer(ckpt_dir)
353
+ fabric = L.Fabric(devices=1, strategy="auto")
354
+ config = Config.from_file(ckpt_dir + "/model_config.yaml")
355
+ config.post_adapter = False
356
+
357
+ with fabric.init_module(empty_init=False):
358
+ model = GPT(config)
359
+
360
+ model = fabric.setup(model)
361
+ state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
362
+ model.load_state_dict(state_dict, strict=True)
363
+ model.to(device).eval()
364
+
365
+ return fabric, model, text_tokenizer, snacmodel, whispermodel
366
+
367
+
368
+ def download_model(ckpt_dir):
369
+ repo_id = "gpt-omni/mini-omni"
370
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
371
+
372
+
373
+ class OmniInference:
374
+
375
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
376
+ self.device = device
377
+ if not os.path.exists(ckpt_dir):
378
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
379
+ download_model(ckpt_dir)
380
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
381
+
382
+ def warm_up(self, sample='./data/samples/output1.wav'):
383
+ for _ in self.run_AT_batch_stream(sample):
384
+ pass
385
+
386
+ @torch.inference_mode()
387
+ def run_AT_batch_stream(self,
388
+ audio_path,
389
+ stream_stride=4,
390
+ max_returned_tokens=2048,
391
+ temperature=0.9,
392
+ top_k=1,
393
+ top_p=1.0,
394
+ eos_id_a=_eoa,
395
+ eos_id_t=_eot,
396
+ ):
397
+
398
+ assert os.path.exists(audio_path), f"audio file {audio_path} not found"
399
+ model = self.model
400
+
401
+ with self.fabric.init_tensor():
402
+ model.set_kv_cache(batch_size=2)
403
+
404
+ mel, leng = load_audio(audio_path)
405
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
406
+ T = input_ids[0].size(1)
407
+ device = input_ids[0].device
408
+
409
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
410
+
411
+ if model.max_seq_length < max_returned_tokens - 1:
412
+ raise NotImplementedError(
413
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
414
+ )
415
+
416
+ input_pos = torch.tensor([T], device=device)
417
+ list_output = [[] for i in range(8)]
418
+ tokens_A, token_T = next_token_batch(
419
+ model,
420
+ audio_feature.to(torch.float32).to(model.device),
421
+ input_ids,
422
+ [T - 3, T - 3],
423
+ ["A1T2", "A1T2"],
424
+ input_pos=torch.arange(0, T, device=device),
425
+ temperature=temperature,
426
+ top_k=top_k,
427
+ top_p=top_p,
428
+ )
429
+
430
+ for i in range(7):
431
+ list_output[i].append(tokens_A[i].tolist()[0])
432
+ list_output[7].append(token_T.tolist()[0])
433
+
434
+ model_input_ids = [[] for i in range(8)]
435
+ for i in range(7):
436
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
437
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
438
+ model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
439
+ model_input_ids[i] = torch.stack(model_input_ids[i])
440
+
441
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
442
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
443
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
444
+
445
+ text_end = False
446
+ index = 1
447
+ nums_generate = stream_stride
448
+ begin_generate = False
449
+ current_index = 0
450
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
451
+ tokens_A, token_T = next_token_batch(
452
+ model,
453
+ None,
454
+ model_input_ids,
455
+ None,
456
+ None,
457
+ input_pos=input_pos,
458
+ temperature=temperature,
459
+ top_k=top_k,
460
+ top_p=top_p,
461
+ )
462
+
463
+ if text_end:
464
+ token_T = torch.tensor([_pad_t], device=device)
465
+
466
+ if tokens_A[-1] == eos_id_a:
467
+ break
468
+
469
+ if token_T == eos_id_t:
470
+ text_end = True
471
+
472
+ for i in range(7):
473
+ list_output[i].append(tokens_A[i].tolist()[0])
474
+ list_output[7].append(token_T.tolist()[0])
475
+
476
+ model_input_ids = [[] for i in range(8)]
477
+ for i in range(7):
478
+ tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
479
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
480
+ model_input_ids[i].append(
481
+ torch.tensor([layershift(4097, i)], device=device)
482
+ )
483
+ model_input_ids[i] = torch.stack(model_input_ids[i])
484
+
485
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
486
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
487
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
488
+
489
+ if index == 7:
490
+ begin_generate = True
491
+
492
+ if begin_generate:
493
+ current_index += 1
494
+ if current_index == nums_generate:
495
+ current_index = 0
496
+ snac = get_snac(list_output, index, nums_generate)
497
+ audio_stream = generate_audio_data(snac, self.snacmodel)
498
+ yield audio_stream
499
+
500
+ input_pos = input_pos.add_(1)
501
+ index += 1
502
+ text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
503
+ print(f"text output: {text}")
504
+ model.clear_kv_cache()
505
+ return list_output
506
+
507
+
508
+ def test_infer():
509
+ device = "cuda:0"
510
+ out_dir = f"./output/{get_time_str()}"
511
+ ckpt_dir = f"./checkpoint"
512
+ if not os.path.exists(ckpt_dir):
513
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
514
+ download_model(ckpt_dir)
515
+
516
+ fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
517
+
518
+ task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
519
+
520
+ # prepare test data
521
+ # TODO
522
+ test_audio_list = sorted(os.listdir('./data/samples'))
523
+ test_audio_list = [os.path.join('./data/samples', path) for path in test_audio_list]
524
+ test_audio_transcripts = [
525
+ "What is your name?",
526
+ "what are your hobbies?",
527
+ "Do you like beijing",
528
+ "How are you feeling today?",
529
+ "what is the weather like today?",
530
+ ]
531
+ test_text_list = [
532
+ "What is your name?",
533
+ "How are you feeling today?",
534
+ "Can you describe your surroundings?",
535
+ "What did you do yesterday?",
536
+ "What is your favorite book and why?",
537
+ "How do you make a cup of tea?",
538
+ "What is the weather like today?",
539
+ "Can you explain the concept of time?",
540
+ "Can you tell me a joke?",
541
+ ]
542
+
543
+ # LOAD MODEL
544
+ with torch.no_grad():
545
+ if "A1A2" in task:
546
+ print("===============================================================")
547
+ print(" testing A1A2")
548
+ print("===============================================================")
549
+ step = 0
550
+ for path in test_audio_list:
551
+ try:
552
+ mel, leng = load_audio(path)
553
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
554
+ text = A1_A2(
555
+ fabric,
556
+ audio_feature,
557
+ input_ids,
558
+ leng,
559
+ model,
560
+ text_tokenizer,
561
+ step,
562
+ snacmodel,
563
+ out_dir=out_dir,
564
+ )
565
+ print(f"input: {test_audio_transcripts[step]}")
566
+ print(f"output: {text}")
567
+ step += 1
568
+ print(
569
+ "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
570
+ )
571
+ except:
572
+ print(f"[error] failed to process {path}")
573
+ print("===============================================================")
574
+
575
+ if 'asr' in task:
576
+ print("===============================================================")
577
+ print(" testing asr")
578
+ print("===============================================================")
579
+
580
+ index = 0
581
+ step = 0
582
+ for path in test_audio_list:
583
+ mel, leng = load_audio(path)
584
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
585
+ output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
586
+ print(f"audio_path: {path}")
587
+ print(f"audio transcript: {test_audio_transcripts[index]}")
588
+ print(f"asr output: {output}")
589
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
590
+ index += 1
591
+
592
+ if "T1A2" in task:
593
+ step = 0
594
+ print("\n")
595
+ print("===============================================================")
596
+ print(" testing T1A2")
597
+ print("===============================================================")
598
+ for text in test_text_list:
599
+ input_ids = get_input_ids_TA(text, text_tokenizer)
600
+ text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
601
+ snacmodel, out_dir=out_dir)
602
+ print(f"input: {text}")
603
+ print(f"output: {text_output}")
604
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
605
+ step += 1
606
+ print("===============================================================")
607
+
608
+ if "T1T2" in task:
609
+ step = 0
610
+ print("\n")
611
+ print("===============================================================")
612
+ print(" testing T1T2")
613
+ print("===============================================================")
614
+
615
+ for text in test_text_list:
616
+ input_ids = get_input_ids_TT(text, text_tokenizer)
617
+ text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
618
+ print(f" Input: {text}")
619
+ print(f"Output: {text_output}")
620
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
621
+ print("===============================================================")
622
+
623
+ if "AT" in task:
624
+ print("===============================================================")
625
+ print(" testing A1T2")
626
+ print("===============================================================")
627
+ step = 0
628
+ for path in test_audio_list:
629
+ mel, leng = load_audio(path)
630
+ audio_feature, input_ids = get_input_ids_whisper(
631
+ mel, leng, whispermodel, device,
632
+ special_token_a=_pad_a, special_token_t=_answer_t
633
+ )
634
+ text = A1_T2(
635
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
636
+ )
637
+ print(f"input: {test_audio_transcripts[step]}")
638
+ print(f"output: {text}")
639
+ step += 1
640
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
641
+ print("===============================================================")
642
+
643
+ if "AA-BATCH" in task:
644
+ print("===============================================================")
645
+ print(" testing A1A2-BATCH")
646
+ print("===============================================================")
647
+ step = 0
648
+ for path in test_audio_list:
649
+ mel, leng = load_audio(path)
650
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
651
+ text = A1_A2_batch(
652
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
653
+ snacmodel, out_dir=out_dir
654
+ )
655
+ print(f"input: {test_audio_transcripts[step]}")
656
+ print(f"output: {text}")
657
+ step += 1
658
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
659
+ print("===============================================================")
660
+
661
+ print("*********************** test end *****************************")
662
+
663
+
664
+
665
+ if __name__ == "__main__":
666
+ test_infer()
litgpt/.DS_Store ADDED
Binary file (6.15 kB). View file
 
litgpt/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import logging
4
+ import re
5
+ from litgpt.model import GPT # needs to be imported before config
6
+ from litgpt.config import Config
7
+ from litgpt.tokenizer import Tokenizer
8
+
9
+ # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
10
+ pattern = re.compile(".*Profiler function .* will be ignored")
11
+ logging.getLogger("torch._dynamo.variables.torch").addFilter(
12
+ lambda record: not pattern.search(record.getMessage())
13
+ )
14
+
15
+ # Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
16
+ logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
17
+ logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
18
+
19
+ __all__ = ["GPT", "Config", "Tokenizer"]
litgpt/config.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Optional, Type, Union
7
+
8
+ import torch
9
+ import yaml
10
+ from typing_extensions import Self
11
+
12
+ import litgpt.model
13
+ from litgpt.utils import find_multiple
14
+
15
+
16
+ @dataclass
17
+ class Config:
18
+ name: str = ""
19
+ hf_config: dict = field(default_factory=dict)
20
+ scale_embeddings: bool = False
21
+ block_size: int = 4096
22
+ vocab_size: int = 50254
23
+ padding_multiple: int = 512
24
+ padded_vocab_size: Optional[int] = None
25
+ n_layer: int = 16
26
+ n_head: int = 32
27
+ head_size: Optional[int] = None
28
+ n_embd: int = 4096
29
+ rotary_percentage: float = 0.25
30
+ parallel_residual: bool = True
31
+ bias: bool = True
32
+ lm_head_bias: bool = False
33
+ # to use multi-head attention (MHA), set this to `n_head` (default)
34
+ # to use multi-query attention (MQA), set this to 1
35
+ # to use grouped-query attention (GQA), set this to a value in between
36
+ # Example with `n_head=4`
37
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
38
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
39
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
40
+ # │ │ │ │ │ │ │
41
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
42
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
43
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
44
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
45
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
46
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
47
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
48
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
49
+ # MHA GQA MQA
50
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
51
+ #
52
+ # credit https://arxiv.org/pdf/2305.13245.pdf
53
+ n_query_groups: Optional[int] = None
54
+ shared_attention_norm: bool = False
55
+ norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
56
+ norm_eps: float = 1e-5
57
+ mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
58
+ "GptNeoxMLP"
59
+ )
60
+ gelu_approximate: str = "none"
61
+ intermediate_size: Optional[int] = None
62
+ rope_condense_ratio: int = 1
63
+ rope_base: int = 10000
64
+ n_expert: int = 0
65
+ n_expert_per_token: int = 0
66
+
67
+ add_qkv_bias: Optional[bool] = None
68
+ prompt_vocab_size: Optional[int] = None
69
+ attn_dropout: float = 0.0
70
+ pos_type: str = "rope"
71
+ force_align: bool = False
72
+ use_pretrain_phoneme_emb: bool = False
73
+ tie_word_embeddings: bool = False
74
+
75
+ # setting for mini-omni
76
+ text_vocab_size:int = 152000
77
+ cat_audio_vocab_size: int = 29120
78
+ audio_vocab_size: int = 4160
79
+ whisper_adapter_dim: int = 768
80
+
81
+ post_adapter: bool = False
82
+ post_adapter_layers: int = 6
83
+ asr_adapter: str = "llamamlp"
84
+
85
+ def __post_init__(self):
86
+ if not self.name:
87
+ self.name = self.hf_config.get("name", self.name)
88
+
89
+ if self.head_size is None:
90
+ assert self.n_embd % self.n_head == 0
91
+ self.head_size = self.n_embd // self.n_head
92
+
93
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
94
+ if self.padded_vocab_size is None:
95
+ self.padded_vocab_size = find_multiple(
96
+ self.vocab_size, self.padding_multiple
97
+ )
98
+ else:
99
+ # vocab size shouldn't be larger than padded vocab size
100
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
101
+
102
+ # compute the number of query groups
103
+ if self.n_query_groups is not None:
104
+ assert self.n_head % self.n_query_groups == 0
105
+ else:
106
+ self.n_query_groups = self.n_head
107
+
108
+ # compute the intermediate size for MLP if not set
109
+ if self.intermediate_size is None:
110
+ if self.mlp_class_name == "LLaMAMLP":
111
+ raise ValueError(
112
+ f"The config {self.name!r}, needs to set the `intermediate_size`"
113
+ )
114
+ self.intermediate_size = 4 * self.n_embd
115
+
116
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
117
+
118
+ if self.add_qkv_bias is None:
119
+ self.add_qkv_bias = self.bias
120
+
121
+ @classmethod
122
+ def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
123
+ if name not in name_to_config:
124
+ # search through all `config['hf_config']['name']`
125
+ try:
126
+ conf_dict = next(
127
+ config
128
+ for config in configs
129
+ if name == config["hf_config"]["name"]
130
+ or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
131
+ == name
132
+ )
133
+ except StopIteration:
134
+ raise ValueError(f"{name!r} is not a supported config name")
135
+ else:
136
+ conf_dict = name_to_config[name]
137
+
138
+ conf_dict = conf_dict.copy()
139
+ conf_dict.update(kwargs)
140
+ return cls(**conf_dict)
141
+
142
+ @classmethod
143
+ def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
144
+ with open(path, encoding="utf-8") as fp:
145
+ file_kwargs = yaml.safe_load(fp)
146
+ if file_kwargs is None:
147
+ raise ValueError(f"{path} is empty which is likely unexpected.")
148
+ file_kwargs.update(kwargs)
149
+ return cls(**file_kwargs)
150
+
151
+ @classmethod
152
+ def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
153
+ """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
154
+ if (config_path := path / "model_config.yaml").is_file():
155
+ return cls.from_file(config_path, **kwargs)
156
+ if (model_name := path.name) in name_to_config:
157
+ return cls.from_name(model_name, **kwargs)
158
+ raise FileNotFoundError(
159
+ f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
160
+ )
161
+
162
+ @property
163
+ def mlp_class(self) -> Type:
164
+ # `self.mlp_class_name` cannot be the type to keep the config serializable
165
+ return getattr(litgpt.model, self.mlp_class_name)
166
+
167
+ @property
168
+ def norm_class(self) -> Type:
169
+ # `self.norm_class_name` cannot be the type to keep the config serializable
170
+ if self.norm_class_name == "RMSNorm":
171
+ from functools import partial
172
+
173
+ from litgpt.model import RMSNorm
174
+
175
+ return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
176
+ return getattr(torch.nn, self.norm_class_name)
177
+
178
+
179
+ configs = []
180
+ name_to_config = {config["name"]: config for config in configs}
litgpt/generate/__init__.py ADDED
File without changes
litgpt/generate/base.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from typing import Any, Literal, Optional
4
+
5
+ import torch
6
+ # import torch._dynamo.config
7
+ # import torch._inductor.config
8
+
9
+ from litgpt.model import GPT
10
+ from utils.snac_utils import layershift, snac_config
11
+ from tqdm import tqdm
12
+
13
+
14
+ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
15
+ if torch._dynamo.is_compiling():
16
+ # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
17
+ distribution = torch.empty_like(probs).exponential_(1)
18
+ return torch.argmax(probs / distribution, dim=-1, keepdim=True)
19
+ return torch.multinomial(probs, num_samples=1)
20
+
21
+
22
+ def sample_top_p(logits_A: torch.Tensor, top_p: float) -> torch.Tensor:
23
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
24
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
25
+ # Example:
26
+ # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
27
+ # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
28
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
29
+ # Keep at least 1 token always to prevent the case where no token is selected
30
+ # In this case the most probable one is always kept
31
+ sorted_indices_to_remove[-1:] = 0
32
+ indices_to_remove = sorted_indices_to_remove.scatter(
33
+ 0, sorted_indices, sorted_indices_to_remove
34
+ )
35
+ logits = logits.masked_fill(indices_to_remove, float("-inf"))
36
+ return logits
37
+
38
+
39
+ def sample(
40
+ logits: torch.Tensor,
41
+ temperature: float = 1.0,
42
+ top_k: Optional[int] = None,
43
+ top_p: float = 1.0,
44
+ ) -> torch.Tensor:
45
+ if top_p < 0.0 or top_p > 1.0:
46
+ raise ValueError(f"top_p must be in [0, 1], got {top_p}")
47
+ logits = logits[0, -1]
48
+ # optionally crop the logits to only the top k options
49
+ if top_k is not None:
50
+ v, i = torch.topk(logits, min(top_k, logits.size(-1)))
51
+ # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
52
+ logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
53
+ # optionally scale the logits and sample from a probability distribution
54
+ if temperature > 0.0 or top_p > 0.0:
55
+ if temperature > 0.0:
56
+ logits = logits / temperature
57
+ # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
58
+ if top_p < 1.0:
59
+ logits = sample_top_p(logits, top_p)
60
+ probs = torch.nn.functional.softmax(logits, dim=-1)
61
+ return multinomial_num_samples_1(probs)
62
+ return torch.argmax(logits, dim=-1, keepdim=True)
63
+
64
+
65
+ def next_token(
66
+ model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any
67
+ ) -> torch.Tensor:
68
+ input_pos = input_pos.to(model.device)
69
+ logits_a, logit_t = model(x, input_pos)
70
+
71
+ next_audio_tokens = []
72
+ for logit_a in logits_a:
73
+ next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype)
74
+ next_audio_tokens.append(next_a)
75
+ next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype)
76
+ return next_audio_tokens, next_t
77
+
78
+
79
+ def next_token_asr(
80
+ model: GPT,
81
+ input_pos: torch.Tensor,
82
+ audio_features: torch.tensor,
83
+ lens: int,
84
+ input_ids: list,
85
+ **kwargs: Any,
86
+ ) -> torch.Tensor:
87
+ input_pos = input_pos.to(model.device)
88
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
89
+ logits_a, logit_t = model(audio_features, input_ids, input_pos, whisper_lens=lens)
90
+
91
+ next_audio_tokens = []
92
+ for logit_a in logits_a:
93
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
94
+ next_audio_tokens.append(next_a)
95
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
96
+ return next_audio_tokens, next_t
97
+
98
+
99
+ def next_token_A1T2(
100
+ model: GPT,
101
+ audio_features: torch.tensor,
102
+ input_ids: list,
103
+ whisper_lens: int,
104
+ task: list,
105
+ input_pos: torch.Tensor,
106
+ **kwargs: Any,
107
+ ) -> torch.Tensor:
108
+ input_pos = input_pos.to(model.device)
109
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
110
+ logits_a, logit_t = model(
111
+ audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
112
+ )
113
+
114
+ next_audio_tokens = []
115
+ for logit_a in logits_a:
116
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
117
+ next_audio_tokens.append(next_a)
118
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
119
+ return next_audio_tokens, next_t
120
+
121
+
122
+ def next_token_A1T1(
123
+ model: GPT,
124
+ audio_features: torch.tensor,
125
+ input_ids: list,
126
+ whisper_lens: int,
127
+ task: list,
128
+ input_pos: torch.Tensor,
129
+ **kwargs: Any,
130
+ ) -> torch.Tensor:
131
+ input_pos = input_pos.to(model.device)
132
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
133
+ logits_a, logit_t = model(
134
+ audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
135
+ )
136
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
137
+ return next_t
138
+
139
+
140
+ def next_token_batch(
141
+ model: GPT,
142
+ audio_features: torch.tensor,
143
+ input_ids: list,
144
+ whisper_lens: int,
145
+ task: list,
146
+ input_pos: torch.Tensor,
147
+ **kwargs: Any,
148
+ ) -> torch.Tensor:
149
+ input_pos = input_pos.to(model.device)
150
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
151
+ logits_a, logit_t = model(
152
+ audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
153
+ )
154
+
155
+ for i in range(7):
156
+ logits_a[i] = logits_a[i][0].unsqueeze(0)
157
+ logit_t = logit_t[1].unsqueeze(0)
158
+
159
+ next_audio_tokens = []
160
+ for logit_a in logits_a:
161
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
162
+ next_audio_tokens.append(next_a)
163
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
164
+ return next_audio_tokens, next_t
165
+
166
+
167
+ # torch._dynamo.config.automatic_dynamic_shapes = True
168
+ # torch._inductor.config.triton.unique_kernel_names = True
169
+ # torch._inductor.config.coordinate_descent_tuning = True
170
+ # next_token = torch.compile(next_token, mode="reduce-overhead")
171
+
172
+
173
+ @torch.inference_mode()
174
+ def generate(
175
+ model: GPT,
176
+ input_ids: list,
177
+ max_returned_tokens: int,
178
+ *,
179
+ temperature: float = 1.0,
180
+ top_k: Optional[int] = None,
181
+ top_p: float = 1.0,
182
+ eos_id_a: Optional[int] = None,
183
+ eos_id_t: Optional[int] = None,
184
+ pad_id: Optional[int] = None,
185
+ shift: Optional[int] = None,
186
+ include_prompt: bool = True,
187
+ generate_text=False,
188
+ ) -> torch.Tensor:
189
+ # print("eos_id_a:", eos_id_a)
190
+ # print("eos_id_t:", eos_id_t)
191
+ # print("pad_id:", pad_id)
192
+ """
193
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
194
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
195
+
196
+ Args:
197
+ model: The model to use.
198
+ prompt: Tensor of shape (T) with indices of the prompt sequence.
199
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
200
+ temperature: Scales the predicted logits by 1 / temperature.
201
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
202
+ top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
203
+ In top-p sampling, the next token is sampled from the highest probability tokens
204
+ whose cumulative probability exceeds the threshold `top_p`. When specified,
205
+ it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
206
+ to sampling the most probable token, while `top_p=1` samples from the whole distribution.
207
+ It can be used in conjunction with `top_k` and `temperature` with the following order
208
+ of application:
209
+
210
+ 1. `top_k` sampling
211
+ 2. `temperature` scaling
212
+ 3. `top_p` sampling
213
+
214
+ For more details, see https://arxiv.org/abs/1904.09751
215
+ or https://huyenchip.com/2024/01/16/sampling.html#top_p
216
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
217
+ include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
218
+ """
219
+ T = input_ids[0].size(0)
220
+ device = input_ids[0].device
221
+ assert max_returned_tokens > T
222
+ if model.max_seq_length < max_returned_tokens - 1:
223
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
224
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
225
+ # not support it to avoid negatively impacting the overall speed
226
+ raise NotImplementedError(
227
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
228
+ )
229
+
230
+ for input_id in input_ids:
231
+ input_id = [input_id]
232
+ (
233
+ tokens_A1,
234
+ tokens_A2,
235
+ tokens_A3,
236
+ tokens_A4,
237
+ tokens_A5,
238
+ tokens_A6,
239
+ tokens_A7,
240
+ tokens_T,
241
+ ) = input_ids
242
+
243
+ tokens_A1_output = [tokens_A1]
244
+ tokens_A2_output = [tokens_A2]
245
+ tokens_A3_output = [tokens_A3]
246
+ tokens_A4_output = [tokens_A4]
247
+ tokens_A5_output = [tokens_A5]
248
+ tokens_A6_output = [tokens_A6]
249
+ tokens_A7_output = [tokens_A7]
250
+ tokens_T_output = [tokens_T]
251
+
252
+ list_output = [
253
+ tokens_A1_output,
254
+ tokens_A2_output,
255
+ tokens_A3_output,
256
+ tokens_A4_output,
257
+ tokens_A5_output,
258
+ tokens_A6_output,
259
+ tokens_A7_output,
260
+ tokens_T_output,
261
+ ]
262
+
263
+ input_pos = torch.tensor([T], device=device)
264
+ model_input_ids = [
265
+ tokens_A1.view(1, -1),
266
+ tokens_A2.view(1, -1),
267
+ tokens_A3.view(1, -1),
268
+ tokens_A4.view(1, -1),
269
+ tokens_A5.view(1, -1),
270
+ tokens_A6.view(1, -1),
271
+ tokens_A7.view(1, -1),
272
+ tokens_T.view(1, -1),
273
+ ]
274
+
275
+ tokens_A, token_T = next_token(
276
+ model,
277
+ torch.arange(0, T, device=device),
278
+ model_input_ids,
279
+ temperature=temperature,
280
+ top_k=top_k,
281
+ top_p=top_p,
282
+ )
283
+ for i in range(7):
284
+ list_output[i].append(tokens_A[i].clone())
285
+ list_output[7].append(token_T.clone())
286
+
287
+ # prepare the input for the next iteration
288
+ for i in range(7):
289
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
290
+ token_T = token_T.clone()
291
+
292
+ text_end = False
293
+ max_returned_tokens = 1000
294
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
295
+ model_input_ids = [
296
+ token_a.view(1, -1).to(torch.int32) for token_a in tokens_A
297
+ ] + [token_T.view(1, -1).to(torch.int32)]
298
+ tokens_A, token_T = next_token(
299
+ model,
300
+ input_pos,
301
+ model_input_ids,
302
+ temperature=temperature,
303
+ top_k=top_k,
304
+ top_p=top_p,
305
+ )
306
+ if text_end:
307
+ token_T = torch.tensor([pad_id], device=device)
308
+
309
+ for i in range(7):
310
+ list_output[i].append(tokens_A[i].clone())
311
+ list_output[7].append(token_T.clone())
312
+
313
+ if tokens_A[-1] == eos_id_a:
314
+ break
315
+ if token_T == eos_id_t:
316
+ if generate_text:
317
+ break
318
+ text_end = True
319
+
320
+ for i in range(7):
321
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
322
+ token_T = token_T.clone()
323
+ input_pos = input_pos.add_(1)
324
+
325
+ for i in range(len(list_output)):
326
+ list_output[i] = torch.cat(list_output[i])
327
+ return list_output
328
+
329
+
330
+ @torch.inference_mode()
331
+ def generate_TA_BATCH(
332
+ model: GPT,
333
+ audio_features: torch.Tensor,
334
+ input_ids: list,
335
+ leng,
336
+ task,
337
+ max_returned_tokens: int = 1000,
338
+ *,
339
+ temperature: float = 1.0,
340
+ top_k: Optional[int] = None,
341
+ top_p: float = 1.0,
342
+ eos_id_a: Optional[int] = None,
343
+ eos_id_t: Optional[int] = None,
344
+ pad_id_t: Optional[int] = None,
345
+ shift: Optional[int] = None,
346
+ include_prompt: bool = True,
347
+ generate_text=False,
348
+ ) -> torch.Tensor:
349
+
350
+ T = input_ids[0].size(1)
351
+ device = input_ids[0].device
352
+ assert max_returned_tokens > T
353
+ if model.max_seq_length < max_returned_tokens - 1:
354
+ raise NotImplementedError(
355
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
356
+ )
357
+
358
+ input_pos = torch.tensor([T], device=device)
359
+ model_input_ids = input_ids
360
+
361
+ list_output = [[] for i in range(8)]
362
+
363
+ tokens_A, token_T = next_token_batch(
364
+ model,
365
+ audio_features.to(torch.float32).to(model.device),
366
+ input_ids,
367
+ [T - 3, T - 3],
368
+ ["A1T2", "A1T2"],
369
+ input_pos=torch.arange(0, T, device=device),
370
+ temperature=temperature,
371
+ top_k=top_k,
372
+ top_p=top_p,
373
+ )
374
+
375
+ for i in range(7):
376
+ list_output[i].append(tokens_A[i].tolist()[0])
377
+ list_output[7].append(token_T.tolist()[0])
378
+
379
+ model_input_ids = [[] for i in range(8)]
380
+ for i in range(7):
381
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
382
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
383
+ model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device))
384
+ model_input_ids[i] = torch.stack(model_input_ids[i])
385
+
386
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
387
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
388
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
389
+
390
+ text_end = False
391
+
392
+ for _ in range(2, max_returned_tokens - T + 1):
393
+ tokens_A, token_T = next_token_batch(
394
+ model,
395
+ None,
396
+ model_input_ids,
397
+ None,
398
+ None,
399
+ input_pos=input_pos,
400
+ temperature=temperature,
401
+ top_k=top_k,
402
+ top_p=top_p,
403
+ )
404
+
405
+ if text_end:
406
+ token_T = torch.tensor([pad_id_t], device=device)
407
+
408
+ if tokens_A[-1] == eos_id_a:
409
+ break
410
+ if token_T == eos_id_t:
411
+ text_end = True
412
+
413
+ for i in range(7):
414
+ list_output[i].append(tokens_A[i].tolist()[0])
415
+ list_output[7].append(token_T.tolist()[0])
416
+
417
+ model_input_ids = [[] for i in range(8)]
418
+ for i in range(7):
419
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
420
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
421
+ model_input_ids[i].append(
422
+ torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)
423
+ )
424
+ model_input_ids[i] = torch.stack(model_input_ids[i])
425
+
426
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
427
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
428
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
429
+
430
+ input_pos = input_pos.add_(1)
431
+
432
+ return list_output
433
+
434
+
435
+ @torch.inference_mode()
436
+ def generate_TT(
437
+ model: GPT,
438
+ audio_features: torch.Tensor,
439
+ input_ids: list,
440
+ leng,
441
+ task,
442
+ max_returned_tokens: int = 2048,
443
+ *,
444
+ temperature: float = 1.0,
445
+ top_k: Optional[int] = None,
446
+ top_p: float = 1.0,
447
+ eos_id_a: Optional[int] = None,
448
+ eos_id_t: Optional[int] = None,
449
+ pad_id_t: Optional[int] = None,
450
+ shift: Optional[int] = None,
451
+ include_prompt: bool = True,
452
+ generate_text=False,
453
+ ) -> torch.Tensor:
454
+
455
+ T = input_ids[0].size(1)
456
+ device = input_ids[0].device
457
+
458
+ output = []
459
+ token_T = next_token_A1T1(
460
+ model,
461
+ None,
462
+ input_ids,
463
+ None,
464
+ None,
465
+ input_pos=torch.arange(0, T, device=device),
466
+ temperature=temperature,
467
+ top_k=top_k,
468
+ top_p=top_p,
469
+ )
470
+
471
+ output.append(token_T.clone().tolist()[0])
472
+ input_pos = torch.tensor([T], device=device)
473
+
474
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
475
+ model_input_ids = []
476
+ for i in range(7):
477
+ model_input_ids.append(
478
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
479
+ .view(1, -1)
480
+ .to(torch.int32)
481
+ .to(device)
482
+ )
483
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
484
+ token_T = next_token_A1T1(
485
+ model,
486
+ None,
487
+ model_input_ids,
488
+ None,
489
+ None,
490
+ input_pos=input_pos,
491
+ temperature=temperature,
492
+ top_k=top_k,
493
+ top_p=top_p,
494
+ )
495
+ if token_T == eos_id_t:
496
+ break
497
+ output.append(token_T.clone().tolist()[0])
498
+ input_pos = input_pos.add_(1)
499
+ return output
500
+
501
+
502
+ @torch.inference_mode()
503
+ def generate_AT(
504
+ model: GPT,
505
+ audio_features: torch.Tensor,
506
+ input_ids: list,
507
+ leng,
508
+ task,
509
+ max_returned_tokens: int = 2048,
510
+ *,
511
+ temperature: float = 1.0,
512
+ top_k: Optional[int] = None,
513
+ top_p: float = 1.0,
514
+ eos_id_a: Optional[int] = None,
515
+ eos_id_t: Optional[int] = None,
516
+ pad_id_t: Optional[int] = None,
517
+ shift: Optional[int] = None,
518
+ include_prompt: bool = True,
519
+ generate_text=False,
520
+ ) -> torch.Tensor:
521
+
522
+ T = input_ids[0].size(1)
523
+ device = input_ids[0].device
524
+
525
+ output = []
526
+ token_T = next_token_A1T1(
527
+ model,
528
+ audio_features.to(torch.float32).to(model.device),
529
+ input_ids,
530
+ [T - 3],
531
+ ["AT"],
532
+ input_pos=torch.arange(0, T, device=device),
533
+ temperature=temperature,
534
+ top_k=top_k,
535
+ top_p=top_p,
536
+ )
537
+ output.append(token_T.clone().tolist()[0])
538
+ input_pos = torch.tensor([T], device=device)
539
+ text_end = False
540
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
541
+ model_input_ids = []
542
+ for i in range(7):
543
+ model_input_ids.append(
544
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
545
+ .view(1, -1)
546
+ .to(torch.int32)
547
+ .to(device)
548
+ )
549
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
550
+ token_T = next_token_A1T1(
551
+ model,
552
+ None,
553
+ model_input_ids,
554
+ None,
555
+ None,
556
+ input_pos=input_pos,
557
+ temperature=temperature,
558
+ top_k=top_k,
559
+ top_p=top_p,
560
+ )
561
+ if token_T == eos_id_t:
562
+ break
563
+ output.append(token_T.clone().tolist()[0])
564
+ input_pos = input_pos.add_(1)
565
+ return output
566
+
567
+
568
+ @torch.inference_mode()
569
+ def generate_TA(
570
+ model: GPT,
571
+ audio_features: torch.Tensor,
572
+ input_ids: list,
573
+ leng,
574
+ task,
575
+ max_returned_tokens: int = 2048,
576
+ *,
577
+ temperature: float = 1.0,
578
+ top_k: Optional[int] = None,
579
+ top_p: float = 1.0,
580
+ eos_id_a: Optional[int] = None,
581
+ eos_id_t: Optional[int] = None,
582
+ pad_id_t: Optional[int] = None,
583
+ shift: Optional[int] = None,
584
+ include_prompt: bool = True,
585
+ generate_text=False,
586
+ ) -> torch.Tensor:
587
+
588
+ T = input_ids[0].size(1)
589
+ device = input_ids[0].device
590
+
591
+ output = [[] for _ in range(8)]
592
+ tokens_A, token_T = next_token_A1T2(
593
+ model,
594
+ None,
595
+ input_ids,
596
+ None,
597
+ None,
598
+ input_pos=torch.arange(0, T, device=device),
599
+ temperature=temperature,
600
+ top_k=top_k,
601
+ top_p=top_p,
602
+ )
603
+ for i in range(7):
604
+ output[i].append(tokens_A[i].clone().tolist()[0])
605
+ output[7].append(token_T.clone().tolist()[0])
606
+
607
+ input_pos = torch.tensor([T], device=device)
608
+ text_end = False
609
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
610
+
611
+ model_input_ids = []
612
+ for i in range(7):
613
+ model_input_ids.append(
614
+ layershift(tokens_A[i].clone(), i)
615
+ .view(1, -1)
616
+ .to(torch.int32)
617
+ .to(device)
618
+ )
619
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
620
+
621
+ tokens_A, token_T = next_token_A1T2(
622
+ model,
623
+ None,
624
+ model_input_ids,
625
+ None,
626
+ None,
627
+ input_pos=input_pos,
628
+ temperature=temperature,
629
+ top_k=top_k,
630
+ top_p=top_p,
631
+ )
632
+
633
+ if text_end:
634
+ token_T = torch.tensor([pad_id_t], device=device)
635
+
636
+ if tokens_A[-1] == eos_id_a:
637
+ break
638
+
639
+ if token_T == eos_id_t:
640
+ text_end = True
641
+
642
+ for i in range(7):
643
+ output[i].append(tokens_A[i].clone().tolist()[0])
644
+ output[7].append(token_T.clone().tolist()[0])
645
+ input_pos = input_pos.add_(1)
646
+
647
+ return output
648
+
649
+
650
+ @torch.inference_mode()
651
+ def generate_AA(
652
+ model: GPT,
653
+ audio_features: torch.Tensor,
654
+ input_ids: list,
655
+ leng,
656
+ task,
657
+ max_returned_tokens: int = 2048,
658
+ *,
659
+ temperature: float = 1.0,
660
+ top_k: Optional[int] = None,
661
+ top_p: float = 1.0,
662
+ eos_id_a: Optional[int] = None,
663
+ eos_id_t: Optional[int] = None,
664
+ pad_id_t: Optional[int] = None,
665
+ shift: Optional[int] = None,
666
+ include_prompt: bool = True,
667
+ generate_text=False,
668
+ ) -> torch.Tensor:
669
+
670
+ T = input_ids[0].size(1)
671
+ device = input_ids[0].device
672
+
673
+ output = [[] for _ in range(8)]
674
+ tokens_A, token_T = next_token_A1T2(
675
+ model,
676
+ audio_features.to(torch.float32).to(model.device),
677
+ input_ids,
678
+ [T - 3],
679
+ ["A1T2"],
680
+ input_pos=torch.arange(0, T, device=device),
681
+ temperature=temperature,
682
+ top_k=top_k,
683
+ top_p=top_p,
684
+ )
685
+ for i in range(7):
686
+ output[i].append(tokens_A[i].clone().tolist()[0])
687
+ output[7].append(token_T.clone().tolist()[0])
688
+
689
+ input_pos = torch.tensor([T], device=device)
690
+
691
+ text_end = False
692
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
693
+
694
+ model_input_ids = []
695
+ for i in range(7):
696
+ model_input_ids.append(
697
+ layershift(tokens_A[i].clone(), i)
698
+ .view(1, -1)
699
+ .to(torch.int32)
700
+ .to(device)
701
+ )
702
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
703
+
704
+ tokens_A, token_T = next_token_A1T2(
705
+ model,
706
+ None,
707
+ model_input_ids,
708
+ None,
709
+ None,
710
+ input_pos=input_pos,
711
+ temperature=temperature,
712
+ top_k=top_k,
713
+ top_p=top_p,
714
+ )
715
+
716
+ if text_end:
717
+ token_T = torch.tensor([pad_id_t], device=device)
718
+
719
+ if tokens_A[-1] == eos_id_a:
720
+ break
721
+ if token_T == eos_id_t:
722
+ # print("text_end")
723
+ text_end = True
724
+
725
+ for i in range(7):
726
+ output[i].append(tokens_A[i].clone().tolist()[0])
727
+ output[7].append(token_T.clone().tolist()[0])
728
+ input_pos = input_pos.add_(1)
729
+
730
+ return output
731
+
732
+
733
+ @torch.inference_mode()
734
+ def generate_ASR(
735
+ model: GPT,
736
+ audio_features: torch.Tensor,
737
+ input_ids: list,
738
+ leng,
739
+ task,
740
+ max_returned_tokens: int = 1200,
741
+ *,
742
+ temperature: float = 1.0,
743
+ top_k: Optional[int] = None,
744
+ top_p: float = 1.0,
745
+ eos_id_a: Optional[int] = None,
746
+ eos_id_t: Optional[int] = None,
747
+ pad_id_t: Optional[int] = None,
748
+ shift: Optional[int] = None,
749
+ include_prompt: bool = True,
750
+ generate_text=False,
751
+ ) -> torch.Tensor:
752
+
753
+ T = input_ids[0].size(1)
754
+ device = input_ids[0].device
755
+ output = []
756
+ token_T = next_token_A1T1(
757
+ model,
758
+ audio_features.to(torch.float32).to(model.device),
759
+ input_ids,
760
+ [T - 3],
761
+ ["asr"],
762
+ input_pos=torch.arange(0, T, device=device),
763
+ temperature=temperature,
764
+ top_k=top_k,
765
+ top_p=top_p,
766
+ )
767
+ output.append(token_T.clone().tolist()[0])
768
+ input_pos = torch.tensor([T], device=device)
769
+ text_end = False
770
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
771
+ model_input_ids = []
772
+ for i in range(7):
773
+ model_input_ids.append(
774
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
775
+ .view(1, -1)
776
+ .to(torch.int32)
777
+ .to(device)
778
+ )
779
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
780
+ token_T = next_token_A1T1(
781
+ model,
782
+ None,
783
+ model_input_ids,
784
+ None,
785
+ None,
786
+ input_pos=input_pos,
787
+ temperature=temperature,
788
+ top_k=top_k,
789
+ top_p=top_p,
790
+ )
791
+ if token_T == eos_id_t:
792
+ break
793
+ output.append(token_T.clone().tolist()[0])
794
+ input_pos = input_pos.add_(1)
795
+ return output
litgpt/model.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Full definition of a decoder-only transformer-based language model, all of it in this single file.
4
+
5
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
6
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
7
+ """
8
+
9
+ import math
10
+ from typing import Any, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing_extensions import Self
15
+ from litgpt.config import Config
16
+
17
+
18
+ class GPT(nn.Module):
19
+ def __init__(self, config: Config) -> None:
20
+ super().__init__()
21
+ assert config.padded_vocab_size is not None
22
+ self.config = config
23
+ if self.config.asr_adapter == "mlp":
24
+ print("Using MLP adapter for ASR feature")
25
+ self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)
26
+ elif self.config.asr_adapter == "llamamlp":
27
+ print("using LLAMA MLP adapter for ASR feature")
28
+ self.whisper_adapter = whisperMLP(config=config)
29
+ else:
30
+ raise ValueError("asr_adapter should be mlp or llamamlp")
31
+ self.lm_head = nn.Linear(
32
+ config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
33
+ )
34
+ if config.post_adapter:
35
+ self.transformer = nn.ModuleDict(
36
+ dict(
37
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
38
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
39
+ post_adapter=nn.ModuleList(
40
+ Block(config) for _ in range(config.post_adapter_layers)
41
+ ),
42
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
43
+ post_adapter_audio_ln=config.norm_class(
44
+ config.n_embd, eps=config.norm_eps
45
+ ),
46
+ post_adapter_audio_lm_head=nn.Linear(
47
+ config.n_embd, config.cat_audio_vocab_size, bias=config.lm_head_bias
48
+ ),
49
+ )
50
+ )
51
+ else:
52
+ self.transformer = nn.ModuleDict(
53
+ dict(
54
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
55
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
56
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
57
+ )
58
+ )
59
+ self.max_seq_length = self.config.block_size
60
+ self.mask_cache: Optional[torch.Tensor] = None
61
+ if config.tie_word_embeddings:
62
+ self.lm_head.weight = self.transformer.wte.weight
63
+
64
+ @property
65
+ def max_seq_length(self) -> int:
66
+ return self._max_seq_length
67
+
68
+ @max_seq_length.setter
69
+ def max_seq_length(self, value: int) -> None:
70
+ """
71
+ When doing inference, the sequences used might be shorter than the model's context length.
72
+ This allows setting a smaller number to avoid allocating unused memory
73
+ """
74
+ if value > self.config.block_size:
75
+ raise ValueError(
76
+ f"Cannot attend to {value}, block size is only {self.config.block_size}"
77
+ )
78
+ self._max_seq_length = value
79
+ if not hasattr(self, "cos"):
80
+ # first call
81
+ cos, sin = self.rope_cache()
82
+ self.register_buffer("cos", cos, persistent=False)
83
+ self.register_buffer("sin", sin, persistent=False)
84
+ # override
85
+ elif value != self.cos.size(0):
86
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
87
+ # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
88
+ # if the kv cache is expected
89
+
90
+ def reset_parameters(self) -> None:
91
+ # Trigger resetting the rope-cache
92
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
93
+
94
+ def _init_weights(self, module: nn.Module) -> None:
95
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
96
+ if isinstance(module, nn.Linear):
97
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
98
+ if module.bias is not None:
99
+ torch.nn.init.zeros_(module.bias)
100
+ elif isinstance(module, nn.Embedding):
101
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
102
+
103
+ def concat_whisper_feat(self, audio_feature, input_ids, T, task):
104
+ for j in range(len(T)):
105
+ if task[j] != "T1T2" and task[j] != "T1A2":
106
+ for i in range(7):
107
+ input_ids[i][j, 1 : T[j] + 1, :] = audio_feature[j][: T[j]].clone()
108
+ else:
109
+ continue
110
+ return input_ids
111
+
112
+ def forward(
113
+ self,
114
+ audio_features: torch.Tensor,
115
+ input_ids: torch.Tensor,
116
+ input_pos: Optional[torch.Tensor] = None,
117
+ whisper_lens: Optional[list] = None,
118
+ task: Optional[str] = None,
119
+ ) -> torch.Tensor:
120
+
121
+ show = False
122
+ T = input_ids[0].size(1)
123
+ if self.max_seq_length < T:
124
+ raise ValueError(
125
+ f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
126
+ )
127
+
128
+ if input_pos is not None: # use the kv cache
129
+ cos = self.cos.index_select(0, input_pos)
130
+ sin = self.sin.index_select(0, input_pos)
131
+ if self.mask_cache is None:
132
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
133
+ mask = self.mask_cache.index_select(2, input_pos)
134
+ else:
135
+ cos = self.cos[:T]
136
+ sin = self.sin[:T]
137
+ mask = None
138
+
139
+ if audio_features is not None:
140
+ # get whisper feature
141
+ x_a = self.whisper_adapter(audio_features)
142
+ # get input_ids embedding
143
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
144
+
145
+ x0 = self.transformer.wte(x0)
146
+ x1 = self.transformer.wte(x1)
147
+ x2 = self.transformer.wte(x2)
148
+ x3 = self.transformer.wte(x3)
149
+ x4 = self.transformer.wte(x4)
150
+ x5 = self.transformer.wte(x5)
151
+ x6 = self.transformer.wte(x6)
152
+ x7 = self.transformer.wte(x7)
153
+
154
+ # concat whisper feature
155
+ input_emb = self.concat_whisper_feat(
156
+ x_a, [x0, x1, x2, x3, x4, x5, x6, x7], whisper_lens, task
157
+ )
158
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_emb
159
+
160
+ else:
161
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
162
+
163
+ x0 = self.transformer.wte(x0)
164
+ x1 = self.transformer.wte(x1)
165
+ x2 = self.transformer.wte(x2)
166
+ x3 = self.transformer.wte(x3)
167
+ x4 = self.transformer.wte(x4)
168
+ x5 = self.transformer.wte(x5)
169
+ x6 = self.transformer.wte(x6)
170
+ x7 = self.transformer.wte(x7)
171
+
172
+ x = (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
173
+
174
+ if self.config.scale_embeddings:
175
+ x = x * (self.config.n_embd**0.5)
176
+
177
+ for block in self.transformer.h:
178
+ x = block(x, cos, sin, mask, input_pos)
179
+
180
+
181
+ text_vocab_size = self.config.text_vocab_size
182
+ audio_vocab_size = self.config.audio_vocab_size
183
+
184
+ x_ori = x
185
+ x_ori = self.transformer.ln_f(x_ori)
186
+ x_ori = self.lm_head(x_ori) # (b, t, vocab_size)
187
+ xt = x_ori[..., :text_vocab_size]
188
+
189
+ if self.config.post_adapter:
190
+ for block in self.transformer.post_adapter:
191
+ x = block(x, cos, sin, mask, input_pos)
192
+ x = self.transformer.post_adapter_audio_ln(x)
193
+ x = self.transformer.post_adapter_audio_lm_head(x) # (b, t, vocab_size)
194
+ xa = []
195
+ for i in range(7):
196
+ xa.append(x[..., audio_vocab_size * i : audio_vocab_size * (i + 1)])
197
+ else:
198
+ xa = []
199
+ for i in range(7):
200
+ xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
201
+
202
+ return xa, xt
203
+
204
+ @classmethod
205
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
206
+ return cls(Config.from_name(name, **kwargs))
207
+
208
+ def rope_cache(
209
+ self, device: Optional[torch.device] = None
210
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
211
+ return build_rope_cache(
212
+ seq_len=self.max_seq_length,
213
+ n_elem=self.config.rope_n_elem,
214
+ device=device,
215
+ condense_ratio=self.config.rope_condense_ratio,
216
+ base=self.config.rope_base,
217
+ )
218
+
219
+ def set_kv_cache(
220
+ self,
221
+ batch_size: int,
222
+ rope_cache_length: Optional[int] = None,
223
+ device: Optional[torch.device] = None,
224
+ dtype: Optional[torch.dtype] = None,
225
+ ) -> None:
226
+ if rope_cache_length is None:
227
+ rope_cache_length = self.cos.size(-1)
228
+ max_seq_length = self.max_seq_length
229
+
230
+ # initialize the kv cache for all blocks
231
+ for block in self.transformer.h:
232
+ block.attn.kv_cache = block.attn.build_kv_cache(
233
+ batch_size, max_seq_length, rope_cache_length, device, dtype
234
+ )
235
+ if self.config.post_adapter:
236
+ for block in self.transformer.post_adapter:
237
+ block.attn.kv_cache = block.attn.build_kv_cache(
238
+ batch_size, max_seq_length, rope_cache_length, device, dtype
239
+ )
240
+
241
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
242
+ # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
243
+ # for the kv-cache support (only during inference), we only create it in that situation
244
+ self.mask_cache = build_mask_cache(max_seq_length, device)
245
+
246
+ def clear_kv_cache(self) -> None:
247
+ self.mask_cache = None
248
+ for block in self.transformer.h:
249
+ block.attn.kv_cache = None
250
+
251
+
252
+ class Block(nn.Module):
253
+
254
+ def __init__(self, config: Config) -> None:
255
+ super().__init__()
256
+ if not config.parallel_residual and config.shared_attention_norm:
257
+ raise NotImplementedError(
258
+ "No checkpoint amongst the ones we support uses this configuration"
259
+ " (non-parallel residual and shared attention norm)."
260
+ )
261
+
262
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
263
+ self.attn = CausalSelfAttention(config)
264
+ self.norm_2 = (
265
+ None
266
+ if config.shared_attention_norm
267
+ else config.norm_class(config.n_embd, eps=config.norm_eps)
268
+ )
269
+ self.mlp = config.mlp_class(config)
270
+
271
+ self.config = config
272
+
273
+ def forward(
274
+ self,
275
+ x: torch.Tensor,
276
+ cos: torch.Tensor,
277
+ sin: torch.Tensor,
278
+ mask: Optional[torch.Tensor] = None,
279
+ input_pos: Optional[torch.Tensor] = None,
280
+ ) -> torch.Tensor:
281
+ """
282
+ Non-parallel residual Parallel residual
283
+ ┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True,
284
+ │ ↓ │ ↓ ↓ the output from `norm_1` is reused
285
+ │ norm_1 │ norm_1 ───► norm_2
286
+ │ ↓ │ ↓ ↓
287
+ │ attn │ attn mlp
288
+ │ ↓ │ ↓ │
289
+ ┌─ └► + └► + ◄───────────┘
290
+ │ norm_2
291
+ │ ↓
292
+ │ mlp
293
+ │ ↓
294
+ └───► +
295
+ """
296
+
297
+ x_normed = self.norm_1(x)
298
+ attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
299
+
300
+ if self.config.parallel_residual:
301
+ x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
302
+ x = self.mlp(x_normed) + attention_output + x
303
+ else:
304
+ x = attention_output + x
305
+ x = self.mlp(self.norm_2(x)) + x
306
+ return x
307
+
308
+
309
+ class CausalSelfAttention(nn.Module):
310
+ def __init__(self, config: Config) -> None:
311
+ super().__init__()
312
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
313
+ # key, query, value projections for all heads, but in a batch
314
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias)
315
+ # output projection
316
+ # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
317
+ self.proj = nn.Linear(
318
+ config.head_size * config.n_head, config.n_embd, bias=config.bias
319
+ )
320
+ # disabled by default
321
+ self.kv_cache: Optional[KVCache] = None
322
+
323
+ self.config = config
324
+
325
+ def forward(
326
+ self,
327
+ x: torch.Tensor,
328
+ cos: torch.Tensor,
329
+ sin: torch.Tensor,
330
+ mask: Optional[torch.Tensor] = None,
331
+ input_pos: Optional[torch.Tensor] = None,
332
+ ) -> torch.Tensor:
333
+ B, T, C = (
334
+ x.size()
335
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
336
+
337
+ qkv = self.attn(x)
338
+
339
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
340
+ q_per_kv = self.config.n_head // self.config.n_query_groups
341
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
342
+ qkv = qkv.view(
343
+ B, T, self.config.n_query_groups, total_qkv, self.config.head_size
344
+ )
345
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
346
+
347
+ # split batched computation into three
348
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
349
+
350
+ # maybe repeat k and v if for the non multi-head attention cases
351
+ # training: flash attention requires it
352
+ # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
353
+ if self.config.n_query_groups != self.config.n_head and (
354
+ input_pos is None or self.config.n_query_groups != 1
355
+ ):
356
+ k = k.expand(
357
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
358
+ )
359
+ v = v.expand(
360
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
361
+ )
362
+
363
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
364
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
365
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
366
+
367
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
368
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
369
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
370
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
371
+
372
+ if input_pos is not None:
373
+ if not isinstance(self.kv_cache, KVCache):
374
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
375
+ k, v = self.kv_cache(input_pos, k, v)
376
+
377
+ y = self.scaled_dot_product_attention(q, k, v, mask)
378
+
379
+ y = y.reshape(
380
+ B, T, self.config.head_size * self.config.n_head
381
+ ) # re-assemble all head outputs side by side
382
+
383
+ # output projection
384
+ return self.proj(y)
385
+
386
+ def scaled_dot_product_attention(
387
+ self,
388
+ q: torch.Tensor,
389
+ k: torch.Tensor,
390
+ v: torch.Tensor,
391
+ mask: Optional[torch.Tensor] = None,
392
+ ) -> torch.Tensor:
393
+ scale = 1.0 / math.sqrt(self.config.head_size)
394
+ y = torch.nn.functional.scaled_dot_product_attention(
395
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
396
+ )
397
+ return y.transpose(1, 2)
398
+
399
+ def build_kv_cache(
400
+ self,
401
+ batch_size: int,
402
+ max_seq_length: int,
403
+ rope_cache_length: Optional[int] = None,
404
+ device: Optional[torch.device] = None,
405
+ dtype: Optional[torch.dtype] = None,
406
+ ) -> "KVCache":
407
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
408
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
409
+ if rope_cache_length is None:
410
+ if self.config.rotary_percentage != 1.0:
411
+ raise TypeError(
412
+ "Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
413
+ )
414
+ k_shape = v_shape
415
+ else:
416
+ k_shape = (
417
+ batch_size,
418
+ heads,
419
+ max_seq_length,
420
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
421
+ )
422
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
423
+
424
+
425
+ class GptNeoxMLP(nn.Module):
426
+ def __init__(self, config: Config) -> None:
427
+ super().__init__()
428
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
429
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
430
+
431
+ self.config = config
432
+
433
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
434
+ x = self.fc(x)
435
+ x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
436
+ return self.proj(x)
437
+
438
+
439
+ class LLaMAMLP(nn.Module):
440
+ def __init__(self, config: Config) -> None:
441
+ super().__init__()
442
+ self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
443
+ self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
444
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
445
+
446
+ self.config = config
447
+
448
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
449
+ x_fc_1 = self.fc_1(x)
450
+ x_fc_2 = self.fc_2(x)
451
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
452
+ return self.proj(x)
453
+
454
+
455
+ class whisperMLP(nn.Module):
456
+ def __init__(self, config: Config) -> None:
457
+ super().__init__()
458
+ self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
459
+ self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
460
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
461
+
462
+ self.config = config
463
+
464
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
465
+ x_fc_1 = self.fc_1(x)
466
+ x_fc_2 = self.fc_2(x)
467
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
468
+ return self.proj(x)
469
+
470
+
471
+ class GemmaMLP(LLaMAMLP):
472
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
473
+ x_fc_1 = self.fc_1(x)
474
+ x_fc_2 = self.fc_2(x)
475
+ x = (
476
+ torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate)
477
+ * x_fc_2
478
+ )
479
+ return self.proj(x)
480
+
481
+
482
+ class LLaMAMoE(nn.Module):
483
+ def __init__(self, config: Config) -> None:
484
+ super().__init__()
485
+ self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
486
+ self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
487
+
488
+ self.config = config
489
+
490
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
491
+ """
492
+ Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
493
+ See also figure 1 in https://arxiv.org/abs/2211.15841
494
+ """
495
+ B, T, C = (
496
+ x.size()
497
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
498
+ x = x.view(-1, C) # (B*T, C)
499
+ router = self.gate(x) # (B*T, n_expert)
500
+ probs, indices = torch.topk(
501
+ router, self.config.n_expert_per_token
502
+ ) # (B*T, n_expert_per_token)
503
+ probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
504
+ masks = indices.unsqueeze(-1) == torch.arange(
505
+ self.config.n_expert, device=x.device
506
+ )
507
+ masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
508
+ y = torch.zeros_like(x) # (B*T, C)
509
+ for mask, expert in zip(masks, self.experts):
510
+ token_idx, expert_idx = torch.where(mask)
511
+ y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
512
+ return y.view(B, T, C)
513
+
514
+
515
+ def build_rope_cache(
516
+ seq_len: int,
517
+ n_elem: int,
518
+ device: Optional[torch.device] = None,
519
+ base: int = 10000,
520
+ condense_ratio: int = 1,
521
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
522
+ """Enhanced Transformer with Rotary Position Embedding.
523
+
524
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
525
+ transformers/rope/__init__.py. MIT License:
526
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
527
+ """
528
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
529
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
530
+
531
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
532
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
533
+
534
+ # Calculate the product of position index and $\theta_i$
535
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
536
+
537
+ return torch.cos(idx_theta), torch.sin(idx_theta)
538
+
539
+
540
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
541
+ head_size = x.size(-1)
542
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
543
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
544
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
545
+ roped = (x * cos) + (rotated * sin)
546
+ return roped.to(dtype=x.dtype)
547
+
548
+
549
+ class KVCache(nn.Module):
550
+ def __init__(
551
+ self,
552
+ k_shape: Tuple[int, int, int, int],
553
+ v_shape: Tuple[int, int, int, int],
554
+ device: Optional[torch.device] = None,
555
+ dtype: Optional[torch.dtype] = None,
556
+ ) -> None:
557
+ super().__init__()
558
+ self.register_buffer(
559
+ "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
560
+ )
561
+ self.register_buffer(
562
+ "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
563
+ )
564
+
565
+ def forward(
566
+ self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
567
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
568
+ # move the buffer to the activation dtype for when AMP is used
569
+ self.k = self.k.to(k.dtype)
570
+ self.v = self.v.to(v.dtype)
571
+ # update the cache
572
+ k = self.k.index_copy_(2, input_pos, k)
573
+ v = self.v.index_copy_(2, input_pos, v)
574
+ return k, v
575
+
576
+ def reset_parameters(self) -> None:
577
+ torch.nn.init.zeros_(self.k)
578
+ torch.nn.init.zeros_(self.v)
579
+
580
+
581
+ def build_mask_cache(
582
+ max_seq_length: int, device: Optional[torch.device] = None
583
+ ) -> torch.Tensor:
584
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
585
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
586
+
587
+
588
+ class RMSNorm(torch.nn.Module):
589
+ """Root Mean Square Layer Normalization.
590
+
591
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
592
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
593
+ """
594
+
595
+ def __init__(
596
+ self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False
597
+ ) -> None:
598
+ super().__init__()
599
+ self.weight = torch.nn.Parameter(torch.ones(size))
600
+ self.eps = eps
601
+ self.dim = dim
602
+ self.add_unit_offset = add_unit_offset
603
+
604
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
605
+ dtype = x.dtype
606
+ x = x.float()
607
+ # NOTE: the original RMSNorm paper implementation is not equivalent
608
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
609
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
610
+ x_normed = x_normed.to(dtype=dtype)
611
+ if self.add_unit_offset:
612
+ # Gemma model requires a unit offset
613
+ # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
614
+ return x_normed * (1 + self.weight)
615
+ return x_normed * self.weight
616
+
617
+ def reset_parameters(self) -> None:
618
+ torch.nn.init.ones_(self.weight)
litgpt/tokenizer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+
9
+
10
+ class Tokenizer:
11
+ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
12
+ checkpoint_dir = Path(checkpoint_dir)
13
+ if not checkpoint_dir.exists():
14
+ raise NotADirectoryError(
15
+ f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
16
+ )
17
+
18
+ self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
19
+ self.bos_id = None
20
+ self.eos_id = None
21
+
22
+ # some checkpoints have both files, `.json` takes precedence
23
+ if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
24
+ from tokenizers import Tokenizer as HFTokenizer
25
+
26
+ self.processor = HFTokenizer.from_file(str(vocabulary_path))
27
+ self.backend = "huggingface"
28
+
29
+ if (
30
+ special_tokens_path := checkpoint_dir / "tokenizer_config.json"
31
+ ).is_file():
32
+ with open(special_tokens_path, encoding="utf-8") as fp:
33
+ config = json.load(fp)
34
+ bos_token = config.get("bos_token")
35
+ eos_token = config.get("eos_token")
36
+ if bos_token is not None and isinstance(bos_token, dict):
37
+ bos_token = bos_token.get("content")
38
+ if eos_token is not None and isinstance(eos_token, dict):
39
+ eos_token = eos_token.get("content")
40
+ self.bos_id = (
41
+ self.token_to_id(bos_token) if bos_token is not None else None
42
+ )
43
+ self.eos_id = (
44
+ self.token_to_id(eos_token) if eos_token is not None else None
45
+ )
46
+ if (
47
+ special_tokens_path := checkpoint_dir / "generation_config.json"
48
+ ).is_file():
49
+ with open(special_tokens_path, encoding="utf-8") as fp:
50
+ config = json.load(fp)
51
+ if self.bos_id is None:
52
+ self.bos_id = config.get("bos_token_id")
53
+ if self.eos_id is None:
54
+ self.eos_id = config.get("eos_token_id")
55
+
56
+ elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
57
+ from sentencepiece import SentencePieceProcessor
58
+
59
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
60
+ self.backend = "sentencepiece"
61
+ self.bos_id = self.processor.bos_id()
62
+ self.eos_id = self.processor.eos_id()
63
+ else:
64
+ raise NotImplementedError
65
+
66
+ @property
67
+ def vocab_size(self) -> int:
68
+ if self.backend == "huggingface":
69
+ return self.processor.get_vocab_size(with_added_tokens=False)
70
+ if self.backend == "sentencepiece":
71
+ return self.processor.vocab_size()
72
+ raise RuntimeError
73
+
74
+ def token_to_id(self, token: str) -> int:
75
+ if self.backend == "huggingface":
76
+ id_ = self.processor.token_to_id(token)
77
+ elif self.backend == "sentencepiece":
78
+ id_ = self.processor.piece_to_id(token)
79
+ else:
80
+ raise RuntimeError
81
+ if id_ is None:
82
+ raise ValueError(f"token {token!r} not found in the collection.")
83
+ return id_
84
+
85
+ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
86
+ if not (
87
+ tokenizer_config_path := checkpoint_dir / "tokenizer_config.json"
88
+ ).is_file():
89
+ return False
90
+ with open(tokenizer_config_path, encoding="utf-8") as fp:
91
+ config = json.load(fp)
92
+ if "add_bos_token" in config:
93
+ return config["add_bos_token"]
94
+ # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
95
+ # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
96
+ return config.get("tokenizer_class") == "LlamaTokenizer"
97
+
98
+ def encode(
99
+ self,
100
+ string: str,
101
+ device: Optional[torch.device] = None,
102
+ bos: Optional[bool] = None,
103
+ eos: bool = False,
104
+ max_length: int = -1,
105
+ ) -> torch.Tensor:
106
+ if self.backend == "huggingface":
107
+ tokens = self.processor.encode(string).ids
108
+ elif self.backend == "sentencepiece":
109
+ tokens = self.processor.encode(string)
110
+ else:
111
+ raise RuntimeError
112
+ if bos or (bos is None and self.use_bos):
113
+ bos_id = self.bos_id
114
+ if bos_id is None:
115
+ raise NotImplementedError(
116
+ "This tokenizer does not have a defined a bos token"
117
+ )
118
+ if tokens[0] != bos_id:
119
+ tokens = [bos_id] + tokens
120
+ if tokens is None:
121
+ raise ValueError("`tokens` is None")
122
+
123
+ if eos and (not tokens or tokens[-1] != self.eos_id):
124
+ tokens = tokens + [self.eos_id]
125
+ if max_length > 0:
126
+ tokens = tokens[:max_length]
127
+ return torch.tensor(tokens, dtype=torch.int, device=device)
128
+
129
+ def decode(self, tensor: torch.Tensor) -> str:
130
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
131
+ return self.processor.decode(tokens)
litgpt/utils.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Utility functions for training and inference."""
4
+ import inspect
5
+ import math
6
+ import os
7
+ import pickle
8
+ import shutil
9
+ import sys
10
+ from dataclasses import asdict, is_dataclass
11
+ from io import BytesIO
12
+ from pathlib import Path
13
+ from typing import (
14
+ TYPE_CHECKING,
15
+ Any,
16
+ Dict,
17
+ Iterable,
18
+ List,
19
+ Literal,
20
+ Mapping,
21
+ Optional,
22
+ TypeVar,
23
+ Union,
24
+ )
25
+
26
+ import lightning as L
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.utils._device
30
+ import yaml
31
+ from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
32
+ from lightning.fabric.strategies import FSDPStrategy
33
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
34
+ from lightning.pytorch.loggers import WandbLogger
35
+ from lightning.pytorch.cli import instantiate_class
36
+ from torch.serialization import normalize_storage_type
37
+ from typing_extensions import Self
38
+
39
+ if TYPE_CHECKING:
40
+ from litgpt import GPT, Config
41
+
42
+
43
+ def init_out_dir(out_dir: Path) -> Path:
44
+ if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
45
+ return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
46
+ return out_dir
47
+
48
+
49
+ def find_resume_path(
50
+ resume: Union[bool, Literal["auto"], Path], out_dir: Path
51
+ ) -> Optional[Path]:
52
+ if not resume or isinstance(resume, Path):
53
+ return resume
54
+
55
+ resume_path = max(
56
+ out_dir.rglob("step-*/*.pth"),
57
+ key=(lambda p: int(p.parent.name.split("-")[1])),
58
+ default=None,
59
+ )
60
+ if resume == "auto":
61
+ return resume_path
62
+ if resume is True and resume_path is None:
63
+ raise FileNotFoundError(
64
+ f"You passed `--resume=True`, but no checkpont file was found in `--out_dir={out_dir}`."
65
+ )
66
+ return resume_path
67
+
68
+
69
+ def find_multiple(n: int, k: int) -> int:
70
+ assert k > 0
71
+ if n % k == 0:
72
+ return n
73
+ return n + k - (n % k)
74
+
75
+
76
+ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
77
+ total = 0
78
+ for p in module.parameters():
79
+ if requires_grad is None or p.requires_grad == requires_grad:
80
+ if hasattr(p, "quant_state"):
81
+ # bitsandbytes 4bit layer support
82
+ total += math.prod(p.quant_state.shape)
83
+ else:
84
+ total += p.numel()
85
+ return total
86
+
87
+
88
+ def reset_parameters(module: nn.Module) -> None:
89
+ """Calls `reset_parameters` on the module and all its submodules."""
90
+ for mod in module.modules():
91
+ if callable(getattr(mod, "reset_parameters", None)):
92
+ mod.reset_parameters()
93
+
94
+
95
+ def check_valid_checkpoint_dir(
96
+ checkpoint_dir: Path,
97
+ model_filename: str = "lit_model.pth",
98
+ verbose: bool = True,
99
+ raise_error: bool = False,
100
+ ) -> None:
101
+ files = {
102
+ model_filename: (checkpoint_dir / model_filename).is_file(),
103
+ "model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
104
+ "tokenizer.json OR tokenizer.model": (
105
+ checkpoint_dir / "tokenizer.json"
106
+ ).is_file()
107
+ or (checkpoint_dir / "tokenizer.model").is_file(),
108
+ "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
109
+ }
110
+ if checkpoint_dir.is_dir():
111
+ if all(files.values()):
112
+ # we're good
113
+ return
114
+ problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
115
+ else:
116
+ problem = " is not a checkpoint directory"
117
+
118
+ # list locally available checkpoints
119
+ available = list(Path("checkpoints").glob("*/*"))
120
+ if available:
121
+ options = "\n".join([""] + [repr(str(p.resolve())) for p in available])
122
+ extra = f"\nYou have downloaded locally:{options}\n"
123
+ else:
124
+ extra = ""
125
+
126
+ if verbose:
127
+ error_message = (
128
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
129
+ "\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n"
130
+ f"{extra}\nSee all download options by running:\n litgpt download"
131
+ )
132
+ print(error_message, file=sys.stderr)
133
+
134
+ if raise_error:
135
+ raise FileNotFoundError(
136
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
137
+ )
138
+ else:
139
+ raise SystemExit(1)
140
+
141
+
142
+ class SavingProxyForStorage:
143
+ def __init__(self, obj, saver, protocol_version=5):
144
+ self.protocol_version = protocol_version
145
+ self.saver = saver
146
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
147
+ raise TypeError(f"expected storage, not {type(obj)}")
148
+
149
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
150
+ if isinstance(obj, torch.storage.TypedStorage):
151
+ # PT upstream wants to deprecate this eventually...
152
+ storage = obj._untyped_storage
153
+ storage_type_str = obj._pickle_storage_type()
154
+ storage_type = getattr(torch, storage_type_str)
155
+ storage_numel = obj._size()
156
+ else:
157
+ storage = obj
158
+ storage_type = normalize_storage_type(type(obj))
159
+ storage_numel = storage.nbytes()
160
+
161
+ storage_key = saver._write_storage_and_return_key(storage)
162
+ location = torch.serialization.location_tag(storage)
163
+
164
+ self.storage_info = (
165
+ "storage",
166
+ storage_type,
167
+ storage_key,
168
+ location,
169
+ storage_numel,
170
+ )
171
+
172
+ def __reduce_ex__(self, protocol_version):
173
+ assert False, "this should be handled with out of band"
174
+
175
+
176
+ class SavingProxyForTensor:
177
+ def __init__(self, tensor, saver, protocol_version=5):
178
+ self.protocol_version = protocol_version
179
+ self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
180
+ if reduce_args[0] == torch._utils._rebuild_tensor_v2:
181
+ # for Tensors with Python attributes
182
+ (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
183
+ assert isinstance(
184
+ storage, torch.storage.TypedStorage
185
+ ), "Please check for updates"
186
+ storage_proxy = SavingProxyForStorage(
187
+ storage, saver, protocol_version=protocol_version
188
+ )
189
+ self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
190
+ else:
191
+ (storage, *other_reduce_args) = reduce_args
192
+ assert isinstance(
193
+ storage, torch.storage.TypedStorage
194
+ ), "Please check for updates"
195
+ storage_proxy = SavingProxyForStorage(
196
+ storage, saver, protocol_version=protocol_version
197
+ )
198
+ self.reduce_args = (storage_proxy, *other_reduce_args)
199
+
200
+ def __reduce_ex__(self, protocol_version):
201
+ if protocol_version != self.protocol_version:
202
+ raise RuntimeError(
203
+ f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
204
+ )
205
+ return self.reduce_ret_fn, self.reduce_args
206
+
207
+
208
+ class IncrementalPyTorchPickler(pickle.Pickler):
209
+ def __init__(self, saver, *args, **kwargs):
210
+ super().__init__(*args, **kwargs)
211
+ self.storage_dtypes = {}
212
+ self.saver = saver
213
+ self.id_map = {}
214
+
215
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
216
+ def persistent_id(self, obj):
217
+ # FIXME: the docs say that persistent_id should only return a string
218
+ # but torch store returns tuples. This works only in the binary protocol
219
+ # see
220
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
221
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
222
+ if isinstance(obj, SavingProxyForStorage):
223
+ return obj.storage_info
224
+
225
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
226
+ if isinstance(obj, torch.storage.TypedStorage):
227
+ # TODO: Once we decide to break serialization FC, this case
228
+ # can be deleted
229
+ storage = obj._untyped_storage
230
+ storage_dtype = obj.dtype
231
+ storage_type_str = obj._pickle_storage_type()
232
+ storage_type = getattr(torch, storage_type_str)
233
+ storage_numel = obj._size()
234
+
235
+ else:
236
+ storage = obj
237
+ storage_dtype = torch.uint8
238
+ storage_type = normalize_storage_type(type(obj))
239
+ storage_numel = storage.nbytes()
240
+
241
+ # If storage is allocated, ensure that any other saved storages
242
+ # pointing to the same data all have the same dtype. If storage is
243
+ # not allocated, don't perform this check
244
+ if storage.data_ptr() != 0:
245
+ if storage.data_ptr() in self.storage_dtypes:
246
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
247
+ raise RuntimeError(
248
+ "Cannot save multiple tensors or storages that view the same data as different types"
249
+ )
250
+ else:
251
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
252
+
253
+ storage_key = self.id_map.get(storage._cdata)
254
+ if storage_key is None:
255
+ storage_key = self.saver._write_storage_and_return_key(storage)
256
+ self.id_map[storage._cdata] = storage_key
257
+ location = torch.serialization.location_tag(storage)
258
+
259
+ return ("storage", storage_type, storage_key, location, storage_numel)
260
+
261
+ return None
262
+
263
+
264
+ class incremental_save:
265
+ def __init__(self, name):
266
+ self.name = name
267
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
268
+ self.has_saved = False
269
+ self.next_key = 0
270
+
271
+ def __enter__(self):
272
+ return self
273
+
274
+ def store_early(self, tensor):
275
+ if isinstance(tensor, torch.Tensor):
276
+ return SavingProxyForTensor(tensor, self)
277
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
278
+
279
+ def save(self, obj):
280
+ if self.has_saved:
281
+ raise RuntimeError("have already saved")
282
+ # Write the pickle data for `obj`
283
+ data_buf = BytesIO()
284
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
285
+ pickler.dump(obj)
286
+ data_value = data_buf.getvalue()
287
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
288
+ self.has_saved = True
289
+
290
+ def _write_storage_and_return_key(self, storage):
291
+ if self.has_saved:
292
+ raise RuntimeError("have already saved")
293
+ key = self.next_key
294
+ self.next_key += 1
295
+ name = f"data/{key}"
296
+ if storage.device.type != "cpu":
297
+ storage = storage.cpu()
298
+ num_bytes = storage.nbytes()
299
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
300
+ return key
301
+
302
+ def __exit__(self, type, value, traceback):
303
+ self.zipfile.write_end_of_file()
304
+
305
+
306
+ T = TypeVar("T")
307
+
308
+
309
+ def chunked_cross_entropy(
310
+ logits: Union[torch.Tensor, List[torch.Tensor]],
311
+ targets: torch.Tensor,
312
+ chunk_size: int = 128,
313
+ ignore_index: int = -100,
314
+ ) -> torch.Tensor:
315
+ # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
316
+ # the memory usage in fine-tuning settings with low number of parameters.
317
+ # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
318
+ # the memory spike's magnitude
319
+
320
+ # lm_head was chunked (we are fine-tuning)
321
+ if isinstance(logits, list):
322
+ # don't want to chunk cross entropy
323
+ if chunk_size == 0:
324
+ logits = torch.cat(logits, dim=1)
325
+ logits = logits.reshape(-1, logits.size(-1))
326
+ targets = targets.reshape(-1)
327
+ return torch.nn.functional.cross_entropy(
328
+ logits, targets, ignore_index=ignore_index
329
+ )
330
+
331
+ # chunk cross entropy
332
+ logit_chunks = [
333
+ logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
334
+ ]
335
+ target_chunks = [
336
+ target_chunk.reshape(-1)
337
+ for target_chunk in targets.split(logits[0].size(1), dim=1)
338
+ ]
339
+ loss_chunks = [
340
+ torch.nn.functional.cross_entropy(
341
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
342
+ )
343
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
344
+ ]
345
+ non_masked_elems = (targets != ignore_index).sum()
346
+ # See [non_masked_elems div note]
347
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
348
+ torch.ones_like(non_masked_elems)
349
+ )
350
+
351
+ # no chunking at all
352
+ logits = logits.reshape(-1, logits.size(-1))
353
+ targets = targets.reshape(-1)
354
+ if chunk_size == 0:
355
+ return torch.nn.functional.cross_entropy(
356
+ logits, targets, ignore_index=ignore_index
357
+ )
358
+
359
+ # lm_head wasn't chunked, chunk cross entropy
360
+ logit_chunks = logits.split(chunk_size)
361
+ target_chunks = targets.split(chunk_size)
362
+ loss_chunks = [
363
+ torch.nn.functional.cross_entropy(
364
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
365
+ )
366
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
367
+ ]
368
+ non_masked_elems = (targets != ignore_index).sum()
369
+ # [non_masked_elems div note]:
370
+ # max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
371
+ # results in a python int which is then passed back to torch division. By using the
372
+ # `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
373
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
374
+ torch.ones_like(non_masked_elems)
375
+ )
376
+
377
+
378
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
379
+ for checkpoint_name, attribute_name in mapping.items():
380
+ full_checkpoint_name = prefix + checkpoint_name
381
+ if full_checkpoint_name in state_dict:
382
+ full_attribute_name = prefix + attribute_name
383
+ state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
384
+ return state_dict
385
+
386
+
387
+ def get_default_supported_precision(training: bool) -> str:
388
+ """Return default precision that is supported by the hardware: either `bf16` or `16`.
389
+
390
+ Args:
391
+ training: `-mixed` or `-true` version of the precision to use
392
+
393
+ Returns:
394
+ default precision that is suitable for the task and is supported by the hardware
395
+ """
396
+ from lightning.fabric.accelerators import MPSAccelerator
397
+
398
+ if MPSAccelerator.is_available() or (
399
+ torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
400
+ ):
401
+ return "16-mixed" if training else "16-true"
402
+ return "bf16-mixed" if training else "bf16-true"
403
+
404
+
405
+ def load_checkpoint(
406
+ fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
407
+ ) -> None:
408
+ if isinstance(fabric.strategy, FSDPStrategy):
409
+ fabric.load_raw(checkpoint_path, model, strict=strict)
410
+ else:
411
+ state_dict = lazy_load(checkpoint_path)
412
+ state_dict = state_dict.get("model", state_dict)
413
+ model.load_state_dict(state_dict, strict=strict)
414
+
415
+
416
+ def flops_per_param(
417
+ max_seq_length: int, n_layer: int, n_embd: int, n_params: int
418
+ ) -> int:
419
+ flops_per_token = (
420
+ 2 * n_params
421
+ ) # each parameter is used for a MAC (2 FLOPS) per network operation
422
+ # this assumes that all samples have a fixed length equal to the block size
423
+ # which is most likely false during finetuning
424
+ flops_per_seq = flops_per_token * max_seq_length
425
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
426
+ return flops_per_seq + attn_flops_per_seq
427
+
428
+
429
+ def estimate_flops(model: "GPT", training: bool) -> int:
430
+ """Measures estimated FLOPs for MFU.
431
+
432
+ Refs:
433
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
434
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
435
+ """
436
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
437
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
438
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
439
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
440
+ n_trainable_params = num_parameters(model, requires_grad=True)
441
+ trainable_flops = flops_per_param(
442
+ model.max_seq_length,
443
+ model.config.n_layer,
444
+ model.config.n_embd,
445
+ n_trainable_params,
446
+ )
447
+ # forward + backward + gradients (assumes no gradient accumulation)
448
+ ops_per_step = 3 if training else 1
449
+ n_frozen_params = num_parameters(model, requires_grad=False)
450
+ frozen_flops = flops_per_param(
451
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
452
+ )
453
+ # forward + backward
454
+ frozen_ops_per_step = 2 if training else 1
455
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
456
+
457
+
458
+ class CycleIterator:
459
+ """An iterator that cycles through an iterable indefinitely.
460
+
461
+ Example:
462
+ >>> iterator = CycleIterator([1, 2, 3])
463
+ >>> [next(iterator) for _ in range(5)]
464
+ [1, 2, 3, 1, 2]
465
+
466
+ Note:
467
+ Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
468
+ """
469
+
470
+ def __init__(self, iterable: Iterable) -> None:
471
+ self.iterable = iterable
472
+ self.epoch = 0
473
+ self._iterator = None
474
+
475
+ def __next__(self) -> Any:
476
+ if self._iterator is None:
477
+ self._iterator = iter(self.iterable)
478
+ try:
479
+ return next(self._iterator)
480
+ except StopIteration:
481
+ self._iterator = iter(self.iterable)
482
+ self.epoch += 1
483
+ return next(self._iterator)
484
+
485
+ def __iter__(self) -> Self:
486
+ return self
487
+
488
+
489
+ def copy_config_files(source_dir: Path, out_dir: Path) -> None:
490
+ """Copies the specified configuration and tokenizer files into the output directory."""
491
+
492
+ config_files = ["config.json", "generation_config.json", "model_config.yaml"]
493
+ tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
494
+
495
+ for file_name in config_files + tokenizer_files:
496
+ src_path = source_dir / file_name
497
+ if src_path.exists():
498
+ shutil.copy(src_path, out_dir)
499
+
500
+
501
+ def CLI(*args: Any, **kwargs: Any) -> Any:
502
+ from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
503
+
504
+ set_docstring_parse_options(attribute_docstrings=True)
505
+ set_config_read_mode(urls_enabled=True)
506
+
507
+ return CLI(*args, **kwargs)
508
+
509
+
510
+ def capture_hparams() -> Dict[str, Any]:
511
+ """Captures the local variables ('hyperparameters') from where this function gets called."""
512
+ caller_frame = inspect.currentframe().f_back
513
+ locals_of_caller = caller_frame.f_locals
514
+ hparams = {}
515
+ for name, value in locals_of_caller.items():
516
+ if value is None or isinstance(value, (int, float, str, bool, Path)):
517
+ hparams[name] = value
518
+ elif is_dataclass(value):
519
+ hparams[name] = asdict(value)
520
+ else:
521
+ hparams[name] = str(value)
522
+ return hparams
523
+
524
+
525
+ def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
526
+ """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
527
+ from jsonargparse import capture_parser
528
+
529
+ # TODO: Make this more robust
530
+ # This hack strips away the subcommands from the top-level CLI
531
+ # to parse the file as if it was called as a script
532
+ known_commands = [
533
+ ("finetune_full",), # For subcommands, use `("finetune", "full")` etc
534
+ ("finetune_lora",),
535
+ ("finetune_adapter",),
536
+ ("finetune_adapter_v2",),
537
+ ("finetune",),
538
+ ("pretrain",),
539
+ ]
540
+ for known_command in known_commands:
541
+ unwanted = slice(1, 1 + len(known_command))
542
+ if tuple(sys.argv[unwanted]) == known_command:
543
+ sys.argv[unwanted] = []
544
+
545
+ parser = capture_parser(lambda: CLI(function))
546
+ config = parser.parse_args()
547
+ parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
548
+
549
+
550
+ def save_config(config: "Config", checkpoint_dir: Path) -> None:
551
+ config_dict = asdict(config)
552
+ with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
553
+ yaml.dump(config_dict, fp)
554
+
555
+
556
+ def parse_devices(devices: Union[str, int]) -> int:
557
+ if devices in (-1, "auto"):
558
+ return torch.cuda.device_count() or 1
559
+ if isinstance(devices, int) and devices > 0:
560
+ return devices
561
+ raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")
562
+
563
+
564
+ def choose_logger(
565
+ logger_name: Literal["csv", "tensorboard", "wandb"],
566
+ out_dir: Path,
567
+ name: str,
568
+ log_interval: int = 1,
569
+ resume: Optional[bool] = None,
570
+ **kwargs: Any,
571
+ ):
572
+ if logger_name == "csv":
573
+ return CSVLogger(
574
+ root_dir=(out_dir / "logs"),
575
+ name="csv",
576
+ flush_logs_every_n_steps=log_interval,
577
+ **kwargs,
578
+ )
579
+ if logger_name == "tensorboard":
580
+ return TensorBoardLogger(
581
+ root_dir=(out_dir / "logs"), name="tensorboard", **kwargs
582
+ )
583
+ if logger_name == "wandb":
584
+ return WandbLogger(project=name, resume=resume, **kwargs)
585
+ raise ValueError(
586
+ f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'."
587
+ )
588
+
589
+
590
+ def get_argument_names(cls):
591
+ sig = inspect.signature(cls.__init__)
592
+ return {
593
+ name
594
+ for name, param in sig.parameters.items()
595
+ if param.kind
596
+ in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
597
+ }
598
+
599
+
600
+ def instantiate_bnb_optimizer(optimizer, model_parameters):
601
+ if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
602
+ isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
603
+ ):
604
+ raise ValueError(
605
+ "The chosen quantization format only supports the AdamW optimizer."
606
+ )
607
+
608
+ import bitsandbytes as bnb
609
+
610
+ if isinstance(optimizer, str):
611
+ optimizer = bnb.optim.PagedAdamW(model_parameters)
612
+ else:
613
+ optim_args = get_argument_names(bnb.optim.PagedAdamW)
614
+ allowed_kwargs = {
615
+ key: optimizer["init_args"][key]
616
+ for key in optim_args & optimizer["init_args"].keys()
617
+ }
618
+ optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
619
+ return optimizer
620
+
621
+
622
+ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
623
+ if isinstance(optimizer, str):
624
+ optimizer_cls = getattr(torch.optim, optimizer)
625
+ optimizer = optimizer_cls(model_parameters, **kwargs)
626
+ else:
627
+ optimizer = dict(optimizer) # copy
628
+ optimizer["init_args"].update(kwargs)
629
+ optimizer = instantiate_class(model_parameters, optimizer)
630
+ return optimizer
631
+
632
+
633
+ def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
634
+ new_checkpoint_dir = "checkpoints" / checkpoint_dir
635
+ should_return_new_dir = (
636
+ not checkpoint_dir.is_dir()
637
+ and checkpoint_dir.parts[0] != "checkpoints"
638
+ and not checkpoint_dir.is_absolute()
639
+ and new_checkpoint_dir.exists()
640
+ )
641
+ return new_checkpoint_dir if should_return_new_dir else checkpoint_dir
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ torchaudio==2.3.1
4
+ litgpt==0.4.3
5
+ snac==1.2.0
6
+ soundfile==0.12.1
7
+ openai-whisper
8
+ tokenizers==0.19.1
9
+ streamlit==1.37.1
10
+ # PyAudio==0.2.14
11
+ pydub==0.25.1
12
+ onnxruntime==1.19.0
13
+ # numpy==1.26.3
14
+ gradio==4.42.0
15
+ librosa==0.10.2.post1
16
+ flask==3.0.3
17
+ fire
utils/snac_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import numpy as np
4
+
5
+
6
+ class SnacConfig:
7
+ audio_vocab_size = 4096
8
+ padded_vocab_size = 4160
9
+ end_of_audio = 4097
10
+
11
+
12
+ snac_config = SnacConfig()
13
+
14
+
15
+ def get_time_str():
16
+ time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
17
+ return time_str
18
+
19
+
20
+ def layershift(input_id, layer, stride=4160, shift=152000):
21
+ return input_id + shift + layer * stride
22
+
23
+
24
+ def generate_audio_data(snac_tokens, snacmodel):
25
+ audio = reconstruct_tensors(snac_tokens)
26
+ with torch.inference_mode():
27
+ audio_hat = snacmodel.decode(audio)
28
+ audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
29
+ audio_data = audio_data.astype(np.int16)
30
+ audio_data = audio_data.tobytes()
31
+ return audio_data
32
+
33
+
34
+ def get_snac(list_output, index, nums_generate):
35
+
36
+ snac = []
37
+ start = index
38
+ for i in range(nums_generate):
39
+ snac.append("#")
40
+ for j in range(7):
41
+ snac.append(list_output[j][start - nums_generate - 5 + j + i])
42
+ return snac
43
+
44
+
45
+ def reconscruct_snac(output_list):
46
+ if len(output_list) == 8:
47
+ output_list = output_list[:-1]
48
+ output = []
49
+ for i in range(7):
50
+ output_list[i] = output_list[i][i + 1 :]
51
+ for i in range(len(output_list[-1])):
52
+ output.append("#")
53
+ for j in range(7):
54
+ output.append(output_list[j][i])
55
+ return output
56
+
57
+
58
+ def reconstruct_tensors(flattened_output):
59
+ """Reconstructs the list of tensors from the flattened output."""
60
+
61
+ def count_elements_between_hashes(lst):
62
+ try:
63
+ # Find the index of the first '#'
64
+ first_index = lst.index("#")
65
+ # Find the index of the second '#' after the first
66
+ second_index = lst.index("#", first_index + 1)
67
+ # Count the elements between the two indices
68
+ return second_index - first_index - 1
69
+ except ValueError:
70
+ # Handle the case where there aren't enough '#' symbols
71
+ return "List does not contain two '#' symbols"
72
+
73
+ def remove_elements_before_hash(flattened_list):
74
+ try:
75
+ # Find the index of the first '#'
76
+ first_hash_index = flattened_list.index("#")
77
+ # Return the list starting from the first '#'
78
+ return flattened_list[first_hash_index:]
79
+ except ValueError:
80
+ # Handle the case where there is no '#'
81
+ return "List does not contain the symbol '#'"
82
+
83
+ def list_to_torch_tensor(tensor1):
84
+ # Convert the list to a torch tensor
85
+ tensor = torch.tensor(tensor1)
86
+ # Reshape the tensor to have size (1, n)
87
+ tensor = tensor.unsqueeze(0)
88
+ return tensor
89
+
90
+ flattened_output = remove_elements_before_hash(flattened_output)
91
+ codes = []
92
+ tensor1 = []
93
+ tensor2 = []
94
+ tensor3 = []
95
+ tensor4 = []
96
+
97
+ n_tensors = count_elements_between_hashes(flattened_output)
98
+ if n_tensors == 7:
99
+ for i in range(0, len(flattened_output), 8):
100
+
101
+ tensor1.append(flattened_output[i + 1])
102
+ tensor2.append(flattened_output[i + 2])
103
+ tensor3.append(flattened_output[i + 3])
104
+ tensor3.append(flattened_output[i + 4])
105
+
106
+ tensor2.append(flattened_output[i + 5])
107
+ tensor3.append(flattened_output[i + 6])
108
+ tensor3.append(flattened_output[i + 7])
109
+ codes = [
110
+ list_to_torch_tensor(tensor1).cuda(),
111
+ list_to_torch_tensor(tensor2).cuda(),
112
+ list_to_torch_tensor(tensor3).cuda(),
113
+ ]
114
+
115
+ if n_tensors == 15:
116
+ for i in range(0, len(flattened_output), 16):
117
+
118
+ tensor1.append(flattened_output[i + 1])
119
+ tensor2.append(flattened_output[i + 2])
120
+ tensor3.append(flattened_output[i + 3])
121
+ tensor4.append(flattened_output[i + 4])
122
+ tensor4.append(flattened_output[i + 5])
123
+ tensor3.append(flattened_output[i + 6])
124
+ tensor4.append(flattened_output[i + 7])
125
+ tensor4.append(flattened_output[i + 8])
126
+
127
+ tensor2.append(flattened_output[i + 9])
128
+ tensor3.append(flattened_output[i + 10])
129
+ tensor4.append(flattened_output[i + 11])
130
+ tensor4.append(flattened_output[i + 12])
131
+ tensor3.append(flattened_output[i + 13])
132
+ tensor4.append(flattened_output[i + 14])
133
+ tensor4.append(flattened_output[i + 15])
134
+
135
+ codes = [
136
+ list_to_torch_tensor(tensor1).cuda(),
137
+ list_to_torch_tensor(tensor2).cuda(),
138
+ list_to_torch_tensor(tensor3).cuda(),
139
+ list_to_torch_tensor(tensor4).cuda(),
140
+ ]
141
+
142
+ return codes
143
+