lxysl commited on
Commit
bc752b1
·
1 Parent(s): 76b1e92

upload vita-1.5 app.py

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +479 -8
  2. vita/config/__init__.py +10 -0
  3. vita/config/dataset_config.py +8 -0
  4. vita/constants.py +14 -0
  5. vita/conversation.py +401 -0
  6. vita/model/__init__.py +5 -0
  7. vita/model/builder.py +287 -0
  8. vita/model/language_model/vita_fo_qwen2.py +227 -0
  9. vita/model/language_model/vita_mixtral.py +420 -0
  10. vita/model/language_model/vita_nemo.py +282 -0
  11. vita/model/language_model/vita_qwen2.py +304 -0
  12. vita/model/multimodal_encoder/builder.py +83 -0
  13. vita/model/multimodal_encoder/clip/clip_encoder.py +78 -0
  14. vita/model/multimodal_encoder/eva_clip/eva_clip_encoder.py +66 -0
  15. vita/model/multimodal_encoder/eva_clip/eva_clip_processors.py +69 -0
  16. vita/model/multimodal_encoder/eva_clip/eva_vit.py +982 -0
  17. vita/model/multimodal_encoder/internvit/configuration_intern_vit.py +125 -0
  18. vita/model/multimodal_encoder/internvit/flash_attention.py +101 -0
  19. vita/model/multimodal_encoder/internvit/internvit_encoder.py +105 -0
  20. vita/model/multimodal_encoder/internvit/modeling_intern_vit.py +394 -0
  21. vita/model/multimodal_encoder/siglip/siglip_encoder.py +149 -0
  22. vita/model/multimodal_encoder/whale/adapter.py +137 -0
  23. vita/model/multimodal_encoder/whale/cmvn.py +89 -0
  24. vita/model/multimodal_encoder/whale/init_model.py +192 -0
  25. vita/model/multimodal_encoder/whale/module/component/mamba.py +131 -0
  26. vita/model/multimodal_encoder/whale/module/component/subsampling.py +74 -0
  27. vita/model/multimodal_encoder/whale/module/component/transformer.py +428 -0
  28. vita/model/multimodal_encoder/whale/module/encoder/encoder.py +171 -0
  29. vita/model/multimodal_encoder/whale/module/layer/attention.py +571 -0
  30. vita/model/multimodal_encoder/whale/module/layer/conv1d.py +88 -0
  31. vita/model/multimodal_encoder/whale/module/layer/dtcblock.py +95 -0
  32. vita/model/multimodal_encoder/whale/module/layer/fsmn.py +129 -0
  33. vita/model/multimodal_encoder/whale/utils.py +146 -0
  34. vita/model/multimodal_projector/builder.py +185 -0
  35. vita/model/vita_arch.py +639 -0
  36. vita/model/vita_tts/adapter.py +157 -0
  37. vita/model/vita_tts/audioLLM.py +433 -0
  38. vita/model/vita_tts/decoder/decoder.py +367 -0
  39. vita/model/vita_tts/decoder/llm2tts.py +161 -0
  40. vita/model/vita_tts/decoder/ticodec/models.py +716 -0
  41. vita/model/vita_tts/decoder/ticodec/vqvae.py +57 -0
  42. vita/model/vita_tts/decoder/ticodec/vqvae_tester.py +37 -0
  43. vita/model/vita_tts/encoder/attention.py +459 -0
  44. vita/model/vita_tts/encoder/cmvn.py +107 -0
  45. vita/model/vita_tts/encoder/encoder.py +155 -0
  46. vita/model/vita_tts/encoder/subsampling.py +106 -0
  47. vita/model/vita_tts/encoder/transformer.py +285 -0
  48. vita/model/vita_tts/masks.py +195 -0
  49. vita/model/vita_tts/pipeline.py +131 -0
  50. vita/model/vita_tts/utils.py +48 -0
app.py CHANGED
@@ -1,14 +1,485 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
2
  import spaces
3
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
1
+ import torch
2
+ import os
3
+ import argparse
4
+ import numpy as np
5
+ import copy
6
  import gradio as gr
7
+ import re
8
+ import torchaudio
9
+ import io
10
+ import cv2
11
+ import math
12
  import spaces
13
+ from numba import jit
14
+ from huggingface_hub import snapshot_download
15
+
16
+ from vita.constants import DEFAULT_AUDIO_TOKEN, DEFAULT_IMAGE_TOKEN, MAX_IMAGE_LENGTH, MIN_IMAGE_LENGTH, IMAGE_TOKEN_INDEX, AUDIO_TOKEN_INDEX
17
+ from vita.conversation import conv_templates, SeparatorStyle
18
+ from vita.util.mm_utils import tokenizer_image_token, tokenizer_image_audio_token
19
+ from PIL import Image
20
+ from decord import VideoReader, cpu
21
+ from vita.model.builder import load_pretrained_model
22
+ from vita.model.vita_tts.decoder.llm2tts import llm2TTS
23
+ from vita.model.language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM
24
+
25
+ decoder_topk = 2
26
+ codec_chunk_size = 40
27
+ codec_padding_size = 10
28
+
29
+ PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛""„‟…‧﹏."
30
+
31
+ MODEL_NAME = "VITA-MLLM/VITA-1.5"
32
+ model_path = snapshot_download(MODEL_NAME, local_dir="VITA_ckpt")
33
+ tokenizer, model, feature_extractor, context_len = load_pretrained_model(
34
+ model_path, model_base=None, model_name="VITA-1.5", model_type="qwen2p5_instruct"
35
+ )
36
+ llm_embedding = model.get_input_embeddings().cuda()
37
+ tts = llm2TTS(os.path.join(model_path, 'vita_tts_ckpt/'))
38
+
39
+ @jit
40
+ def float_to_int16(audio: np.ndarray) -> np.ndarray:
41
+ am = int(math.ceil(float(np.abs(audio).max())) * 32768)
42
+ am = 32767 * 32768 // am
43
+ return np.multiply(audio, am).astype(np.int16)
44
+
45
+
46
+ def remove_special_characters(input_str):
47
+ # Remove special tokens
48
+ special_tokens = ['☞', '☟', '☜', '<unk>', '<|im_end|>']
49
+ for token in special_tokens:
50
+ input_str = input_str.replace(token, '')
51
+ return input_str
52
+
53
+
54
+ def replace_equation(sentence):
55
+ special_notations = {
56
+ "sin": " sine ",
57
+ "cos": " cosine ",
58
+ "tan": " tangent ",
59
+ "cot": " cotangent ",
60
+ "sec": " secant ",
61
+ "csc": " cosecant ",
62
+ "log": " logarithm ",
63
+ "exp": "e^",
64
+ "sqrt": "根号 ",
65
+ "abs": "绝对值 ",
66
+ }
67
+
68
+ special_operators = {
69
+ "+": "加",
70
+ "-": "减",
71
+ "*": "乘",
72
+ "/": "除",
73
+ "=": "等于",
74
+ '!=': '不等于',
75
+ '>': '大于',
76
+ '<': '小于',
77
+ '>=': '大于等于',
78
+ '<=': '小于等于',
79
+ }
80
+
81
+ greek_letters = {
82
+ "α": "alpha ",
83
+ "β": "beta ",
84
+ "γ": "gamma ",
85
+ "δ": "delta ",
86
+ "ε": "epsilon ",
87
+ "ζ": "zeta ",
88
+ "η": "eta ",
89
+ "θ": "theta ",
90
+ "ι": "iota ",
91
+ "κ": "kappa ",
92
+ "λ": "lambda ",
93
+ "μ": "mu ",
94
+ "ν": "nu ",
95
+ "ξ": "xi ",
96
+ "ο": "omicron ",
97
+ "π": "派 ",
98
+ "ρ": "rho ",
99
+ "σ": "sigma ",
100
+ "τ": "tau ",
101
+ "υ": "upsilon ",
102
+ "φ": "phi ",
103
+ "χ": "chi ",
104
+ "ψ": "psi ",
105
+ "ω": "omega "
106
+ }
107
+
108
+ sentence = sentence.replace('**', ' ')
109
+
110
+ sentence = re.sub(r'(?<![\d)])-(\d+)', r'负\1', sentence)
111
+
112
+ for key in special_notations:
113
+ sentence = sentence.replace(key, special_notations[key])
114
+ for key in special_operators:
115
+ sentence = sentence.replace(key, special_operators[key])
116
+ for key in greek_letters:
117
+ sentence = sentence.replace(key, greek_letters[key])
118
+
119
+
120
+ sentence = re.sub(r'\(?(\d+)\)?\((\d+)\)', r'\1乘\2', sentence)
121
+ sentence = re.sub(r'\(?(\w+)\)?\^\(?(\w+)\)?', r'\1的\2次方', sentence)
122
+
123
+ return sentence
124
+
125
+
126
+ def is_video(file_path):
127
+ video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'}
128
+ _, ext = os.path.splitext(file_path)
129
+ return ext.lower() in video_extensions
130
+
131
+ def is_image(file_path):
132
+ image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'}
133
+ _, ext = os.path.splitext(file_path)
134
+ return ext.lower() in image_extensions
135
+
136
+ def is_wav(file_path):
137
+ wav_extensions = {'.wav'}
138
+ _, ext = os.path.splitext(file_path)
139
+ return ext.lower() in wav_extensions
140
+
141
+ def load_model_embemding(model_path):
142
+ config_path = os.path.join(model_path, 'origin_config.json')
143
+ config = VITAQwen2Config.from_pretrained(config_path)
144
+ model = VITAQwen2ForCausalLM.from_pretrained(model_path, config=config, low_cpu_mem_usage=True)
145
+ embedding = model.get_input_embeddings()
146
+ del model
147
+ return embedding
148
+
149
+ def split_into_sentences(text):
150
+ sentence_endings = re.compile(r'[,。?\n!?、,?.!]')
151
+ sentences = sentence_endings.split(text)
152
+ return [sentence.strip() for sentence in sentences if sentence.strip()]
153
+
154
+ def convert_webm_to_mp4(input_file, output_file):
155
+ try:
156
+ cap = cv2.VideoCapture(input_file)
157
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
158
+ out = cv2.VideoWriter(output_file, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))
159
+
160
+ while cap.isOpened():
161
+ ret, frame = cap.read()
162
+ if not ret:
163
+ break
164
+ out.write(frame)
165
+
166
+ cap.release()
167
+ out.release()
168
+ except Exception as e:
169
+ print(f"Error: {e}")
170
+ raise
171
+
172
+
173
+ def _get_rawvideo_dec(video_path, max_frames=MAX_IMAGE_LENGTH, min_frames=MIN_IMAGE_LENGTH, video_framerate=1, s=None, e=None):
174
+ if s is None or e is None:
175
+ start_time, end_time = None, None
176
+ else:
177
+ start_time = int(s)
178
+ end_time = int(e)
179
+ start_time = max(start_time, 0)
180
+ end_time = max(end_time, 0)
181
+ if start_time > end_time:
182
+ start_time, end_time = end_time, start_time
183
+ elif start_time == end_time:
184
+ end_time = start_time + 1
185
+
186
+ if os.path.exists(video_path):
187
+ vreader = VideoReader(video_path, ctx=cpu(0))
188
+ else:
189
+ raise FileNotFoundError
190
+
191
+ fps = vreader.get_avg_fps()
192
+ f_start = 0 if start_time is None else int(start_time * fps)
193
+ f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
194
+ num_frames = f_end - f_start + 1
195
+
196
+ if num_frames > 0:
197
+ sample_fps = int(video_framerate)
198
+ t_stride = int(round(float(fps) / sample_fps))
199
+ all_pos = list(range(f_start, f_end + 1, t_stride))
200
+
201
+ if len(all_pos) > max_frames:
202
+ sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
203
+ elif len(all_pos) < min_frames:
204
+ sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)]
205
+ else:
206
+ sample_pos = all_pos
207
+
208
+ patch_images = [Image.fromarray(f).convert("RGB") for f in vreader.get_batch(sample_pos).asnumpy()]
209
+ return patch_images, len(patch_images)
210
+ else:
211
+ print(f"video path: {video_path} error.")
212
 
213
+ def _parse_text(text):
214
+ lines = text.split("\n")
215
+ lines = [line for line in lines if line != ""]
216
+ count = 0
217
+
218
+ for i, line in enumerate(lines):
219
+ if "```" in line:
220
+ count += 1
221
+ items = line.split("`")
222
+ if count % 2 == 1:
223
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
224
+ else:
225
+ lines[i] = "<br></code></pre>"
226
+ else:
227
+ if i > 0 and count % 2 == 1:
228
+ line = line.replace("`", r"\`")
229
+ line = line.replace("<", "&lt;")
230
+ line = line.replace(">", "&gt;")
231
+ line = line.replace(" ", "&nbsp;")
232
+ line = line.replace("*", "&ast;")
233
+ line = line.replace("_", "&lowbar;")
234
+ line = line.replace("-", "&#45;")
235
+ line = line.replace(".", "&#46;")
236
+ line = line.replace("!", "&#33;")
237
+ line = line.replace("(", "&#40;")
238
+ line = line.replace(")", "&#41;")
239
+ line = line.replace("$", "&#36;")
240
+ lines[i] = "<br>" + line
241
+
242
+ return "".join(lines)
243
+
244
+
245
+ @spaces.GPU
246
+ def predict(_chatbot, task_history):
247
+ chat_query = task_history[-1][0]
248
+ print(task_history)
249
+
250
+ conv_mode = "qwen2p5_instruct"
251
+ conv = conv_templates[conv_mode].copy()
252
+
253
+ all_audio_path = []
254
+ all_visual_tensor = []
255
+
256
+ qs = ''
257
+ input_mode = 'lang'
258
+ for i, (q, a) in enumerate(task_history):
259
+ if isinstance(q, (tuple, list)):
260
+ if is_image(q[0]):
261
+ images = [Image.open(q[0]).convert("RGB")]
262
+ all_visual_tensor.extend(images)
263
+ input_mode = 'image'
264
+ qs += DEFAULT_IMAGE_TOKEN * len(images) + '\n'
265
+ elif is_video(q[0]):
266
+ video_frames, slice_len = _get_rawvideo_dec(q[0])
267
+ all_visual_tensor.extend(video_frames)
268
+ input_mode = 'video'
269
+ qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n'
270
+ elif is_wav(q[0]):
271
+ if a is not None and a.startswith('☜'):
272
+ continue
273
+ else:
274
+ all_audio_path.append(q[0])
275
+ new_q = qs + DEFAULT_AUDIO_TOKEN
276
+ qs = ''
277
+ conv.append_message(conv.roles[0], new_q)
278
+ conv.append_message(conv.roles[1], a)
279
+ else:
280
+ new_q = qs + q
281
+ qs = ''
282
+ conv.append_message(conv.roles[0], new_q)
283
+ conv.append_message(conv.roles[1], a)
284
+
285
+ prompt = conv.get_prompt(input_mode)
286
+
287
+ if all_audio_path != []:
288
+ input_ids = tokenizer_image_audio_token(
289
+ prompt, tokenizer,
290
+ image_token_index=IMAGE_TOKEN_INDEX,
291
+ audio_token_index=AUDIO_TOKEN_INDEX
292
+ )
293
+ audio_list = []
294
+ for single_audio_path in all_audio_path:
295
+ try:
296
+ audio, original_sr = torchaudio.load(single_audio_path)
297
+ target_sr = 16000
298
+ if original_sr != target_sr:
299
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
300
+ audio = resampler(audio)
301
+ audio_features = feature_extractor(audio, sampling_rate=target_sr, return_tensors="pt")["input_features"]
302
+ audio_list.append(audio_features.squeeze(0))
303
+ except Exception as e:
304
+ print(f"Error processing {single_audio_path}: {e}")
305
+ else:
306
+ input_ids = tokenizer_image_token(
307
+ prompt, tokenizer,
308
+ image_token_index=IMAGE_TOKEN_INDEX
309
+ )
310
+
311
+ if all_visual_tensor == [] and all_audio_path == []:
312
+ datapromt = {
313
+ "prompt_token_ids": input_ids,
314
+ }
315
+ elif all_visual_tensor != [] and all_audio_path == []:
316
+ datapromt = {
317
+ "prompt_token_ids": input_ids,
318
+ "multi_modal_data": {
319
+ "image": all_visual_tensor
320
+ },
321
+ }
322
+ elif all_visual_tensor == [] and all_audio_path != []:
323
+ datapromt = {
324
+ "prompt_token_ids": input_ids,
325
+ "multi_modal_data": {
326
+ "audio": audio_list
327
+ },
328
+ }
329
+ else:
330
+ datapromt = {
331
+ "prompt_token_ids": input_ids,
332
+ "multi_modal_data": {
333
+ "image": all_visual_tensor,
334
+ "audio": audio_list
335
+ },
336
+ }
337
+
338
+ print(datapromt)
339
+
340
+ with torch.inference_mode():
341
+ output_ids = model.generate(
342
+ input_ids,
343
+ images=all_visual_tensor,
344
+ audios=audio_list,
345
+ do_sample=False,
346
+ temperature=0.01,
347
+ top_p=None,
348
+ num_beams=1,
349
+ output_scores=True,
350
+ return_dict_in_generate=True,
351
+ max_new_tokens=1024,
352
+ use_cache=True,
353
+ )
354
+
355
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0]
356
+ outputs = outputs.strip()
357
+
358
+ task_history[-1] = (chat_query, outputs)
359
+ remove_special_characters_output = remove_special_characters(outputs)
360
+ _chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output))
361
+ print("query", chat_query)
362
+ print("task_history", task_history)
363
+ print(_chatbot)
364
+ print("answer: ", outputs)
365
+ yield _chatbot
366
+
367
+
368
+ def add_text(history, task_history, text):
369
+ task_text = text
370
+ if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
371
+ task_text = text[:-1]
372
+ history = history + [(_parse_text(text), None)]
373
+ task_history = task_history + [(task_text, None)]
374
+ return history, task_history, ""
375
+
376
+ def add_file(history, task_history, file):
377
+ history = history + [((file.name,), None)]
378
+ task_history = task_history + [((file.name,), None)]
379
+ return history, task_history
380
+
381
+ def add_audio(history, task_history, file):
382
+ print(file)
383
+ if file is None:
384
+ return history, task_history
385
+ history = history + [((file,), None)]
386
+ task_history = task_history + [((file,), None)]
387
+ return history, task_history
388
+
389
+ def add_video(history, task_history, file):
390
+ print(file)
391
+ if file is None:
392
+ return history, task_history
393
+ new_file_name = file.replace(".webm",".mp4")
394
+ if file.endswith(".webm"):
395
+ convert_webm_to_mp4(file, new_file_name)
396
+ task_history = task_history + [((new_file_name,), None)]
397
+ return history, task_history
398
+
399
+
400
+ def reset_user_input():
401
+ return gr.update(value="")
402
+
403
+ def reset_state(task_history):
404
+ task_history.clear()
405
+ return []
406
 
407
  @spaces.GPU
408
+ def stream_audio_output(history, task_history):
409
+ text = task_history[-1][-1]
410
+ if not text:
411
+ # import pdb;pdb.set_trace()
412
+ yield None,None
413
+ llm_resounse = replace_equation(remove_special_characters(text))
414
+ #print('tts_text', llm_resounse)
415
+ for idx, text in enumerate(split_into_sentences(llm_resounse)):
416
+ embeddings = llm_embedding(torch.tensor(tokenizer.encode(text)).cuda())
417
+ for seg in tts.run(embeddings.reshape(-1, 896).unsqueeze(0), decoder_topk,
418
+ None,
419
+ codec_chunk_size, codec_padding_size):
420
+ if idx == 0:
421
+ try:
422
+ split_idx = torch.nonzero(seg.abs() > 0.03, as_tuple=True)[-1][0]
423
+ seg = seg[:, :, split_idx:]
424
+ except:
425
+ print('Do not need to split')
426
+ pass
427
+
428
+ if seg is not None and len(seg) > 0:
429
+ seg = seg.to(torch.float32).cpu().numpy()
430
+ yield 24000, float_to_int16(seg).T
431
+
432
+
433
+ with gr.Blocks(title="VideoMLLM") as demo:
434
+ gr.Markdown("""<center><font size=8>VITA</center>""")
435
+ chatbot = gr.Chatbot(label='VITA', elem_classes="control-height", height=500)
436
+ query = gr.Textbox(lines=2, label='Text Input')
437
+ task_history = gr.State([])
438
+ with gr.Row():
439
+ add_text_button = gr.Button("Submit Text (提交文本)")
440
+ add_audio_button = gr.Button("Submit Audio (提交音频)")
441
+ with gr.Row():
442
+ with gr.Column(scale=2):
443
+ addfile_btn = gr.UploadButton("📁 Upload (上传文件[视频,图片])", file_types=["video", "image"])
444
+ video_input = gr.Video(sources=[ "webcam"], height=400, width=700, container=True, interactive=True, show_download_button=True, label="📹 Video Recording (视频录制)")
445
+
446
+ with gr.Column(scale=1):
447
+ empty_bin = gr.Button("🧹 Clear History (清除历史)")
448
+ record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
449
+ audio_output = gr.Audio(
450
+ label="Output Audio",
451
+ value=None,
452
+ format= "wav",
453
+ autoplay=True,
454
+ streaming=True,
455
+ interactive=False,
456
+ show_label=True,
457
+ waveform_options=gr.WaveformOptions(
458
+ sample_rate=24000,
459
+ ),
460
+ )
461
+
462
+
463
+ add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then(
464
+ reset_user_input, [], [query]
465
+ ).then(
466
+ predict, [chatbot, task_history], [chatbot], show_progress=True
467
+ ).then(
468
+ stream_audio_output,[chatbot, task_history], [audio_output],
469
+ )
470
+
471
+
472
+ video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True)
473
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
474
+ addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
475
+
476
+
477
+
478
+ add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then(
479
+ predict, [chatbot, task_history], [chatbot], show_progress=True
480
+ ).then(
481
+ stream_audio_output,[chatbot, task_history], [audio_output],
482
+ )
483
+
484
 
485
+ demo.launch(server_port=18806)
 
vita/config/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset_config import *
2
+
3
+ NaturalCap0 = [ShareGPT4V0]
4
+ NaturalCap = [ShareGPT4V]
5
+
6
+ DataConfig = {
7
+ "Pretrain_video": NaturalCap0,
8
+ }
9
+
10
+ NoPatchSets = ["khair", "jester"]
vita/config/dataset_config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ AudioFolder = ""
2
+ FolderDict = {
3
+ #### NaturalCap
4
+ "sharegpt4": "",
5
+ }
6
+ #### NaturalCap
7
+ ShareGPT4V = {"chat_path": ""}
8
+ ShareGPT4V0 = {"chat_path": ""}
vita/constants.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Constants
2
+ MAX_IMAGE_LENGTH = 16 # 8#16#32#64
3
+ MIN_IMAGE_LENGTH = 4
4
+ IGNORE_INDEX = -100
5
+ IMAGE_TOKEN_INDEX = -200
6
+ AUDIO_TOKEN_INDEX = -500
7
+ DEFAULT_IMAGE_TOKEN = "<image>"
8
+ DEFAULT_VIDEO_TOKEN = "<video>"
9
+ DEFAULT_AUDIO_TOKEN = "<audio>"
10
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
11
+ LOGDIR = "gradio-logs"
12
+ WORKER_HEART_BEAT_INTERVAL = 15
13
+ DEFAULT_DATA_RATIO = 1.0#0.124#0.5 #0.2 #1.0
14
+ GLOBAL_WEIGHTS_PATH = "/path/to/model_weights"
vita/conversation.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import Enum, auto
3
+ from typing import List
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+
9
+ TWO = auto()
10
+ PLAIN = auto()
11
+ Nemo = auto()
12
+ Qwen2p5Instruct = auto()
13
+ MixtralZh = auto()
14
+ MixtralTwo = auto()
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class Conversation:
19
+ """A class that keeps all conversation history."""
20
+
21
+ system: str
22
+ roles: List[str]
23
+ messages: List[List[str]]
24
+ offset: int
25
+ sep_style: SeparatorStyle
26
+ sep: str = "###"
27
+ sep2: str = None
28
+ version: str = "Unknown"
29
+
30
+ skip_next: bool = False
31
+
32
+ def get_prompt(self, modality=None):
33
+ messages = self.messages
34
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
35
+ messages = self.messages.copy()
36
+ init_role, init_msg = messages[0].copy()
37
+ init_msg = init_msg[0].replace("<image>", "").strip()
38
+ if "mmtag" in self.version:
39
+ messages[0] = (init_role, init_msg)
40
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
41
+ messages.insert(1, (self.roles[1], "Received."))
42
+ else:
43
+ messages[0] = (init_role, "<image>\n" + init_msg)
44
+
45
+ if self.sep_style == SeparatorStyle.TWO:
46
+ seps = [self.sep, self.sep2]
47
+ ret = self.system + seps[0]
48
+ for i, (role, message) in enumerate(messages):
49
+ if message:
50
+ if type(message) is tuple:
51
+ message, _, _ = message
52
+ ret += role + ": " + message + seps[i % 2]
53
+ else:
54
+ ret += role + ":"
55
+
56
+ elif self.sep_style == SeparatorStyle.MixtralZh:
57
+ seps = [self.sep, self.sep2]
58
+ ret = "system:" + self.system + seps[0]
59
+ for i, (role, message) in enumerate(messages):
60
+ if message:
61
+ if type(message) is tuple:
62
+ message, _, _ = message
63
+ ret += "\n" + role + ":" + message + seps[i % 2]
64
+ else:
65
+ ret += "\n" + role + ":"
66
+
67
+ elif self.sep_style == SeparatorStyle.MixtralTwo:
68
+ seps = [self.sep, self.sep2]
69
+ has_image = False
70
+ for i, (role, message) in enumerate(messages):
71
+ if message and "<image>" in message:
72
+ has_image = True
73
+ break
74
+ if has_image:
75
+ assert modality == "image" or modality == "video"
76
+ if modality == "image":
77
+ self.system = self.system[0]
78
+ elif modality == "video":
79
+ self.system = self.system[1]
80
+ else:
81
+ raise ValueError
82
+ else:
83
+ assert modality == "lang"
84
+ self.system = self.system[2]
85
+ ret = "system:" + self.system + seps[0]
86
+ for i, (role, message) in enumerate(messages):
87
+ if message:
88
+ if type(message) is tuple:
89
+ message, _, _ = message
90
+ ret += "\n" + role + ":" + message + seps[i % 2]
91
+ else:
92
+ ret += "\n" + role + ":"
93
+
94
+ elif self.sep_style == SeparatorStyle.Nemo:
95
+ wrap_inst = lambda msg: f"[INST]{msg}[/INST]"
96
+ seps = [self.sep, self.sep2]
97
+ has_image = False
98
+ for i, (role, message) in enumerate(messages):
99
+ if message and "<image>" in message:
100
+ has_image = True
101
+ break
102
+ if has_image:
103
+ assert modality == "image" or modality == "video"
104
+ if modality == "image":
105
+ self.system = self.system[0]
106
+ elif modality == "video":
107
+ self.system = self.system[1]
108
+ else:
109
+ raise ValueError
110
+ else:
111
+ assert modality == "lang"
112
+ self.system = self.system[2]
113
+ ret = ""
114
+ for i, (role, message) in enumerate(messages):
115
+ if message:
116
+ if type(message) is tuple:
117
+ message, _, _ = message
118
+ if i == 0:
119
+ message = self.system + '\n' + message
120
+ if i % 2 == 0:
121
+ ret += wrap_inst(message)
122
+ else:
123
+ ret += message + seps[i % 2]
124
+ else:
125
+ ret += ""
126
+
127
+ elif self.sep_style == SeparatorStyle.Qwen2p5Instruct:
128
+ wrap_qa = lambda msg: f"<|im_start|>{msg}<|im_end|>\n"
129
+ wrap_qa2 = lambda msg: f"<|im_start|>{msg}<|im_end|>"
130
+ seps = [self.sep, self.sep2]
131
+ has_image = False
132
+ for i, (role, message) in enumerate(messages):
133
+ if message and "<image>" in message:
134
+ has_image = True
135
+ break
136
+ if has_image:
137
+ assert modality == "image" or modality == "video"
138
+ if modality == "image":
139
+ self.system = self.system[0]
140
+ elif modality == "video":
141
+ self.system = self.system[1]
142
+ else:
143
+ raise ValueError
144
+ else:
145
+ assert modality == "lang"
146
+ self.system = self.system[2]
147
+ ret = wrap_qa("system\n" + self.system)
148
+ for i, (role, message) in enumerate(messages):
149
+ if message:
150
+ if type(message) is tuple:
151
+ message, _, _ = message
152
+ if i < len(messages) - 1:
153
+ ret += wrap_qa(role + '\n' + message)
154
+ else:
155
+ ret += wrap_qa2(role + '\n' + message)
156
+ else:
157
+ ret += "<|im_start|>" + role + '\n'
158
+
159
+ elif self.sep_style == SeparatorStyle.PLAIN:
160
+ seps = [self.sep, self.sep2]
161
+ ret = self.system
162
+ for i, (role, message) in enumerate(messages):
163
+ if message:
164
+ if type(message) is tuple:
165
+ message, _, _ = message
166
+ ret += message + seps[i % 2]
167
+ else:
168
+ ret += ""
169
+ else:
170
+ raise ValueError(f"Invalid style: {self.sep_style}")
171
+
172
+ return ret
173
+
174
+ def append_message(self, role, message):
175
+ self.messages.append([role, message])
176
+
177
+ def get_images(self, return_pil=False):
178
+ images = []
179
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
180
+ if i % 2 == 0:
181
+ if type(msg) is tuple:
182
+ import base64
183
+ from io import BytesIO
184
+ from PIL import Image
185
+
186
+ msg, image, image_process_mode = msg
187
+ if image_process_mode == "Pad":
188
+
189
+ def expand2square(pil_img, background_color=(122, 116, 104)):
190
+ width, height = pil_img.size
191
+ if width == height:
192
+ return pil_img
193
+ elif width > height:
194
+ result = Image.new(pil_img.mode, (width, width), background_color)
195
+ result.paste(pil_img, (0, (width - height) // 2))
196
+ return result
197
+ else:
198
+ result = Image.new(pil_img.mode, (height, height), background_color)
199
+ result.paste(pil_img, ((height - width) // 2, 0))
200
+ return result
201
+
202
+ image = expand2square(image)
203
+ elif image_process_mode in ["Default", "Crop"]:
204
+ pass
205
+ elif image_process_mode == "Resize":
206
+ image = image.resize((336, 336))
207
+ else:
208
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
209
+
210
+ if return_pil:
211
+ images.append(image)
212
+ else:
213
+ buffered = BytesIO()
214
+ image.save(buffered, format="PNG")
215
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
216
+ images.append(img_b64_str)
217
+ return images
218
+
219
+ def to_gradio_chatbot(self):
220
+ ret = []
221
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
222
+ if i % 2 == 0:
223
+ if type(msg) is tuple:
224
+ import base64
225
+ from io import BytesIO
226
+
227
+ msg, image, image_process_mode = msg
228
+ max_hw, min_hw = max(image.size), min(image.size)
229
+ aspect_ratio = max_hw / min_hw
230
+ max_len, min_len = 800, 400
231
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
232
+ longest_edge = int(shortest_edge * aspect_ratio)
233
+ W, H = image.size
234
+ if H > W:
235
+ H, W = longest_edge, shortest_edge
236
+ else:
237
+ H, W = shortest_edge, longest_edge
238
+ image = image.resize((W, H))
239
+ buffered = BytesIO()
240
+ image.save(buffered, format="JPEG")
241
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
242
+ img_str = (
243
+ f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
244
+ )
245
+ msg = img_str + msg.replace("<image>", "").strip()
246
+ ret.append([msg, None])
247
+ else:
248
+ ret.append([msg, None])
249
+ else:
250
+ ret[-1][-1] = msg
251
+ return ret
252
+
253
+ def copy(self):
254
+ return Conversation(
255
+ system=self.system,
256
+ roles=self.roles,
257
+ messages=[[x, y] for x, y in self.messages],
258
+ offset=self.offset,
259
+ sep_style=self.sep_style,
260
+ sep=self.sep,
261
+ sep2=self.sep2,
262
+ version=self.version,
263
+ )
264
+
265
+ def dict(self):
266
+ if len(self.get_images()) > 0:
267
+ return {
268
+ "system": self.system,
269
+ "roles": self.roles,
270
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
271
+ "offset": self.offset,
272
+ "sep": self.sep,
273
+ "sep2": self.sep2,
274
+ }
275
+ return {
276
+ "system": self.system,
277
+ "roles": self.roles,
278
+ "messages": self.messages,
279
+ "offset": self.offset,
280
+ "sep": self.sep,
281
+ "sep2": self.sep2,
282
+ }
283
+
284
+
285
+ conv_mixtral_zh = Conversation(
286
+ system="你是一个人工智能机器人。\n- 你是研究社区开发的大语言模型。你的设计宗旨是有益、诚实且无害。\n- 你支持使用用户选择的多种语言流利地进行交流并解答用户的问题。\n- 如果用户更正你生成的错误答案,你会向用户致歉并与用户探讨正确的答案。",
287
+ roles=("user", "bot"),
288
+ version="mixtral_zh",
289
+ messages=(),
290
+ offset=0,
291
+ sep_style=SeparatorStyle.MixtralZh,
292
+ sep="</s>",
293
+ sep2="</s>",
294
+ )
295
+
296
+ conv_mixtral_two = Conversation(
297
+ system=[
298
+ "You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the image given by the user, and it is strictly forbidden to answer the question without the content of the image. Please note that you are seeing the image, not the video.",
299
+ "You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the video given by the user, and it is strictly forbidden to answer the question without the content of the video. Please note that you are seeing the video, not the image.",
300
+ "You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user.",
301
+ ],
302
+ roles=("user", "bot"),
303
+ version="mixtral_two",
304
+ messages=(),
305
+ offset=0,
306
+ sep_style=SeparatorStyle.MixtralTwo,
307
+ sep="</s>",
308
+ sep2="</s>",
309
+ )
310
+
311
+ conv_nemo = Conversation(
312
+ system=[
313
+ "You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the image given by the user, and it is strictly forbidden to answer the question without the content of the image. Please note that you are seeing the image, not the video.",
314
+ "You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the video given by the user, and it is strictly forbidden to answer the question without the content of the video. Please note that you are seeing the video, not the image.",
315
+ "You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user.",
316
+ ],
317
+ roles=("USER", "ASSISTANT"),
318
+ version="nemo",
319
+ messages=(),
320
+ offset=0,
321
+ sep_style=SeparatorStyle.Nemo,
322
+ sep="[/INST]",
323
+ sep2="</s>",
324
+ )
325
+
326
+ conv_qwen2p5_instruct = Conversation(
327
+ system=[
328
+ "You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the image given by the user, and it is strictly forbidden to answer the question without the content of the image. Please note that you are seeing the image, not the video.",
329
+ "You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the video given by the user, and it is strictly forbidden to answer the question without the content of the video. Please note that you are seeing the video, not the image.",
330
+ "You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user.",
331
+ ],
332
+ roles=("user", "assistant"),
333
+ version="qwen2p5_instruct",
334
+ messages=(),
335
+ offset=0,
336
+ sep_style=SeparatorStyle.Qwen2p5Instruct,
337
+ sep="<|im_start|>",
338
+ sep2="<|im_start|>",
339
+ )
340
+
341
+ conv_phi3 = Conversation(
342
+ system="A chat between a curious user and an artificial intelligence assistant. "
343
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
344
+ roles=("USER", "ASSISTANT"),
345
+ version="phi3",
346
+ messages=(),
347
+ offset=0,
348
+ sep_style=SeparatorStyle.TWO,
349
+ sep=" ",
350
+ sep2="<|endoftext|>",
351
+ )
352
+
353
+ conv_minicpm = Conversation(
354
+ system="A chat between a curious user and an artificial intelligence assistant. "
355
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
356
+ roles=("USER", "ASSISTANT"),
357
+ version="minicpm",
358
+ messages=(),
359
+ offset=0,
360
+ sep_style=SeparatorStyle.TWO,
361
+ sep=" ",
362
+ sep2="</s>",
363
+ )
364
+
365
+ conv_llama = Conversation(
366
+ system="A chat between a curious user and an artificial intelligence assistant. "
367
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
368
+ roles=("USER", "ASSISTANT"),
369
+ version="llama",
370
+ messages=(),
371
+ offset=0,
372
+ sep_style=SeparatorStyle.TWO,
373
+ sep=" ",
374
+ sep2="<|end_of_text|>",
375
+ )
376
+
377
+ conv_plain = Conversation(
378
+ system="",
379
+ roles=("", ""),
380
+ messages=(),
381
+ offset=0,
382
+ sep_style=SeparatorStyle.PLAIN,
383
+ sep="\n",
384
+ )
385
+
386
+ default_conversation = conv_mixtral_two
387
+ conv_templates = {
388
+ "default": conv_mixtral_two,
389
+ "nemo": conv_nemo,
390
+ "qwen2p5_instruct": conv_qwen2p5_instruct,
391
+ "mixtral_zh": conv_mixtral_zh,
392
+ "mixtral_two": conv_mixtral_two,
393
+ "phi3": conv_phi3,
394
+ "plain": conv_plain,
395
+ "minicpm": conv_minicpm,
396
+ "llama": conv_llama,
397
+ }
398
+
399
+ if __name__ == "__main__":
400
+ print(default_conversation.get_prompt())
401
+
vita/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .language_model.vita_mixtral import VITAMixtralConfig, VITAMixtralForCausalLM
2
+ from .language_model.vita_nemo import VITAMistralConfig, VITAMistralForCausalLM
3
+ from .language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM
4
+ from .language_model.vita_fo_qwen2 import VITAFOQwen2Config, VITAFOQwen2ForCausalLM
5
+
vita/model/builder.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ import torch
5
+ from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig, logging
6
+
7
+ from vita.constants import GLOBAL_WEIGHTS_PATH
8
+ from vita.model import *
9
+
10
+ logging.set_verbosity_error()
11
+ warnings.filterwarnings("ignore")
12
+
13
+
14
+ def load_pretrained_model(
15
+ model_path,
16
+ model_base,
17
+ model_name,
18
+ model_type,
19
+ load_8bit=False,
20
+ load_4bit=False,
21
+ device_map="auto",
22
+ device="cuda",
23
+ **kwargs,
24
+ ):
25
+ if model_type not in {"mixtral-8x7b", "nemo", "qwen2p5_instruct", "qwen2p5_fo_instruct"}:
26
+ raise ValueError(f"Unknown Model Type {model_type}")
27
+
28
+ kwargs = {"device_map": device_map, **kwargs}
29
+
30
+ if device != "cuda":
31
+ kwargs["device_map"] = {"": device}
32
+
33
+ if load_8bit:
34
+ kwargs["load_in_8bit"] = True
35
+ elif load_4bit:
36
+ kwargs["load_in_4bit"] = True
37
+ kwargs["quantization_config"] = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_compute_dtype=torch.float16,
40
+ bnb_4bit_use_double_quant=True,
41
+ bnb_4bit_quant_type="nf4",
42
+ )
43
+ else:
44
+ kwargs["torch_dtype"] = torch.float16
45
+
46
+ # Load VITA model
47
+ if "lora" in model_name.lower() and model_base is None:
48
+ warnings.warn(
49
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument."
50
+ )
51
+ if "lora" in model_name.lower() and model_base is not None:
52
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
53
+
54
+ print("Loading VITA from base model...")
55
+ if model_type == "mixtral-8x7b":
56
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
57
+ model = VITAMixtralForCausalLM.from_pretrained(
58
+ model_path, low_cpu_mem_usage=True, **kwargs
59
+ )
60
+
61
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
62
+ if model.lm_head.weight.shape[0] != token_num:
63
+ model.lm_head.weight = torch.nn.Parameter(
64
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
65
+ )
66
+ model.model.embed_tokens.weight = torch.nn.Parameter(
67
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
68
+ )
69
+
70
+ print("Loading additional VITA weights...")
71
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
72
+ non_lora_trainables = torch.load(
73
+ os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu"
74
+ )
75
+ else:
76
+ # this is probably from HF Hub
77
+ from huggingface_hub import hf_hub_download
78
+
79
+ def load_from_hf(repo_id, filename, subfolder=None):
80
+ cache_file = hf_hub_download(
81
+ repo_id=repo_id, filename=filename, subfolder=subfolder
82
+ )
83
+ return torch.load(cache_file, map_location="cpu")
84
+
85
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
86
+
87
+ non_lora_trainables = {
88
+ (k[11:] if k.startswith("base_model.") else k): v
89
+ for k, v in non_lora_trainables.items()
90
+ }
91
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
92
+ non_lora_trainables = {
93
+ (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()
94
+ }
95
+ model.load_state_dict(non_lora_trainables, strict=False)
96
+
97
+ from peft import PeftModel
98
+
99
+ print("Loading LoRA weights...")
100
+ model = PeftModel.from_pretrained(model, model_path)
101
+ print("Merging LoRA weights...")
102
+ model = model.merge_and_unload()
103
+ print("Model is loaded...")
104
+ elif model_base is not None:
105
+ # this may be mm projector only
106
+ print("Loading VITA from base model...")
107
+
108
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
109
+ if model_type == "mixtral-8x7b":
110
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
111
+ model = VITAMixtralForCausalLM.from_pretrained(
112
+ model_base, low_cpu_mem_usage=True, **kwargs
113
+ )
114
+
115
+ # load vision encoder
116
+ from types import SimpleNamespace
117
+ model_args = {
118
+ "vision_tower": f"{GLOBAL_WEIGHTS_PATH}/InternViT-300M-448px",
119
+ "pretrain_mm_mlp_adapter": None,
120
+ "mm_projector_type": "mlp2x_gelu",
121
+ }
122
+ model_args = SimpleNamespace(**model_args)
123
+ model.get_model().initialize_vision_modules(model_args=model_args)
124
+
125
+ # load audio encoder
126
+ from types import SimpleNamespace
127
+ model_args = {
128
+ 'audio_encoder': f"{GLOBAL_WEIGHTS_PATH}/audio-encoder-2wh_zh_en_audioset_Mixtral-8x7B_New-base-tunning",
129
+ 'freeze_audio_encoder': True,
130
+ 'freeze_audio_encoder_adapter': True
131
+ }
132
+ model_args = SimpleNamespace(**model_args)
133
+ model.get_model().initialize_audio_modules(model_args=model_args)
134
+ audio_encoder = model.get_audio_encoder()
135
+ device = torch.device('cuda:0')
136
+ audio_encoder = audio_encoder.to(device)
137
+
138
+ mm_projector_weights = torch.load(
139
+ os.path.join(model_path, "mm_projector.bin"), map_location="cpu"
140
+ )
141
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
142
+ model.load_state_dict(mm_projector_weights, strict=False)
143
+ model.model.mm_projector.to(device="cuda", dtype=torch.float16)
144
+ model.model.vision_tower.to(device="cuda", dtype=torch.float16)
145
+ else:
146
+ if model_type == "mixtral-8x7b":
147
+ # import pdb; pdb.set_trace()
148
+ device_map = {
149
+ "model.embed_tokens": 0,
150
+ "model.layers.0": 0,
151
+ "model.layers.1": 0,
152
+ "model.layers.2": 0,
153
+ "model.layers.3": 0,
154
+ "model.layers.4": 0,
155
+ "model.layers.5": 0,
156
+ "model.layers.6": 0,
157
+ "model.layers.7": 0,
158
+ "model.layers.8": 0,
159
+ "model.layers.9": 0,
160
+ "model.layers.10": 0,
161
+ "model.layers.11": 0,
162
+ "model.layers.12": 0,
163
+ "model.layers.13": 0,
164
+ "model.layers.14": 0,
165
+ "model.layers.15": 0,
166
+ "model.layers.16": 1,
167
+ "model.layers.17": 1,
168
+ "model.layers.18": 1,
169
+ "model.layers.19": 1,
170
+ "model.layers.20": 1,
171
+ "model.layers.21": 1,
172
+ "model.layers.22": 1,
173
+ "model.layers.23": 1,
174
+ "model.layers.24": 1,
175
+ "model.layers.25": 1,
176
+ "model.layers.26": 1,
177
+ "model.layers.27": 1,
178
+ "model.layers.28": 1,
179
+ "model.layers.29": 1,
180
+ "model.layers.30": 1,
181
+ "model.layers.31": 1,
182
+ "model.norm": 1,
183
+ "model.vision_tower": 1,
184
+ "model.mm_projector": 1,
185
+ "model.audio_encoder": 1,
186
+ "lm_head": 1,
187
+ }
188
+ device_map["model.audio_encoder"] = 0
189
+ kwargs.update(device_map=device_map)
190
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
191
+ model = VITAMixtralForCausalLM.from_pretrained(
192
+ model_path, low_cpu_mem_usage=True, **kwargs
193
+ )
194
+ # model.hf_device_map
195
+ elif model_type == "nemo":
196
+ # import pdb; pdb.set_trace()
197
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
198
+ model = VITAMistralForCausalLM.from_pretrained(
199
+ model_path, low_cpu_mem_usage=True, **kwargs
200
+ )
201
+ elif model_type == "qwen2p5_instruct":
202
+ # import pdb; pdb.set_trace()
203
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
204
+ model = VITAQwen2ForCausalLM.from_pretrained(
205
+ model_path, low_cpu_mem_usage=True, **kwargs
206
+ )
207
+ elif model_type == "qwen2p5_fo_instruct":
208
+ # import pdb; pdb.set_trace()
209
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
210
+ model = VITAFOQwen2ForCausalLM.from_pretrained(
211
+ model_path, low_cpu_mem_usage=True, **kwargs
212
+ )
213
+
214
+ model.resize_token_embeddings(len(tokenizer))
215
+
216
+ vision_tower = model.get_vision_tower()
217
+ if not vision_tower.is_loaded:
218
+ vision_tower.load_model()
219
+
220
+ num_params = sum(p.numel() for p in vision_tower.parameters())
221
+ print("the number of vision encoder params: {}M".format(num_params / 1024 / 1024))
222
+
223
+ if getattr(model.config, "unfreeze_vision_tower", False):
224
+ if "lora" in model_name.lower():
225
+ assert model_base is not None
226
+ vision_non_lora_trainables = {
227
+ k[19:]: v
228
+ for k, v in non_lora_trainables.items()
229
+ if k.startswith("model.vision_tower.")
230
+ }
231
+ vision_tower.load_state_dict(vision_non_lora_trainables, strict=False)
232
+ else:
233
+ assert model_base is None
234
+ from safetensors.torch import load_file
235
+
236
+ vision_weights = {}
237
+ for file_name in os.listdir(model_path):
238
+ if file_name.endswith("safetensors"):
239
+ vision_weights.update(
240
+ {
241
+ k[19:]: v
242
+ for k, v in load_file(os.path.join(model_path, file_name)).items()
243
+ if k.startswith("model.vision_tower.")
244
+ }
245
+ )
246
+ vision_tower.load_state_dict(vision_weights, strict=True)
247
+
248
+ # import pdb; pdb.set_trace()
249
+ # if (not getattr(model.config, "freeze_audio_encoder", True)) and (not getattr(model.config, "freeze_audio_encoder_adapter", True)):
250
+ # from safetensors.torch import load_file
251
+ # audio_weights = {}
252
+ # for file_name in os.listdir(model_path):
253
+ # if file_name.endswith('safetensors'):
254
+ # audio_weights.update(
255
+ # {k[20:]: v for k, v in load_file(os.path.join(model_path, file_name)).items() if
256
+ # k.startswith('model.audio_encoder.')})
257
+ # audio_encoder.load_state_dict(audio_weights, strict=True)
258
+ # audio_encoder.eval()
259
+ # import pdb; pdb.set_trace()
260
+
261
+ # import pdb; pdb.set_trace()
262
+ # from safetensors.torch import load_file
263
+ # audio_weights = {}
264
+ # for file_name in os.listdir(model_path):
265
+ # if file_name.endswith('safetensors'):
266
+ # audio_weights.update(
267
+ # {k[20:]: v for k, v in load_file(os.path.join(model_path, file_name)).items() if
268
+ # k.startswith('model.audio_encoder.')})
269
+ # import pdb; pdb.set_trace()
270
+
271
+ vision_tower.to(dtype=torch.float16)
272
+ image_processor = vision_tower.image_processor
273
+
274
+ #import pdb; pdb.set_trace()
275
+ if hasattr(model.config, "max_sequence_length"):
276
+ context_len = model.config.max_sequence_length
277
+ else:
278
+ context_len = 2048
279
+
280
+ if model.generation_config.pad_token_id is None:
281
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
282
+
283
+ if model_type == "phi-3":
284
+ model.generation_config.eos_token_id = tokenizer.eos_token_id
285
+
286
+ return tokenizer, model, image_processor, context_len
287
+
vita/model/language_model/vita_fo_qwen2.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModelForCausalLM,
10
+ Qwen2Config,
11
+ Qwen2ForCausalLM,
12
+ Qwen2Model,
13
+ )
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast, MoeCausalLMOutputWithPast
16
+ from transformers.generation.utils import GenerateOutput
17
+
18
+ from ..vita_arch import VITAMetaForCausalLM, VITAMetaModel
19
+ from ...constants import IGNORE_INDEX
20
+ from .vita_qwen2 import custom_forward
21
+
22
+
23
+ Qwen2ForCausalLM.forward = custom_forward
24
+
25
+
26
+ class VITAFOQwen2Config(Qwen2Config):
27
+ model_type = "vita-fo-Qwen2"
28
+
29
+
30
+ class VITAFOQwen2Model(VITAMetaModel, Qwen2Model):
31
+ config_class = VITAFOQwen2Config
32
+
33
+ def __init__(self, config: Qwen2Config):
34
+ super(VITAFOQwen2Model, self).__init__(config)
35
+
36
+
37
+ class VITAFOQwen2ForCausalLM(Qwen2ForCausalLM, VITAMetaForCausalLM):
38
+ config_class = VITAFOQwen2Config
39
+
40
+ def __init__(self, config):
41
+ super(Qwen2ForCausalLM, self).__init__(config)
42
+ self.model = VITAFOQwen2Model(config)
43
+ self.vocab_size = config.vocab_size
44
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
45
+ self.predict_usr_state = 0#2
46
+ if self.predict_usr_state:
47
+ self.predictor_head = torch.nn.Linear(config.hidden_size, self.predict_usr_state + 1) # +1 for the dummy class
48
+ else:
49
+ self.predictor_head = None
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ position_ids: Optional[torch.LongTensor] = None,
61
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ images: Optional[torch.FloatTensor] = None,
68
+ audios: Optional[dict] = None,
69
+ sf_masks: Optional[torch.Tensor] = None,
70
+ return_dict: Optional[bool] = None,
71
+ cache_position: Optional[torch.LongTensor] = None,
72
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
73
+ if inputs_embeds is None:
74
+ (
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ inputs_embeds,
80
+ labels,
81
+ ) = self.prepare_inputs_labels_for_multimodal(
82
+ input_ids, position_ids, attention_mask, past_key_values, labels, images, audios, sf_masks
83
+ )
84
+ if labels is not None:
85
+ state_labels = labels
86
+ labels = torch.where(labels>=0, labels, IGNORE_INDEX)
87
+ output_hidden_states = True
88
+ outputs = super().forward(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ position_ids=position_ids,
92
+ past_key_values=past_key_values,
93
+ inputs_embeds=inputs_embeds,
94
+ labels=labels,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ output_hidden_states=output_hidden_states,
98
+ return_dict=return_dict,
99
+ cache_position=cache_position,
100
+ )
101
+ # state loss
102
+ if self.predictor_head is not None:
103
+ state_logits = self.predictor_head(outputs[2][-1]).view(-1, self.predict_usr_state+1) # +1 for the dummy class
104
+ if labels is not None:
105
+ loss = outputs[0]
106
+ weight = torch.Tensor([1, 5, 1]).to(torch.bfloat16).to(inputs_embeds.device)
107
+ loss_fct = torch.nn.CrossEntropyLoss(weight=weight)
108
+ s_labels= torch.where(
109
+ state_labels < IGNORE_INDEX,
110
+ IGNORE_INDEX-state_labels-1,
111
+ IGNORE_INDEX).view(-1)
112
+ #assert all(label in [0, 1, IGNORE_INDEX] for label in s_labels), "s_labels must contain only 0, 1, or -100"
113
+ state_loss = loss_fct(state_logits, s_labels)
114
+ loss = loss + state_loss
115
+ outputs['loss'] = loss
116
+ return outputs
117
+
118
+ @torch.no_grad()
119
+ def generate(
120
+ self,
121
+ inputs: Optional[torch.Tensor] = None,
122
+ images: Optional[torch.Tensor] = None,
123
+ audios: Optional[torch.Tensor] = None,
124
+ sf_masks: Optional[torch.Tensor] = None,
125
+ **kwargs,
126
+ ) -> Union[GenerateOutput, torch.LongTensor]:
127
+ position_ids = kwargs.pop("position_ids", None)
128
+ attention_mask = kwargs.pop("attention_mask", None)
129
+ if "inputs_embeds" in kwargs:
130
+ raise NotImplementedError("`inputs_embeds` is not supported")
131
+
132
+ if images is not None or audios is not None:
133
+ (
134
+ inputs,
135
+ position_ids,
136
+ attention_mask,
137
+ _,
138
+ inputs_embeds,
139
+ _
140
+ ) = self.prepare_inputs_labels_for_multimodal(
141
+ inputs,
142
+ position_ids,
143
+ attention_mask,
144
+ None,
145
+ None,
146
+ images,
147
+ audios,
148
+ sf_masks,
149
+ )
150
+ else:
151
+ inputs_embeds = self.get_model().embed_tokens(inputs)
152
+
153
+ return super().generate(
154
+ position_ids=position_ids,
155
+ attention_mask=attention_mask,
156
+ inputs_embeds=inputs_embeds,
157
+ **kwargs
158
+ )
159
+
160
+ def prepare_inputs_for_generation(
161
+ self,
162
+ input_ids,
163
+ past_key_values=None,
164
+ inputs_embeds=None,
165
+ attention_mask=None,
166
+ **kwargs,
167
+ ):
168
+ images = kwargs.pop("images", None)
169
+ audios = kwargs.pop("audios", None)
170
+ sf_masks = kwargs.pop("sf_masks", None)
171
+
172
+ _inputs = super().prepare_inputs_for_generation(
173
+ input_ids,
174
+ past_key_values=past_key_values,
175
+ inputs_embeds=inputs_embeds,
176
+ attention_mask=attention_mask,
177
+ **kwargs,
178
+ )
179
+
180
+ if images is not None:
181
+ _inputs["images"] = images
182
+ if audios is not None:
183
+ _inputs["audios"] = audios
184
+ if sf_masks is not None:
185
+ _inputs["sf_masks"] = sf_masks
186
+ return _inputs
187
+
188
+ def expand2square(self, pil_img, background_color):
189
+ width, height = pil_img.size
190
+ if width == height:
191
+ return pil_img
192
+ elif width > height:
193
+ result = Image.new(pil_img.mode, (width, width), background_color)
194
+ result.paste(pil_img, (0, (width - height) // 2))
195
+ return result
196
+ else:
197
+ result = Image.new(pil_img.mode, (height, height), background_color)
198
+ result.paste(pil_img, ((height - width) // 2, 0))
199
+ return result
200
+
201
+ def process_images(self, images, model_cfg):
202
+ vision_tower = self.get_vision_tower()
203
+ if not vision_tower.is_loaded:
204
+ vision_tower.load_model()
205
+ image_processor = vision_tower.image_processor
206
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
207
+ new_images = []
208
+ if image_aspect_ratio == "pad":
209
+ for image in images:
210
+ image = self.expand2square(
211
+ image, tuple(int(x * 255) for x in image_processor.image_mean)
212
+ )
213
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
214
+ new_images.append(image)
215
+ else:
216
+ return image_processor(images, return_tensors="pt")["pixel_values"]
217
+ if all(x.shape == new_images[0].shape for x in new_images):
218
+ new_images = torch.stack(new_images, dim=0)
219
+ return new_images
220
+
221
+
222
+ AutoConfig.register("vita-fo-Qwen2", VITAFOQwen2Config)
223
+ AutoModelForCausalLM.register(VITAFOQwen2Config, VITAFOQwen2ForCausalLM)
224
+
225
+
226
+
227
+
vita/model/language_model/vita_mixtral.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModelForCausalLM,
10
+ MixtralConfig,
11
+ MixtralForCausalLM,
12
+ MixtralModel,
13
+ )
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast, MoeCausalLMOutputWithPast
16
+
17
+ from ..vita_arch import VITAMetaForCausalLM, VITAMetaModel
18
+
19
+
20
+ def load_balancing_loss_func(
21
+ gate_logits: torch.Tensor,
22
+ num_experts: torch.Tensor = None,
23
+ top_k=2,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ ) -> float:
26
+ r"""
27
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
28
+
29
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
30
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
31
+ experts is too unbalanced.
32
+
33
+ Args:
34
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
35
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
36
+ shape [batch_size X sequence_length, num_experts].
37
+ attention_mask (`torch.Tensor`, None):
38
+ The attention_mask used in forward function
39
+ shape [batch_size X sequence_length] if not None.
40
+ num_experts (`int`, *optional*):
41
+ Number of experts
42
+
43
+ Returns:
44
+ The auxiliary loss.
45
+ """
46
+ if gate_logits is None or not isinstance(gate_logits, tuple):
47
+ return 0
48
+
49
+ if isinstance(gate_logits, tuple):
50
+ compute_device = gate_logits[0].device
51
+ concatenated_gate_logits = torch.cat(
52
+ [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
53
+ )
54
+
55
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
56
+
57
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
58
+
59
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
60
+
61
+ if attention_mask is None:
62
+ # Compute the percentage of tokens routed to each experts
63
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
64
+
65
+ # Compute the average probability of routing to these experts
66
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
67
+ else:
68
+ batch_size, sequence_length = attention_mask.shape
69
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
70
+
71
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
72
+ expert_attention_mask = (
73
+ attention_mask[None, :, :, None, None]
74
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
75
+ .reshape(-1, top_k, num_experts)
76
+ .to(compute_device)
77
+ )
78
+
79
+ # Compute the percentage of tokens routed to each experts
80
+ tokens_per_expert = torch.sum(
81
+ expert_mask.float() * expert_attention_mask, dim=0
82
+ ) / torch.sum(expert_attention_mask, dim=0)
83
+
84
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
85
+ router_per_expert_attention_mask = (
86
+ attention_mask[None, :, :, None]
87
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
88
+ .reshape(-1, num_experts)
89
+ .to(compute_device)
90
+ )
91
+
92
+ # Compute the average probability of routing to these experts
93
+ router_prob_per_expert = torch.sum(
94
+ routing_weights * router_per_expert_attention_mask, dim=0
95
+ ) / torch.sum(router_per_expert_attention_mask, dim=0)
96
+
97
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
98
+ return overall_loss * num_experts
99
+
100
+
101
+ def custom_forward(
102
+ self,
103
+ input_ids: torch.LongTensor = None,
104
+ attention_mask: Optional[torch.Tensor] = None,
105
+ position_ids: Optional[torch.LongTensor] = None,
106
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
107
+ inputs_embeds: Optional[torch.FloatTensor] = None,
108
+ labels: Optional[torch.LongTensor] = None,
109
+ use_cache: Optional[bool] = None,
110
+ output_attentions: Optional[bool] = None,
111
+ output_hidden_states: Optional[bool] = None,
112
+ output_router_logits: Optional[bool] = None,
113
+ return_dict: Optional[bool] = None,
114
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
115
+ r"""
116
+ Args:
117
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
118
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
119
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
120
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
121
+
122
+ Returns:
123
+
124
+ Example:
125
+
126
+ ```python
127
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
128
+
129
+ >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
130
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
131
+
132
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
133
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
134
+
135
+ >>> # Generate
136
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
137
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
138
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
139
+ ```"""
140
+
141
+ output_attentions = (
142
+ output_attentions if output_attentions is not None else self.config.output_attentions
143
+ )
144
+ output_router_logits = (
145
+ output_router_logits
146
+ if output_router_logits is not None
147
+ else self.config.output_router_logits
148
+ )
149
+
150
+ output_hidden_states = (
151
+ output_hidden_states
152
+ if output_hidden_states is not None
153
+ else self.config.output_hidden_states
154
+ )
155
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
156
+
157
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
158
+ outputs = self.model(
159
+ input_ids=input_ids,
160
+ attention_mask=attention_mask,
161
+ position_ids=position_ids,
162
+ past_key_values=past_key_values,
163
+ inputs_embeds=inputs_embeds,
164
+ use_cache=use_cache,
165
+ output_attentions=output_attentions,
166
+ output_hidden_states=output_hidden_states,
167
+ output_router_logits=output_router_logits,
168
+ return_dict=return_dict,
169
+ )
170
+
171
+ hidden_states = outputs[0]
172
+ logits = self.lm_head(hidden_states)
173
+ # logits = logits.float()
174
+
175
+ loss = None
176
+ if labels is not None:
177
+ # Shift so that tokens < n predict n
178
+ shift_logits = logits[..., :-1, :].contiguous()
179
+ shift_labels = labels[..., 1:].contiguous()
180
+ # Flatten the tokens
181
+ loss_fct = CrossEntropyLoss()
182
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
183
+ shift_labels = shift_labels.view(-1)
184
+ # Enable model parallelism
185
+ shift_labels = shift_labels.to(shift_logits.device)
186
+ loss = loss_fct(shift_logits, shift_labels)
187
+
188
+ aux_loss = None
189
+ if output_router_logits:
190
+ aux_loss = load_balancing_loss_func(
191
+ outputs.router_logits if return_dict else outputs[-1],
192
+ self.num_experts,
193
+ self.num_experts_per_tok,
194
+ attention_mask,
195
+ )
196
+ if labels is not None:
197
+ loss += self.router_aux_loss_coef * aux_loss.to(
198
+ loss.device
199
+ ) # make sure to reside in the same device
200
+
201
+ if not return_dict:
202
+ output = (logits,) + outputs[1:]
203
+ if output_router_logits:
204
+ output = (aux_loss,) + output
205
+ return (loss,) + output if loss is not None else output
206
+
207
+ return MoeCausalLMOutputWithPast(
208
+ loss=loss,
209
+ aux_loss=aux_loss,
210
+ logits=logits,
211
+ past_key_values=outputs.past_key_values,
212
+ hidden_states=outputs.hidden_states,
213
+ attentions=outputs.attentions,
214
+ router_logits=outputs.router_logits,
215
+ )
216
+
217
+
218
+ MixtralForCausalLM.forward = custom_forward
219
+
220
+
221
+ class VITAMixtralConfig(MixtralConfig):
222
+ model_type = "vita-mixtral"
223
+
224
+
225
+ class VITAMixtralModel(VITAMetaModel, MixtralModel):
226
+ config_class = VITAMixtralConfig
227
+
228
+ def __init__(self, config: MixtralConfig):
229
+ super(VITAMixtralModel, self).__init__(config)
230
+
231
+
232
+ class VITAMixtralForCausalLM(MixtralForCausalLM, VITAMetaForCausalLM):
233
+ config_class = VITAMixtralConfig
234
+
235
+ def __init__(self, config):
236
+ super(MixtralForCausalLM, self).__init__(config)
237
+ self.model = VITAMixtralModel(config)
238
+ self.vocab_size = config.vocab_size
239
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
240
+ self.router_aux_loss_coef = config.router_aux_loss_coef
241
+ self.num_experts = config.num_local_experts
242
+ self.num_experts_per_tok = config.num_experts_per_tok
243
+ # Initialize weights and apply final processing
244
+ self.post_init()
245
+
246
+ def get_model(self):
247
+ return self.model
248
+
249
+ def forward(
250
+ self,
251
+ input_ids: torch.LongTensor = None,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ position_ids: Optional[torch.LongTensor] = None,
254
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
255
+ inputs_embeds: Optional[torch.FloatTensor] = None,
256
+ labels: Optional[torch.LongTensor] = None,
257
+ use_cache: Optional[bool] = None,
258
+ output_attentions: Optional[bool] = None,
259
+ output_hidden_states: Optional[bool] = None,
260
+ images: Optional[torch.FloatTensor] = None,
261
+ audios: Optional[dict] = None,
262
+ sf_masks: Optional[torch.Tensor] = None,
263
+ output_router_logits: Optional[bool] = None,
264
+ return_dict: Optional[bool] = None,
265
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
266
+ if inputs_embeds is None:
267
+ (
268
+ input_ids,
269
+ position_ids,
270
+ attention_mask,
271
+ past_key_values,
272
+ inputs_embeds,
273
+ labels,
274
+ ) = self.prepare_inputs_labels_for_multimodal(
275
+ input_ids, position_ids, attention_mask, past_key_values, labels, images, audios, sf_masks
276
+ )
277
+
278
+ return super().forward(
279
+ input_ids=input_ids,
280
+ attention_mask=attention_mask,
281
+ position_ids=position_ids,
282
+ past_key_values=past_key_values,
283
+ inputs_embeds=inputs_embeds,
284
+ labels=labels,
285
+ use_cache=use_cache,
286
+ output_attentions=output_attentions,
287
+ output_hidden_states=output_hidden_states,
288
+ output_router_logits=output_router_logits,
289
+ return_dict=return_dict,
290
+ )
291
+
292
+ def prepare_inputs_for_generation_original(
293
+ self,
294
+ input_ids,
295
+ past_key_values=None,
296
+ attention_mask=None,
297
+ inputs_embeds=None,
298
+ output_router_logits=False,
299
+ **kwargs,
300
+ ):
301
+ # Omit tokens covered by past_key_values
302
+ if past_key_values is not None:
303
+ if isinstance(past_key_values, Cache):
304
+ cache_length = past_key_values.get_seq_length()
305
+ past_length = past_key_values.seen_tokens
306
+ max_cache_length = past_key_values.get_max_length()
307
+ else:
308
+ cache_length = past_length = past_key_values[0][0].shape[2]
309
+ max_cache_length = None
310
+
311
+ # Keep only the unprocessed tokens:
312
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
313
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
314
+ # input)
315
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
316
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
317
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
318
+ # input_ids based on the past_length.
319
+ elif past_length < input_ids.shape[1]:
320
+ input_ids = input_ids[:, past_length:]
321
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
322
+ else:
323
+ remove_prefix_length = input_ids.shape[1] - 1
324
+ input_ids = input_ids[:, remove_prefix_length:]
325
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
326
+ if (
327
+ max_cache_length is not None
328
+ and attention_mask is not None
329
+ and cache_length + input_ids.shape[1] > max_cache_length
330
+ ):
331
+ attention_mask = attention_mask[:, -max_cache_length:]
332
+
333
+ position_ids = kwargs.get("position_ids", None)
334
+ if attention_mask is not None and position_ids is None:
335
+ # create position_ids on the fly for batch generation
336
+ position_ids = attention_mask.long().cumsum(-1) - 1
337
+ position_ids.masked_fill_(attention_mask == 0, 1)
338
+ if past_key_values:
339
+ position_ids = position_ids[:, -input_ids.shape[1] :]
340
+
341
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
342
+ if inputs_embeds is not None and past_key_values is None:
343
+ model_inputs = {"inputs_embeds": inputs_embeds}
344
+ else:
345
+ model_inputs = {"input_ids": input_ids}
346
+
347
+ model_inputs.update(
348
+ {
349
+ "position_ids": position_ids,
350
+ "past_key_values": past_key_values,
351
+ "use_cache": kwargs.get("use_cache"),
352
+ "attention_mask": attention_mask,
353
+ "output_router_logits": output_router_logits,
354
+ }
355
+ )
356
+ return model_inputs
357
+
358
+ def prepare_inputs_for_generation(
359
+ self,
360
+ input_ids,
361
+ past_key_values=None,
362
+ inputs_embeds=None,
363
+ attention_mask=None,
364
+ output_router_logits=False,
365
+ **kwargs,
366
+ ):
367
+ images = kwargs.pop("images", None)
368
+ audios = kwargs.pop("audios", None)
369
+
370
+ _inputs = self.prepare_inputs_for_generation_original(
371
+ input_ids,
372
+ past_key_values=past_key_values,
373
+ inputs_embeds=inputs_embeds,
374
+ attention_mask=attention_mask,
375
+ output_router_logits=output_router_logits,
376
+ **kwargs,
377
+ )
378
+
379
+ if images is not None:
380
+ _inputs["images"] = images
381
+ if audios is not None:
382
+ _inputs["audios"] = audios
383
+ return _inputs
384
+
385
+ def expand2square(self, pil_img, background_color):
386
+ width, height = pil_img.size
387
+ if width == height:
388
+ return pil_img
389
+ elif width > height:
390
+ result = Image.new(pil_img.mode, (width, width), background_color)
391
+ result.paste(pil_img, (0, (width - height) // 2))
392
+ return result
393
+ else:
394
+ result = Image.new(pil_img.mode, (height, height), background_color)
395
+ result.paste(pil_img, ((height - width) // 2, 0))
396
+ return result
397
+
398
+ def process_images(self, images, model_cfg):
399
+ vision_tower = self.get_vision_tower()
400
+ if not vision_tower.is_loaded:
401
+ vision_tower.load_model()
402
+ image_processor = vision_tower.image_processor
403
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
404
+ new_images = []
405
+ if image_aspect_ratio == "pad":
406
+ for image in images:
407
+ image = self.expand2square(
408
+ image, tuple(int(x * 255) for x in image_processor.image_mean)
409
+ )
410
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
411
+ new_images.append(image)
412
+ else:
413
+ return image_processor(images, return_tensors="pt")["pixel_values"]
414
+ if all(x.shape == new_images[0].shape for x in new_images):
415
+ new_images = torch.stack(new_images, dim=0)
416
+ return new_images
417
+
418
+
419
+ AutoConfig.register("vita-mixtral", VITAMixtralConfig)
420
+ AutoModelForCausalLM.register(VITAMixtralConfig, VITAMixtralForCausalLM)
vita/model/language_model/vita_nemo.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModelForCausalLM,
10
+ MistralConfig,
11
+ MistralForCausalLM,
12
+ MistralModel,
13
+ )
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast, MoeCausalLMOutputWithPast
16
+ from transformers.generation.utils import GenerateOutput
17
+
18
+ from ..vita_arch import VITAMetaForCausalLM, VITAMetaModel
19
+
20
+
21
+ def custom_forward(
22
+ self,
23
+ input_ids: torch.LongTensor = None,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
27
+ inputs_embeds: Optional[torch.FloatTensor] = None,
28
+ labels: Optional[torch.LongTensor] = None,
29
+ use_cache: Optional[bool] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ return_dict: Optional[bool] = None,
33
+ cache_position: Optional[torch.LongTensor] = None,
34
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
35
+ r"""
36
+ Args:
37
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
38
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
39
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41
+
42
+ Returns:
43
+
44
+ Example:
45
+
46
+ ```python
47
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
48
+
49
+ >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
50
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
51
+
52
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
53
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
54
+
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59
+ ```"""
60
+
61
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
62
+ output_hidden_states = (
63
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
64
+ )
65
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
66
+
67
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
68
+ outputs = self.model(
69
+ input_ids=input_ids,
70
+ attention_mask=attention_mask,
71
+ position_ids=position_ids,
72
+ past_key_values=past_key_values,
73
+ inputs_embeds=inputs_embeds,
74
+ use_cache=use_cache,
75
+ output_attentions=output_attentions,
76
+ output_hidden_states=output_hidden_states,
77
+ return_dict=return_dict,
78
+ cache_position=cache_position,
79
+ )
80
+
81
+ hidden_states = outputs[0]
82
+ logits = self.lm_head(hidden_states)
83
+ # logits = logits.float()
84
+
85
+ loss = None
86
+ if labels is not None:
87
+ # Shift so that tokens < n predict n
88
+ shift_logits = logits[..., :-1, :].contiguous()
89
+ shift_labels = labels[..., 1:].contiguous()
90
+ # Flatten the tokens
91
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
92
+ shift_labels = shift_labels.view(-1)
93
+ # Ensure tensors are on the same device
94
+ shift_labels = shift_labels.to(shift_logits.device)
95
+ loss_fct = CrossEntropyLoss()
96
+ loss = loss_fct(shift_logits, shift_labels)
97
+
98
+ if not return_dict:
99
+ output = (logits,) + outputs[1:]
100
+ return (loss,) + output if loss is not None else output
101
+
102
+ return CausalLMOutputWithPast(
103
+ loss=loss,
104
+ logits=logits,
105
+ past_key_values=outputs.past_key_values,
106
+ hidden_states=outputs.hidden_states,
107
+ attentions=outputs.attentions,
108
+ )
109
+
110
+ MistralForCausalLM.forward = custom_forward
111
+
112
+
113
+ class VITAMistralConfig(MistralConfig):
114
+ model_type = "vita-Mistral"
115
+
116
+
117
+ class VITAMistralModel(VITAMetaModel, MistralModel):
118
+ config_class = VITAMistralConfig
119
+
120
+ def __init__(self, config: MistralConfig):
121
+ super(VITAMistralModel, self).__init__(config)
122
+
123
+
124
+ class VITAMistralForCausalLM(MistralForCausalLM, VITAMetaForCausalLM):
125
+ config_class = VITAMistralConfig
126
+
127
+ def __init__(self, config):
128
+ super(MistralForCausalLM, self).__init__(config)
129
+ self.model = VITAMistralModel(config)
130
+ self.vocab_size = config.vocab_size
131
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
132
+
133
+ # Initialize weights and apply final processing
134
+ self.post_init()
135
+
136
+ def get_model(self):
137
+ return self.model
138
+
139
+ def forward(
140
+ self,
141
+ input_ids: torch.LongTensor = None,
142
+ attention_mask: Optional[torch.Tensor] = None,
143
+ position_ids: Optional[torch.LongTensor] = None,
144
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
145
+ inputs_embeds: Optional[torch.FloatTensor] = None,
146
+ labels: Optional[torch.LongTensor] = None,
147
+ use_cache: Optional[bool] = None,
148
+ output_attentions: Optional[bool] = None,
149
+ output_hidden_states: Optional[bool] = None,
150
+ images: Optional[torch.FloatTensor] = None,
151
+ audios: Optional[dict] = None,
152
+ return_dict: Optional[bool] = None,
153
+ cache_position: Optional[torch.LongTensor] = None,
154
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
155
+ if inputs_embeds is None:
156
+ (
157
+ input_ids,
158
+ position_ids,
159
+ attention_mask,
160
+ past_key_values,
161
+ inputs_embeds,
162
+ labels,
163
+ ) = self.prepare_inputs_labels_for_multimodal(
164
+ input_ids, position_ids, attention_mask, past_key_values, labels, images, audios
165
+ )
166
+
167
+ return super().forward(
168
+ input_ids=input_ids,
169
+ attention_mask=attention_mask,
170
+ position_ids=position_ids,
171
+ past_key_values=past_key_values,
172
+ inputs_embeds=inputs_embeds,
173
+ labels=labels,
174
+ use_cache=use_cache,
175
+ output_attentions=output_attentions,
176
+ output_hidden_states=output_hidden_states,
177
+ return_dict=return_dict,
178
+ cache_position=cache_position,
179
+ )
180
+
181
+ @torch.no_grad()
182
+ def generate(
183
+ self,
184
+ inputs: Optional[torch.Tensor] = None,
185
+ images: Optional[torch.Tensor] = None,
186
+ audios: Optional[torch.Tensor] = None,
187
+ **kwargs,
188
+ ) -> Union[GenerateOutput, torch.LongTensor]:
189
+ position_ids = kwargs.pop("position_ids", None)
190
+ attention_mask = kwargs.pop("attention_mask", None)
191
+ if "inputs_embeds" in kwargs:
192
+ raise NotImplementedError("`inputs_embeds` is not supported")
193
+
194
+ if images is not None or audios is not None:
195
+ (
196
+ inputs,
197
+ position_ids,
198
+ attention_mask,
199
+ _,
200
+ inputs_embeds,
201
+ _
202
+ ) = self.prepare_inputs_labels_for_multimodal(
203
+ inputs,
204
+ position_ids,
205
+ attention_mask,
206
+ None,
207
+ None,
208
+ images,
209
+ audios
210
+ )
211
+ else:
212
+ inputs_embeds = self.get_model().embed_tokens(inputs)
213
+
214
+ return super().generate(
215
+ position_ids=position_ids,
216
+ attention_mask=attention_mask,
217
+ inputs_embeds=inputs_embeds,
218
+ **kwargs
219
+ )
220
+
221
+ def prepare_inputs_for_generation(
222
+ self,
223
+ input_ids,
224
+ past_key_values=None,
225
+ inputs_embeds=None,
226
+ attention_mask=None,
227
+ **kwargs,
228
+ ):
229
+ images = kwargs.pop("images", None)
230
+ audios = kwargs.pop("audios", None)
231
+
232
+ _inputs = super().prepare_inputs_for_generation(
233
+ input_ids,
234
+ past_key_values=past_key_values,
235
+ inputs_embeds=inputs_embeds,
236
+ attention_mask=attention_mask,
237
+ **kwargs,
238
+ )
239
+
240
+ if images is not None:
241
+ _inputs["images"] = images
242
+ if audios is not None:
243
+ _inputs["audios"] = audios
244
+ return _inputs
245
+
246
+ def expand2square(self, pil_img, background_color):
247
+ width, height = pil_img.size
248
+ if width == height:
249
+ return pil_img
250
+ elif width > height:
251
+ result = Image.new(pil_img.mode, (width, width), background_color)
252
+ result.paste(pil_img, (0, (width - height) // 2))
253
+ return result
254
+ else:
255
+ result = Image.new(pil_img.mode, (height, height), background_color)
256
+ result.paste(pil_img, ((height - width) // 2, 0))
257
+ return result
258
+
259
+ def process_images(self, images, model_cfg):
260
+ vision_tower = self.get_vision_tower()
261
+ if not vision_tower.is_loaded:
262
+ vision_tower.load_model()
263
+ image_processor = vision_tower.image_processor
264
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
265
+ new_images = []
266
+ if image_aspect_ratio == "pad":
267
+ for image in images:
268
+ image = self.expand2square(
269
+ image, tuple(int(x * 255) for x in image_processor.image_mean)
270
+ )
271
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
272
+ new_images.append(image)
273
+ else:
274
+ return image_processor(images, return_tensors="pt")["pixel_values"]
275
+ if all(x.shape == new_images[0].shape for x in new_images):
276
+ new_images = torch.stack(new_images, dim=0)
277
+ return new_images
278
+
279
+
280
+ AutoConfig.register("vita-Mistral", VITAMistralConfig)
281
+ AutoModelForCausalLM.register(VITAMistralConfig, VITAMistralForCausalLM)
282
+
vita/model/language_model/vita_qwen2.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModelForCausalLM,
10
+ Qwen2Config,
11
+ Qwen2ForCausalLM,
12
+ Qwen2Model,
13
+ )
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast, MoeCausalLMOutputWithPast
16
+ from transformers.generation.utils import GenerateOutput
17
+
18
+ from ..vita_arch import VITAMetaForCausalLM, VITAMetaModel
19
+
20
+
21
+ def custom_forward(
22
+ self,
23
+ input_ids: torch.LongTensor = None,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
27
+ inputs_embeds: Optional[torch.FloatTensor] = None,
28
+ labels: Optional[torch.LongTensor] = None,
29
+ use_cache: Optional[bool] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ return_dict: Optional[bool] = None,
33
+ cache_position: Optional[torch.LongTensor] = None,
34
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
35
+ r"""
36
+ Args:
37
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
38
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
39
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41
+
42
+ Returns:
43
+
44
+ Example:
45
+
46
+ ```python
47
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
48
+
49
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
50
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
51
+
52
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
53
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
54
+
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59
+ ```"""
60
+
61
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
62
+ output_hidden_states = (
63
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
64
+ )
65
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
66
+
67
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
68
+ outputs = self.model(
69
+ input_ids=input_ids,
70
+ attention_mask=attention_mask,
71
+ position_ids=position_ids,
72
+ past_key_values=past_key_values,
73
+ inputs_embeds=inputs_embeds,
74
+ use_cache=use_cache,
75
+ output_attentions=output_attentions,
76
+ output_hidden_states=output_hidden_states,
77
+ return_dict=return_dict,
78
+ cache_position=cache_position,
79
+ )
80
+
81
+ hidden_states = outputs[0]
82
+ logits = self.lm_head(hidden_states)
83
+ # logits = logits.float()
84
+
85
+ loss = None
86
+ if labels is not None:
87
+ # Shift so that tokens < n predict n
88
+ shift_logits = logits[..., :-1, :].contiguous()
89
+ shift_labels = labels[..., 1:].contiguous()
90
+ # Flatten the tokens
91
+ loss_fct = CrossEntropyLoss()
92
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
93
+ shift_labels = shift_labels.view(-1)
94
+ # Enable model parallelism
95
+ shift_labels = shift_labels.to(shift_logits.device)
96
+ loss = loss_fct(shift_logits, shift_labels)
97
+
98
+ if not return_dict:
99
+ output = (logits,) + outputs[1:]
100
+ return (loss,) + output if loss is not None else output
101
+
102
+ #import pdb; pdb.set_trace()
103
+ return CausalLMOutputWithPast(
104
+ loss=loss,
105
+ logits=logits,
106
+ past_key_values=outputs.past_key_values,
107
+ hidden_states=outputs.hidden_states,
108
+ attentions=outputs.attentions,
109
+ )
110
+
111
+
112
+ Qwen2ForCausalLM.forward = custom_forward
113
+
114
+
115
+ class VITAQwen2Config(Qwen2Config):
116
+ model_type = "vita-Qwen2"
117
+
118
+
119
+ class VITAQwen2Model(VITAMetaModel, Qwen2Model):
120
+ config_class = VITAQwen2Config
121
+
122
+ def __init__(self, config: Qwen2Config):
123
+ super(VITAQwen2Model, self).__init__(config)
124
+
125
+
126
+ class VITAQwen2ForCausalLM(Qwen2ForCausalLM, VITAMetaForCausalLM):
127
+ config_class = VITAQwen2Config
128
+
129
+ def __init__(self, config):
130
+ super(Qwen2ForCausalLM, self).__init__(config)
131
+ self.model = VITAQwen2Model(config)
132
+ self.vocab_size = config.vocab_size
133
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
134
+
135
+ # Initialize weights and apply final processing
136
+ self.post_init()
137
+
138
+ def get_model(self):
139
+ return self.model
140
+
141
+ def forward(
142
+ self,
143
+ input_ids: torch.LongTensor = None,
144
+ attention_mask: Optional[torch.Tensor] = None,
145
+ position_ids: Optional[torch.LongTensor] = None,
146
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
147
+ inputs_embeds: Optional[torch.FloatTensor] = None,
148
+ labels: Optional[torch.LongTensor] = None,
149
+ use_cache: Optional[bool] = None,
150
+ output_attentions: Optional[bool] = None,
151
+ output_hidden_states: Optional[bool] = None,
152
+ images: Optional[torch.FloatTensor] = None,
153
+ audios: Optional[dict] = None,
154
+ sf_masks: Optional[torch.Tensor] = None,
155
+ return_dict: Optional[bool] = None,
156
+ cache_position: Optional[torch.LongTensor] = None,
157
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
158
+ if inputs_embeds is None:
159
+ (
160
+ input_ids,
161
+ position_ids,
162
+ attention_mask,
163
+ past_key_values,
164
+ inputs_embeds,
165
+ labels,
166
+ ) = self.prepare_inputs_labels_for_multimodal(
167
+ input_ids, position_ids, attention_mask, past_key_values, labels, images, audios, sf_masks
168
+ )
169
+
170
+ return super().forward(
171
+ input_ids=input_ids,
172
+ attention_mask=attention_mask,
173
+ position_ids=position_ids,
174
+ past_key_values=past_key_values,
175
+ inputs_embeds=inputs_embeds,
176
+ labels=labels,
177
+ use_cache=use_cache,
178
+ output_attentions=output_attentions,
179
+ output_hidden_states=output_hidden_states,
180
+ return_dict=return_dict,
181
+ cache_position=cache_position,
182
+ )
183
+
184
+ @torch.no_grad()
185
+ def generate(
186
+ self,
187
+ inputs: Optional[torch.Tensor] = None,
188
+ images: Optional[torch.Tensor] = None,
189
+ audios: Optional[torch.Tensor] = None,
190
+ sf_masks: Optional[torch.Tensor] = None,
191
+ shared_v_pid_stride: Optional[int] = None,
192
+ **kwargs,
193
+ ) -> Union[GenerateOutput, torch.LongTensor]:
194
+ position_ids = kwargs.pop("position_ids", None)
195
+ attention_mask = kwargs.pop("attention_mask", None)
196
+ if "inputs_embeds" in kwargs:
197
+ raise NotImplementedError("`inputs_embeds` is not supported")
198
+
199
+ if images is not None or audios is not None:
200
+ (
201
+ inputs,
202
+ position_ids,
203
+ attention_mask,
204
+ _,
205
+ inputs_embeds,
206
+ _
207
+ ) = self.prepare_inputs_labels_for_multimodal(
208
+ inputs,
209
+ position_ids,
210
+ attention_mask,
211
+ None,
212
+ None,
213
+ images,
214
+ audios,
215
+ sf_masks,
216
+ shared_v_pid_stride,
217
+ )
218
+ else:
219
+ inputs_embeds = self.get_model().embed_tokens(inputs)
220
+
221
+ return super().generate(
222
+ position_ids=position_ids,
223
+ attention_mask=attention_mask,
224
+ inputs_embeds=inputs_embeds,
225
+ **kwargs
226
+ )
227
+
228
+ def prepare_inputs_for_generation(
229
+ self,
230
+ input_ids,
231
+ past_key_values=None,
232
+ inputs_embeds=None,
233
+ attention_mask=None,
234
+ **kwargs,
235
+ ):
236
+ images = kwargs.pop("images", None)
237
+ audios = kwargs.pop("audios", None)
238
+ sf_masks = kwargs.pop("sf_masks", None)
239
+
240
+ _inputs = super().prepare_inputs_for_generation(
241
+ input_ids,
242
+ past_key_values=past_key_values,
243
+ inputs_embeds=inputs_embeds,
244
+ attention_mask=attention_mask,
245
+ **kwargs,
246
+ )
247
+
248
+ # import pdb; pdb.set_trace()
249
+ position_ids = _inputs["position_ids"]
250
+ cache_position = _inputs["cache_position"]
251
+ if cache_position.shape[-1] == 1 and position_ids.shape[-1] > 1:
252
+ new_position_ids = torch.zeros((position_ids.shape[0],1), dtype=position_ids.dtype, device=position_ids.device)
253
+ new_position_ids[:, 0] = position_ids[0,-1] + cache_position[-1] + 1 - position_ids.shape[-1]
254
+ position_ids = new_position_ids
255
+ _inputs["position_ids"] = position_ids
256
+ # import pdb; pdb.set_trace()
257
+
258
+ if images is not None:
259
+ _inputs["images"] = images
260
+ if audios is not None:
261
+ _inputs["audios"] = audios
262
+ if sf_masks is not None:
263
+ _inputs["sf_masks"] = sf_masks
264
+ return _inputs
265
+
266
+ def expand2square(self, pil_img, background_color):
267
+ width, height = pil_img.size
268
+ if width == height:
269
+ return pil_img
270
+ elif width > height:
271
+ result = Image.new(pil_img.mode, (width, width), background_color)
272
+ result.paste(pil_img, (0, (width - height) // 2))
273
+ return result
274
+ else:
275
+ result = Image.new(pil_img.mode, (height, height), background_color)
276
+ result.paste(pil_img, ((height - width) // 2, 0))
277
+ return result
278
+
279
+ def process_images(self, images, model_cfg):
280
+ vision_tower = self.get_vision_tower()
281
+ if not vision_tower.is_loaded:
282
+ vision_tower.load_model()
283
+ image_processor = vision_tower.image_processor
284
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
285
+ new_images = []
286
+ if image_aspect_ratio == "pad":
287
+ for image in images:
288
+ image = self.expand2square(
289
+ image, tuple(int(x * 255) for x in image_processor.image_mean)
290
+ )
291
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
292
+ new_images.append(image)
293
+ else:
294
+ return image_processor(images, return_tensors="pt")["pixel_values"]
295
+ if all(x.shape == new_images[0].shape for x in new_images):
296
+ new_images = torch.stack(new_images, dim=0)
297
+ return new_images
298
+
299
+
300
+ AutoConfig.register("vita-Qwen2", VITAQwen2Config)
301
+ AutoModelForCausalLM.register(VITAQwen2Config, VITAQwen2ForCausalLM)
302
+
303
+
304
+
vita/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import yaml
4
+ import torch
5
+ from transformers.utils.hub import get_file_from_repo
6
+
7
+ from .clip.clip_encoder import CLIPVisionTower
8
+ from .eva_clip.eva_clip_encoder import EvaClipVisionTower
9
+ from .internvit.internvit_encoder import InternViTVisionTower
10
+ from .siglip.siglip_encoder import SiglipVisionTower, SiglipVisionTowerS2
11
+ from .whale.init_model import init_model
12
+
13
+
14
+ def build_vision_tower(vision_tower_cfg, **kwargs):
15
+ vision_tower = getattr(
16
+ vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)
17
+ )
18
+ use_s2 = getattr(vision_tower_cfg, "use_s2", False)
19
+
20
+ if "sig" in vision_tower.lower():
21
+ if use_s2:
22
+ return SiglipVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
23
+ else:
24
+ return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
25
+ elif "eva" in vision_tower.lower():
26
+ if use_s2:
27
+ raise ValueError(f"Currently not supporting S2 for EVA-CLIP")
28
+ else:
29
+ return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
30
+
31
+ elif "clip" in vision_tower.lower():
32
+ if use_s2:
33
+ raise ValueError(f"Currently not supporting S2 for CLIP")
34
+ else:
35
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
36
+ elif "internvit" in vision_tower.lower():
37
+ if use_s2:
38
+ raise ValueError(f"Currently not supporting S2 for InternViT")
39
+ else:
40
+ return InternViTVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
41
+
42
+ else:
43
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
44
+
45
+
46
+ def build_audio_encoder(audio_encoder_config, **kwargs):
47
+ with open(get_file_from_repo(audio_encoder_config.mm_audio_encoder, "train.yaml"), "r") as fin:
48
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
49
+
50
+ configs["cmvn_file"] = get_file_from_repo(audio_encoder_config.mm_audio_encoder, "global_cmvn")
51
+
52
+ configs["model_conf"]["freeze_encoder"] = getattr(
53
+ audio_encoder_config, "freeze_audio_encoder", True
54
+ )
55
+ configs["model_conf"]["freeze_adpter"] = getattr(
56
+ audio_encoder_config, "freeze_audio_encoder_adapter", True
57
+ )
58
+ configs["model_conf"]["audio_prompt_finetune"] = getattr(
59
+ audio_encoder_config, "audio_prompt_finetune", False
60
+ )
61
+ configs["model_conf"]["audio_prompt_num"] = getattr(
62
+ audio_encoder_config, "audio_prompt_num", 0
63
+ )
64
+
65
+ audio_encoder = init_model(configs)
66
+
67
+ checkpoint = torch.load(get_file_from_repo(audio_encoder_config.mm_audio_encoder, "final.pt"), map_location="cpu")
68
+ model_dict = audio_encoder.state_dict()
69
+ for key in model_dict.keys():
70
+ if key in checkpoint.keys():
71
+ if model_dict[key].shape == checkpoint[key].shape:
72
+ model_dict[key] = checkpoint[key]
73
+ else:
74
+ print(
75
+ "Key {} has different shape, {} VS {}".format(
76
+ key, model_dict[key].shape, checkpoint[key].shape
77
+ )
78
+ )
79
+ else:
80
+ print("Key {} has not in resume model".format(key))
81
+ audio_encoder.load_state_dict(model_dict)
82
+
83
+ return audio_encoder
vita/model/multimodal_encoder/clip/clip_encoder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
4
+
5
+
6
+ class CLIPVisionTower(nn.Module):
7
+ def __init__(self, vision_tower, args, delay_load=False):
8
+ super().__init__()
9
+
10
+ self.is_loaded = False
11
+
12
+ self.vision_tower_name = vision_tower
13
+ self.select_layer = -2
14
+
15
+ if not delay_load:
16
+ self.load_model()
17
+ else:
18
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
19
+
20
+ def load_model(self):
21
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
22
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
23
+ self.vision_tower.requires_grad_(False)
24
+
25
+ self.is_loaded = True
26
+
27
+ def feature_select(self, image_forward_outs):
28
+ image_features = image_forward_outs.hidden_states[self.select_layer]
29
+
30
+ image_features = image_features[:, 1:]
31
+
32
+ return image_features
33
+
34
+ @torch.no_grad()
35
+ def forward(self, images):
36
+ if type(images) is list:
37
+ image_features = []
38
+ for image in images:
39
+ image_forward_out = self.vision_tower(
40
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
41
+ output_hidden_states=True,
42
+ )
43
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
44
+ image_features.append(image_feature)
45
+ else:
46
+ image_forward_outs = self.vision_tower(
47
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
48
+ )
49
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
50
+
51
+ return image_features
52
+
53
+ @property
54
+ def dummy_feature(self):
55
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
56
+
57
+ @property
58
+ def dtype(self):
59
+ return self.vision_tower.dtype
60
+
61
+ @property
62
+ def device(self):
63
+ return self.vision_tower.device
64
+
65
+ @property
66
+ def config(self):
67
+ if self.is_loaded:
68
+ return self.vision_tower.config
69
+ else:
70
+ return self.cfg_only
71
+
72
+ @property
73
+ def hidden_size(self):
74
+ return self.config.hidden_size
75
+
76
+ @property
77
+ def num_patches(self):
78
+ return (self.config.image_size // self.config.patch_size) ** 2
vita/model/multimodal_encoder/eva_clip/eva_clip_encoder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .eva_clip_processors import EvaClipImageTrainProcessor
5
+ from .eva_vit import Eva2LargePlusEncoder
6
+
7
+
8
+ class EvaClipVisionTower(nn.Module):
9
+ def __init__(self, vision_tower, args, delay_load=False):
10
+ super().__init__()
11
+
12
+ self.is_loaded = False
13
+
14
+ self.vision_tower_path = vision_tower
15
+ self.config = VisionTowerConfig()
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ else:
20
+ self.cfg_only = self.config
21
+
22
+ def load_model(self):
23
+ self.image_processor = EvaClipImageTrainProcessor(self.config.image_size)
24
+ self.vision_tower = Eva2LargePlusEncoder(self.vision_tower_path)
25
+ self.vision_tower.requires_grad_(False)
26
+
27
+ self.is_loaded = True
28
+
29
+ @torch.no_grad()
30
+ def forward(self, images):
31
+ if type(images) is list:
32
+ image_features = []
33
+ for image in images:
34
+ image_feature = self.vision_tower(
35
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0)
36
+ ).to(image.dtype)
37
+ image_features.append(image_feature)
38
+ else:
39
+ image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(
40
+ images.dtype
41
+ )
42
+
43
+ return image_features
44
+
45
+ @property
46
+ def dtype(self):
47
+ return self.vision_tower.dtype
48
+
49
+ @property
50
+ def device(self):
51
+ return self.vision_tower.device
52
+
53
+ @property
54
+ def hidden_size(self):
55
+ return self.config.hidden_size
56
+
57
+ @property
58
+ def num_patches(self):
59
+ return (self.config.image_size // self.config.patch_size) ** 2
60
+
61
+
62
+ class VisionTowerConfig:
63
+ def __init__(self):
64
+ self.image_size = 336
65
+ self.patch_size = 14
66
+ self.hidden_size = 1024
vita/model/multimodal_encoder/eva_clip/eva_clip_processors.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3
+ """
4
+
5
+ from PIL import Image
6
+ from transformers.image_processing_utils import BatchFeature
7
+ from transformers.image_transforms import convert_to_rgb
8
+
9
+ from torchvision import transforms
10
+ from torchvision.transforms.functional import InterpolationMode
11
+
12
+
13
+ class BaseProcessor:
14
+ def __init__(self):
15
+ self.transform = lambda x: x
16
+ return
17
+
18
+ def __call__(self, item):
19
+ return self.transform(item)
20
+
21
+
22
+ class EvaClipImageBaseProcessor(BaseProcessor):
23
+ def __init__(self, mean=None, std=None):
24
+ self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
25
+ self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
26
+
27
+ self.normalize = transforms.Normalize(self.mean, self.std)
28
+
29
+ @property
30
+ def image_mean(self):
31
+ return self.mean
32
+
33
+
34
+ class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
35
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
36
+ super().__init__(mean=mean, std=std)
37
+
38
+ self.transform = transforms.Compose(
39
+ [
40
+ convert_to_rgb,
41
+ transforms.Resize(
42
+ image_size,
43
+ interpolation=InterpolationMode.BICUBIC,
44
+ ),
45
+ transforms.CenterCrop(image_size),
46
+ transforms.ToTensor(),
47
+ self.normalize,
48
+ ]
49
+ )
50
+
51
+ self.image_size = image_size
52
+
53
+ def preprocess(self, images, return_tensors):
54
+ if isinstance(images, Image.Image):
55
+ images = [images]
56
+ else:
57
+ assert isinstance(images, list)
58
+
59
+ transformed_images = [self.transform(image).numpy() for image in images]
60
+ data = {"pixel_values": transformed_images}
61
+
62
+ return BatchFeature(data=data, tensor_type=return_tensors)
63
+
64
+ def __call__(self, item):
65
+ return self.transform(item)
66
+
67
+ @property
68
+ def crop_size(self):
69
+ return {"height": self.image_size, "width": self.image_size}
vita/model/multimodal_encoder/eva_clip/eva_vit.py ADDED
@@ -0,0 +1,982 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3
+ """
4
+
5
+ import logging
6
+
7
+ # --------------------------------------------------------
8
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
9
+ # --------------------------------------------------------
10
+ import math
11
+ import os
12
+ from dataclasses import dataclass
13
+ from functools import partial
14
+ from math import pi
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from einops import rearrange, repeat
21
+ from torch import nn as nn
22
+
23
+ import xformers.ops as xops
24
+
25
+
26
+ def broadcat(tensors, dim=-1):
27
+ num_tensors = len(tensors)
28
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
29
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
30
+ shape_len = list(shape_lens)[0]
31
+ dim = (dim + shape_len) if dim < 0 else dim
32
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
33
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
34
+ assert all(
35
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
36
+ ), "invalid dimensions for broadcastable concatentation"
37
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
38
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
39
+ expanded_dims.insert(dim, (dim, dims[dim]))
40
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
41
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
42
+ return torch.cat(tensors, dim=dim)
43
+
44
+
45
+ def rotate_half(x):
46
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
47
+ x1, x2 = x.unbind(dim=-1)
48
+ x = torch.stack((-x2, x1), dim=-1)
49
+ return rearrange(x, "... d r -> ... (d r)")
50
+
51
+
52
+ class VisionRotaryEmbeddingFast(nn.Module):
53
+ def __init__(
54
+ self,
55
+ dim,
56
+ pt_seq_len,
57
+ ft_seq_len=None,
58
+ custom_freqs=None,
59
+ freqs_for="lang",
60
+ theta=10000,
61
+ max_freq=10,
62
+ num_freqs=1,
63
+ patch_dropout=0.0,
64
+ ):
65
+ super().__init__()
66
+ if custom_freqs:
67
+ freqs = custom_freqs
68
+ elif freqs_for == "lang":
69
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
70
+ elif freqs_for == "pixel":
71
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
72
+ elif freqs_for == "constant":
73
+ freqs = torch.ones(num_freqs).float()
74
+ else:
75
+ raise ValueError(f"unknown modality {freqs_for}")
76
+
77
+ if ft_seq_len is None:
78
+ ft_seq_len = pt_seq_len
79
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
80
+
81
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
82
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
83
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
84
+
85
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
86
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
87
+
88
+ self.patch_dropout = patch_dropout
89
+
90
+ self.register_buffer("freqs_cos", freqs_cos)
91
+ self.register_buffer("freqs_sin", freqs_sin)
92
+
93
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
94
+
95
+ def forward(self, t, patch_indices_keep=None):
96
+ if patch_indices_keep is not None:
97
+ batch = t.size()[0]
98
+ batch_indices = torch.arange(batch)
99
+ batch_indices = batch_indices[..., None]
100
+
101
+ freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
102
+ freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
103
+
104
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
105
+ freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
106
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
107
+ freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
108
+
109
+ return t * freqs_cos + rotate_half(t) * freqs_sin
110
+
111
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
112
+
113
+
114
+ class LayerNorm(nn.LayerNorm):
115
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
116
+
117
+ def forward(self, x: torch.Tensor):
118
+ orig_type = x.dtype
119
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
120
+ return x.to(orig_type)
121
+
122
+
123
+ class PatchDropout(nn.Module):
124
+ """
125
+ https://arxiv.org/abs/2212.00794
126
+ """
127
+
128
+ def __init__(self, prob, exclude_first_token=True):
129
+ super().__init__()
130
+ assert 0 <= prob < 1.0
131
+ self.prob = prob
132
+ self.exclude_first_token = exclude_first_token # exclude CLS token
133
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
134
+
135
+ def forward(self, x):
136
+ if not self.training or self.prob == 0.0:
137
+ return x
138
+
139
+ if self.exclude_first_token:
140
+ cls_tokens, x = x[:, :1], x[:, 1:]
141
+ else:
142
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
143
+
144
+ batch = x.size()[0]
145
+ num_tokens = x.size()[1]
146
+
147
+ batch_indices = torch.arange(batch)
148
+ batch_indices = batch_indices[..., None]
149
+
150
+ keep_prob = 1 - self.prob
151
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
152
+
153
+ rand = torch.randn(batch, num_tokens)
154
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
155
+
156
+ x = x[batch_indices, patch_indices_keep]
157
+
158
+ if self.exclude_first_token:
159
+ x = torch.cat((cls_tokens, x), dim=1)
160
+
161
+ if self.training and os.getenv("RoPE") == "1":
162
+ return x, patch_indices_keep
163
+
164
+ return x
165
+
166
+
167
+ try:
168
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
169
+ except:
170
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
171
+
172
+ if os.getenv("ENV_TYPE") == "deepspeed":
173
+ try:
174
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
175
+ except:
176
+ from torch.utils.checkpoint import checkpoint
177
+ else:
178
+ from torch.utils.checkpoint import checkpoint
179
+
180
+
181
+ class DropPath(nn.Module):
182
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
183
+
184
+ def __init__(self, drop_prob=None):
185
+ super(DropPath, self).__init__()
186
+ self.drop_prob = drop_prob
187
+
188
+ def forward(self, x):
189
+ return drop_path(x, self.drop_prob, self.training)
190
+
191
+ def extra_repr(self) -> str:
192
+ return "p={}".format(self.drop_prob)
193
+
194
+
195
+ class Mlp(nn.Module):
196
+ def __init__(
197
+ self,
198
+ in_features,
199
+ hidden_features=None,
200
+ out_features=None,
201
+ act_layer=nn.GELU,
202
+ norm_layer=nn.LayerNorm,
203
+ drop=0.0,
204
+ subln=False,
205
+ ):
206
+ super().__init__()
207
+ out_features = out_features or in_features
208
+ hidden_features = hidden_features or in_features
209
+ self.fc1 = nn.Linear(in_features, hidden_features)
210
+ self.act = act_layer()
211
+
212
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
213
+
214
+ self.fc2 = nn.Linear(hidden_features, out_features)
215
+ self.drop = nn.Dropout(drop)
216
+
217
+ def forward(self, x):
218
+ x = self.fc1(x)
219
+ x = self.act(x)
220
+ # x = self.drop(x)
221
+ # commit this for the orignal BERT implement
222
+ x = self.ffn_ln(x)
223
+
224
+ x = self.fc2(x)
225
+ x = self.drop(x)
226
+ return x
227
+
228
+
229
+ class SwiGLU(nn.Module):
230
+ def __init__(
231
+ self,
232
+ in_features,
233
+ hidden_features=None,
234
+ out_features=None,
235
+ act_layer=nn.SiLU,
236
+ drop=0.0,
237
+ norm_layer=nn.LayerNorm,
238
+ subln=False,
239
+ ):
240
+ super().__init__()
241
+ out_features = out_features or in_features
242
+ hidden_features = hidden_features or in_features
243
+
244
+ self.w1 = nn.Linear(in_features, hidden_features)
245
+ self.w2 = nn.Linear(in_features, hidden_features)
246
+
247
+ self.act = act_layer()
248
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
249
+ self.w3 = nn.Linear(hidden_features, out_features)
250
+
251
+ self.drop = nn.Dropout(drop)
252
+
253
+ def forward(self, x):
254
+ x1 = self.w1(x)
255
+ x2 = self.w2(x)
256
+ hidden = self.act(x1) * x2
257
+ x = self.ffn_ln(hidden)
258
+ x = self.w3(x)
259
+ x = self.drop(x)
260
+ return x
261
+
262
+
263
+ class Attention(nn.Module):
264
+ def __init__(
265
+ self,
266
+ dim,
267
+ num_heads=8,
268
+ qkv_bias=False,
269
+ qk_scale=None,
270
+ attn_drop=0.0,
271
+ proj_drop=0.0,
272
+ window_size=None,
273
+ attn_head_dim=None,
274
+ xattn=False,
275
+ rope=None,
276
+ subln=False,
277
+ norm_layer=nn.LayerNorm,
278
+ ):
279
+ super().__init__()
280
+ self.num_heads = num_heads
281
+ head_dim = dim // num_heads
282
+ if attn_head_dim is not None:
283
+ head_dim = attn_head_dim
284
+ all_head_dim = head_dim * self.num_heads
285
+ self.scale = qk_scale or head_dim**-0.5
286
+
287
+ self.subln = subln
288
+ if self.subln:
289
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
290
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
291
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
292
+ else:
293
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
294
+
295
+ if qkv_bias:
296
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
297
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
298
+ else:
299
+ self.q_bias = None
300
+ self.v_bias = None
301
+
302
+ if window_size:
303
+ self.window_size = window_size
304
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
305
+ self.relative_position_bias_table = nn.Parameter(
306
+ torch.zeros(self.num_relative_distance, num_heads)
307
+ ) # 2*Wh-1 * 2*Ww-1, nH
308
+ # cls to token & token 2 cls & cls to cls
309
+
310
+ # get pair-wise relative position index for each token inside the window
311
+ coords_h = torch.arange(window_size[0])
312
+ coords_w = torch.arange(window_size[1])
313
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
314
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
315
+ relative_coords = (
316
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
317
+ ) # 2, Wh*Ww, Wh*Ww
318
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
319
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
320
+ relative_coords[:, :, 1] += window_size[1] - 1
321
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
322
+ relative_position_index = torch.zeros(
323
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
324
+ )
325
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
326
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
327
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
328
+ relative_position_index[0, 0] = self.num_relative_distance - 1
329
+
330
+ self.register_buffer("relative_position_index", relative_position_index)
331
+ else:
332
+ self.window_size = None
333
+ self.relative_position_bias_table = None
334
+ self.relative_position_index = None
335
+
336
+ self.attn_drop = nn.Dropout(attn_drop)
337
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
338
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
339
+ self.proj = nn.Linear(all_head_dim, dim)
340
+ self.proj_drop = nn.Dropout(proj_drop)
341
+ self.xattn = xattn
342
+ self.xattn_drop = attn_drop
343
+
344
+ self.rope = rope
345
+
346
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
347
+ B, N, C = x.shape
348
+ if self.subln:
349
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
350
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
351
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
352
+
353
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
354
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
355
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
356
+ else:
357
+
358
+ qkv_bias = None
359
+ if self.q_bias is not None:
360
+ qkv_bias = torch.cat(
361
+ (self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)
362
+ )
363
+
364
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
365
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
366
+ 2, 0, 3, 1, 4
367
+ ) # 3, B, num_heads, N, C
368
+ q, k, v = qkv[0], qkv[1], qkv[2]
369
+
370
+ if self.rope:
371
+ # slightly fast impl
372
+ q_t = q[:, :, 1:, :]
373
+ ro_q_t = self.rope(q_t)
374
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
375
+
376
+ k_t = k[:, :, 1:, :]
377
+ ro_k_t = self.rope(k_t)
378
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
379
+
380
+ if self.xattn:
381
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
382
+ k = k.permute(0, 2, 1, 3)
383
+ v = v.permute(0, 2, 1, 3)
384
+
385
+ x = xops.memory_efficient_attention(
386
+ q,
387
+ k,
388
+ v,
389
+ p=self.xattn_drop,
390
+ scale=self.scale,
391
+ )
392
+ x = x.reshape(B, N, -1)
393
+ x = self.inner_attn_ln(x)
394
+ x = self.proj(x)
395
+ x = self.proj_drop(x)
396
+ else:
397
+ q = q * self.scale
398
+ attn = q @ k.transpose(-2, -1)
399
+
400
+ if self.relative_position_bias_table is not None:
401
+ relative_position_bias = self.relative_position_bias_table[
402
+ self.relative_position_index.view(-1)
403
+ ].view(
404
+ self.window_size[0] * self.window_size[1] + 1,
405
+ self.window_size[0] * self.window_size[1] + 1,
406
+ -1,
407
+ ) # Wh*Ww,Wh*Ww,nH
408
+ relative_position_bias = relative_position_bias.permute(
409
+ 2, 0, 1
410
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
411
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
412
+
413
+ if rel_pos_bias is not None:
414
+ attn = attn + rel_pos_bias.type_as(attn)
415
+
416
+ if attn_mask is not None:
417
+ attn_mask = attn_mask.bool()
418
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
419
+
420
+ attn = attn.softmax(dim=-1)
421
+ attn = self.attn_drop(attn)
422
+
423
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
424
+ x = self.inner_attn_ln(x)
425
+ x = self.proj(x)
426
+ x = self.proj_drop(x)
427
+ return x
428
+
429
+
430
+ class Block(nn.Module):
431
+ def __init__(
432
+ self,
433
+ dim,
434
+ num_heads,
435
+ mlp_ratio=4.0,
436
+ qkv_bias=False,
437
+ qk_scale=None,
438
+ drop=0.0,
439
+ attn_drop=0.0,
440
+ drop_path=0.0,
441
+ init_values=None,
442
+ act_layer=nn.GELU,
443
+ norm_layer=nn.LayerNorm,
444
+ window_size=None,
445
+ attn_head_dim=None,
446
+ xattn=False,
447
+ rope=None,
448
+ postnorm=False,
449
+ subln=False,
450
+ naiveswiglu=False,
451
+ ):
452
+ super().__init__()
453
+ self.norm1 = norm_layer(dim)
454
+ self.attn = Attention(
455
+ dim,
456
+ num_heads=num_heads,
457
+ qkv_bias=qkv_bias,
458
+ qk_scale=qk_scale,
459
+ attn_drop=attn_drop,
460
+ proj_drop=drop,
461
+ window_size=window_size,
462
+ attn_head_dim=attn_head_dim,
463
+ xattn=xattn,
464
+ rope=rope,
465
+ subln=subln,
466
+ norm_layer=norm_layer,
467
+ )
468
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
469
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
470
+ self.norm2 = norm_layer(dim)
471
+ mlp_hidden_dim = int(dim * mlp_ratio)
472
+
473
+ if naiveswiglu:
474
+ self.mlp = SwiGLU(
475
+ in_features=dim,
476
+ hidden_features=mlp_hidden_dim,
477
+ subln=subln,
478
+ norm_layer=norm_layer,
479
+ )
480
+ else:
481
+ self.mlp = Mlp(
482
+ in_features=dim,
483
+ hidden_features=mlp_hidden_dim,
484
+ act_layer=act_layer,
485
+ subln=subln,
486
+ drop=drop,
487
+ )
488
+
489
+ if init_values is not None and init_values > 0:
490
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
491
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
492
+ else:
493
+ self.gamma_1, self.gamma_2 = None, None
494
+
495
+ self.postnorm = postnorm
496
+
497
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
498
+ if self.gamma_1 is None:
499
+ if self.postnorm:
500
+ x = x + self.drop_path(
501
+ self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
502
+ )
503
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
504
+ else:
505
+ x = x + self.drop_path(
506
+ self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
507
+ )
508
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
509
+ else:
510
+ if self.postnorm:
511
+ x = x + self.drop_path(
512
+ self.gamma_1
513
+ * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
514
+ )
515
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
516
+ else:
517
+ x = x + self.drop_path(
518
+ self.gamma_1
519
+ * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
520
+ )
521
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
522
+ return x
523
+
524
+
525
+ class PatchEmbed(nn.Module):
526
+ """Image to Patch Embedding"""
527
+
528
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
529
+ super().__init__()
530
+ img_size = to_2tuple(img_size)
531
+ patch_size = to_2tuple(patch_size)
532
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
533
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
534
+ self.img_size = img_size
535
+ self.patch_size = patch_size
536
+ self.num_patches = num_patches
537
+
538
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
539
+
540
+ def forward(self, x, **kwargs):
541
+ B, C, H, W = x.shape
542
+ # FIXME look at relaxing size constraints
543
+ assert (
544
+ H == self.img_size[0] and W == self.img_size[1]
545
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
546
+ x = self.proj(x).flatten(2).transpose(1, 2)
547
+ return x
548
+
549
+
550
+ class RelativePositionBias(nn.Module):
551
+ def __init__(self, window_size, num_heads):
552
+ super().__init__()
553
+ self.window_size = window_size
554
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
555
+ self.relative_position_bias_table = nn.Parameter(
556
+ torch.zeros(self.num_relative_distance, num_heads)
557
+ ) # 2*Wh-1 * 2*Ww-1, nH
558
+ # cls to token & token 2 cls & cls to cls
559
+
560
+ # get pair-wise relative position index for each token inside the window
561
+ coords_h = torch.arange(window_size[0])
562
+ coords_w = torch.arange(window_size[1])
563
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
564
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
565
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
566
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
567
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
568
+ relative_coords[:, :, 1] += window_size[1] - 1
569
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
570
+ relative_position_index = torch.zeros(
571
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
572
+ )
573
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
574
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
575
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
576
+ relative_position_index[0, 0] = self.num_relative_distance - 1
577
+
578
+ self.register_buffer("relative_position_index", relative_position_index)
579
+
580
+ def forward(self):
581
+ relative_position_bias = self.relative_position_bias_table[
582
+ self.relative_position_index.view(-1)
583
+ ].view(
584
+ self.window_size[0] * self.window_size[1] + 1,
585
+ self.window_size[0] * self.window_size[1] + 1,
586
+ -1,
587
+ ) # Wh*Ww,Wh*Ww,nH
588
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
589
+
590
+
591
+ class EVAVisionTransformer(nn.Module):
592
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
593
+
594
+ def __init__(
595
+ self,
596
+ img_size=224,
597
+ patch_size=16,
598
+ in_chans=3,
599
+ num_classes=1000,
600
+ embed_dim=768,
601
+ depth=12,
602
+ num_heads=12,
603
+ mlp_ratio=4.0,
604
+ qkv_bias=False,
605
+ qk_scale=None,
606
+ drop_rate=0.0,
607
+ attn_drop_rate=0.0,
608
+ drop_path_rate=0.0,
609
+ norm_layer=nn.LayerNorm,
610
+ init_values=None,
611
+ patch_dropout=0.0,
612
+ use_abs_pos_emb=True,
613
+ use_rel_pos_bias=False,
614
+ use_shared_rel_pos_bias=False,
615
+ rope=False,
616
+ use_mean_pooling=True,
617
+ init_scale=0.001,
618
+ grad_checkpointing=False,
619
+ xattn=False,
620
+ postnorm=False,
621
+ pt_hw_seq_len=16,
622
+ intp_freq=False,
623
+ naiveswiglu=False,
624
+ subln=False,
625
+ ):
626
+ super().__init__()
627
+ self.image_size = img_size
628
+ self.num_classes = num_classes
629
+ self.num_features = (
630
+ self.embed_dim
631
+ ) = embed_dim # num_features for consistency with other models
632
+
633
+ self.patch_embed = PatchEmbed(
634
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
635
+ )
636
+ num_patches = self.patch_embed.num_patches
637
+
638
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
639
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
640
+ if use_abs_pos_emb:
641
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
642
+ else:
643
+ self.pos_embed = None
644
+ self.pos_drop = nn.Dropout(p=drop_rate)
645
+
646
+ if use_shared_rel_pos_bias:
647
+ self.rel_pos_bias = RelativePositionBias(
648
+ window_size=self.patch_embed.patch_shape, num_heads=num_heads
649
+ )
650
+ else:
651
+ self.rel_pos_bias = None
652
+
653
+ if rope:
654
+ half_head_dim = embed_dim // num_heads // 2
655
+ hw_seq_len = img_size // patch_size
656
+ self.rope = VisionRotaryEmbeddingFast(
657
+ dim=half_head_dim,
658
+ pt_seq_len=pt_hw_seq_len,
659
+ ft_seq_len=hw_seq_len if intp_freq else None,
660
+ # patch_dropout=patch_dropout
661
+ )
662
+ else:
663
+ self.rope = None
664
+
665
+ self.naiveswiglu = naiveswiglu
666
+
667
+ dpr = [
668
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
669
+ ] # stochastic depth decay rule
670
+ self.use_rel_pos_bias = use_rel_pos_bias
671
+ self.blocks = nn.ModuleList(
672
+ [
673
+ Block(
674
+ dim=embed_dim,
675
+ num_heads=num_heads,
676
+ mlp_ratio=mlp_ratio,
677
+ qkv_bias=qkv_bias,
678
+ qk_scale=qk_scale,
679
+ drop=drop_rate,
680
+ attn_drop=attn_drop_rate,
681
+ drop_path=dpr[i],
682
+ norm_layer=norm_layer,
683
+ init_values=init_values,
684
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
685
+ xattn=xattn,
686
+ rope=self.rope,
687
+ postnorm=postnorm,
688
+ subln=subln,
689
+ naiveswiglu=naiveswiglu,
690
+ )
691
+ for i in range(depth)
692
+ ]
693
+ )
694
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
695
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
696
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
697
+
698
+ if self.pos_embed is not None:
699
+ trunc_normal_(self.pos_embed, std=0.02)
700
+
701
+ trunc_normal_(self.cls_token, std=0.02)
702
+ # trunc_normal_(self.mask_token, std=.02)
703
+
704
+ self.apply(self._init_weights)
705
+ self.fix_init_weight()
706
+
707
+ if isinstance(self.head, nn.Linear):
708
+ trunc_normal_(self.head.weight, std=0.02)
709
+ self.head.weight.data.mul_(init_scale)
710
+ self.head.bias.data.mul_(init_scale)
711
+
712
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
713
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
714
+
715
+ self.grad_checkpointing = grad_checkpointing
716
+
717
+ def fix_init_weight(self):
718
+ def rescale(param, layer_id):
719
+ param.div_(math.sqrt(2.0 * layer_id))
720
+
721
+ for layer_id, layer in enumerate(self.blocks):
722
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
723
+ if self.naiveswiglu:
724
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
725
+ else:
726
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
727
+
728
+ def get_cast_dtype(self) -> torch.dtype:
729
+ return self.blocks[0].mlp.fc2.weight.dtype
730
+
731
+ def _init_weights(self, m):
732
+ if isinstance(m, nn.Linear):
733
+ trunc_normal_(m.weight, std=0.02)
734
+ if m.bias is not None:
735
+ nn.init.constant_(m.bias, 0)
736
+ elif isinstance(m, nn.LayerNorm):
737
+ nn.init.constant_(m.bias, 0)
738
+ nn.init.constant_(m.weight, 1.0)
739
+
740
+ def get_num_layers(self):
741
+ return len(self.blocks)
742
+
743
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
744
+ assert unlocked_groups == 0, "partial locking not currently supported for this model"
745
+ for param in self.parameters():
746
+ param.requires_grad = False
747
+
748
+ @torch.jit.ignore
749
+ def set_grad_checkpointing(self, enable=True):
750
+ self.grad_checkpointing = enable
751
+
752
+ @torch.jit.ignore
753
+ def no_weight_decay(self):
754
+ return {"pos_embed", "cls_token"}
755
+
756
+ def get_classifier(self):
757
+ return self.head
758
+
759
+ def reset_classifier(self, num_classes, global_pool=""):
760
+ self.num_classes = num_classes
761
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
762
+
763
+ def forward_features(self, x, return_all_features=False):
764
+
765
+ x = self.patch_embed(x)
766
+ batch_size, seq_len, _ = x.size()
767
+
768
+ cls_tokens = self.cls_token.expand(
769
+ batch_size, -1, -1
770
+ ) # stole cls_tokens impl from Phil Wang, thanks
771
+ x = torch.cat((cls_tokens, x), dim=1)
772
+ if self.pos_embed is not None:
773
+ x = x + self.pos_embed
774
+ x = self.pos_drop(x)
775
+
776
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
777
+ if os.getenv("RoPE") == "1":
778
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
779
+ x, patch_indices_keep = self.patch_dropout(x)
780
+ self.rope.forward = partial(
781
+ self.rope.forward, patch_indices_keep=patch_indices_keep
782
+ )
783
+ else:
784
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
785
+ x = self.patch_dropout(x)
786
+ else:
787
+ x = self.patch_dropout(x)
788
+
789
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
790
+ for i, blk in enumerate(self.blocks):
791
+ if i == len(self.blocks) - 1:
792
+ continue
793
+ if self.grad_checkpointing:
794
+ x = checkpoint(blk, x, (rel_pos_bias,))
795
+ else:
796
+ x = blk(x, rel_pos_bias=rel_pos_bias)
797
+
798
+ if not return_all_features:
799
+ x = self.norm(x)
800
+ if self.fc_norm is not None:
801
+ return self.fc_norm(x.mean(1))
802
+ else:
803
+ return x[:, 0]
804
+ return x
805
+
806
+ def forward(self, x, return_all_features=False):
807
+ if return_all_features:
808
+ return self.forward_features(x, return_all_features)
809
+ x = self.forward_features(x)
810
+ x = self.head(x)
811
+ return x
812
+
813
+
814
+ def load_state_dict(
815
+ checkpoint_path: str,
816
+ map_location: str = "cpu",
817
+ model_key: str = "model|module|state_dict",
818
+ is_openai: bool = False,
819
+ skip_list: list = [],
820
+ ):
821
+ if is_openai:
822
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
823
+ state_dict = model.state_dict()
824
+ for key in ["input_resolution", "context_length", "vocab_size"]:
825
+ state_dict.pop(key, None)
826
+ else:
827
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
828
+ for mk in model_key.split("|"):
829
+ if isinstance(checkpoint, dict) and mk in checkpoint:
830
+ state_dict = checkpoint[mk]
831
+ break
832
+ else:
833
+ state_dict = checkpoint
834
+ if next(iter(state_dict.items()))[0].startswith("module"):
835
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
836
+
837
+ for k in skip_list:
838
+ if k in list(state_dict.keys()):
839
+ logging.info(f"Removing key {k} from pretrained checkpoint")
840
+ del state_dict[k]
841
+
842
+ if os.getenv("RoPE") == "1":
843
+ for k in list(state_dict.keys()):
844
+ if "freqs_cos" in k or "freqs_sin" in k:
845
+ del state_dict[k]
846
+ return state_dict
847
+
848
+
849
+ def load_clip_visual_state_dict(
850
+ checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []
851
+ ):
852
+ state_dict = load_state_dict(
853
+ checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list
854
+ )
855
+
856
+ for k in list(state_dict.keys()):
857
+ if not k.startswith("visual."):
858
+ del state_dict[k]
859
+ for k in list(state_dict.keys()):
860
+ if k.startswith("visual."):
861
+ new_k = k[7:]
862
+ state_dict[new_k] = state_dict[k]
863
+ del state_dict[k]
864
+ return state_dict
865
+
866
+
867
+ try:
868
+ from apex.normalization import FusedLayerNorm
869
+ except:
870
+ FusedLayerNorm = LayerNorm
871
+ print(
872
+ "Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source ."
873
+ )
874
+
875
+
876
+ @dataclass
877
+ class CLIPVisionCfg:
878
+ layers: Union[Tuple[int, int, int, int], int] = 12
879
+ width: int = 768
880
+ head_width: int = 64
881
+ mlp_ratio: float = 4.0
882
+ patch_size: int = 16
883
+ image_size: Union[Tuple[int, int], int] = 224
884
+ ls_init_value: Optional[float] = None # layer scale initial value
885
+ patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
886
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
887
+ drop_path_rate: Optional[float] = None # drop path rate
888
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
889
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
890
+ timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
891
+ timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
892
+ timm_proj_bias: bool = False # enable bias final projection
893
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
894
+ qkv_bias: bool = True
895
+ fusedLN: bool = False
896
+ xattn: bool = False
897
+ postnorm: bool = False
898
+ rope: bool = False
899
+ pt_hw_seq_len: int = 16 # 224/14
900
+ intp_freq: bool = False
901
+ naiveswiglu: bool = False
902
+ subln: bool = False
903
+
904
+
905
+ def _build_vision_tower(vision_tower_path: str, embed_dim: int, vision_cfg: CLIPVisionCfg):
906
+ if isinstance(vision_cfg, dict):
907
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
908
+
909
+ if vision_cfg.eva_model_name:
910
+ vision_heads = vision_cfg.width // vision_cfg.head_width
911
+ norm_layer = LayerNorm
912
+
913
+ visual = EVAVisionTransformer(
914
+ img_size=vision_cfg.image_size,
915
+ patch_size=vision_cfg.patch_size,
916
+ num_classes=embed_dim,
917
+ use_mean_pooling=vision_cfg.global_average_pool, # False
918
+ init_values=vision_cfg.ls_init_value,
919
+ patch_dropout=vision_cfg.patch_dropout,
920
+ embed_dim=vision_cfg.width,
921
+ depth=vision_cfg.layers,
922
+ num_heads=vision_heads,
923
+ mlp_ratio=vision_cfg.mlp_ratio,
924
+ qkv_bias=vision_cfg.qkv_bias,
925
+ drop_path_rate=vision_cfg.drop_path_rate,
926
+ norm_layer=partial(FusedLayerNorm, eps=1e-6)
927
+ if vision_cfg.fusedLN
928
+ else partial(norm_layer, eps=1e-6),
929
+ xattn=vision_cfg.xattn,
930
+ rope=vision_cfg.rope,
931
+ postnorm=vision_cfg.postnorm,
932
+ pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
933
+ intp_freq=vision_cfg.intp_freq,
934
+ naiveswiglu=vision_cfg.naiveswiglu,
935
+ subln=vision_cfg.subln,
936
+ )
937
+
938
+ state_dict = load_clip_visual_state_dict(vision_tower_path)
939
+ incompatible_keys = visual.load_state_dict(state_dict, strict=False)
940
+ print("EVA-CLIP incompatible_keys:", incompatible_keys)
941
+
942
+ return visual
943
+
944
+
945
+ class Eva2LargePlusEncoder(nn.Module):
946
+ def __init__(self, vision_tower_path):
947
+ super(Eva2LargePlusEncoder, self).__init__()
948
+ self.config = {
949
+ "embed_dim": 768,
950
+ "vision_cfg": {
951
+ "image_size": 336,
952
+ "layers": 24,
953
+ "width": 1024,
954
+ "drop_path_rate": 0,
955
+ "head_width": 64,
956
+ "mlp_ratio": 2.6667,
957
+ "patch_size": 14,
958
+ "eva_model_name": "eva-clip-l-14-336",
959
+ "xattn": True,
960
+ "fusedLN": True,
961
+ "rope": True,
962
+ "pt_hw_seq_len": 16,
963
+ "intp_freq": True,
964
+ "naiveswiglu": True,
965
+ "subln": True,
966
+ },
967
+ }
968
+
969
+ self.config["vision_tower_path"] = vision_tower_path
970
+ self.model = _build_vision_tower(**self.config)
971
+
972
+ def forward(self, image, **kwargs):
973
+ encode = self.model(image, return_all_features=True)[:, 1:, :]
974
+ return encode
975
+
976
+ @property
977
+ def dtype(self):
978
+ return list(self.parameters())[-1].dtype
979
+
980
+ @property
981
+ def device(self):
982
+ return list(self.parameters())[-1].device
vita/model/multimodal_encoder/internvit/configuration_intern_vit.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class InternVisionConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
18
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Args:
24
+ num_channels (`int`, *optional*, defaults to 3):
25
+ Number of color channels in the input images (e.g., 3 for RGB).
26
+ patch_size (`int`, *optional*, defaults to 14):
27
+ The size (resolution) of each patch.
28
+ image_size (`int`, *optional*, defaults to 224):
29
+ The size (resolution) of each image.
30
+ qkv_bias (`bool`, *optional*, defaults to `False`):
31
+ Whether to add a bias to the queries and values in the self-attention layers.
32
+ hidden_size (`int`, *optional*, defaults to 3200):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ num_attention_heads (`int`, *optional*, defaults to 25):
35
+ Number of attention heads for each attention layer in the Transformer encoder.
36
+ intermediate_size (`int`, *optional*, defaults to 12800):
37
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
38
+ qk_normalization (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the queries and keys in the self-attention layers.
40
+ num_hidden_layers (`int`, *optional*, defaults to 48):
41
+ Number of hidden layers in the Transformer encoder.
42
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
43
+ Whether to use flash attention mechanism.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
47
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
48
+ The epsilon used by the layer normalization layers.
49
+ dropout (`float`, *optional*, defaults to 0.0):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
52
+ Dropout rate for stochastic depth.
53
+ attention_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for the attention probabilities.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ initializer_factor (`float`, *optional*, defaults to 0.1):
58
+ A factor for layer scale.
59
+ """
60
+
61
+ model_type = "intern_vit_6b"
62
+
63
+ def __init__(
64
+ self,
65
+ num_channels=3,
66
+ patch_size=14,
67
+ image_size=224,
68
+ qkv_bias=False,
69
+ hidden_size=3200,
70
+ num_attention_heads=25,
71
+ intermediate_size=12800,
72
+ qk_normalization=True,
73
+ num_hidden_layers=48,
74
+ use_flash_attn=True,
75
+ hidden_act="gelu",
76
+ norm_type="rms_norm",
77
+ layer_norm_eps=1e-6,
78
+ dropout=0.0,
79
+ drop_path_rate=0.0,
80
+ attention_dropout=0.0,
81
+ initializer_range=0.02,
82
+ initializer_factor=0.1,
83
+ **kwargs,
84
+ ):
85
+ super().__init__(**kwargs)
86
+
87
+ self.hidden_size = hidden_size
88
+ self.intermediate_size = intermediate_size
89
+ self.dropout = dropout
90
+ self.drop_path_rate = drop_path_rate
91
+ self.num_hidden_layers = num_hidden_layers
92
+ self.num_attention_heads = num_attention_heads
93
+ self.num_channels = num_channels
94
+ self.patch_size = patch_size
95
+ self.image_size = image_size
96
+ self.initializer_range = initializer_range
97
+ self.initializer_factor = initializer_factor
98
+ self.attention_dropout = attention_dropout
99
+ self.layer_norm_eps = layer_norm_eps
100
+ self.hidden_act = hidden_act
101
+ self.norm_type = norm_type
102
+ self.qkv_bias = qkv_bias
103
+ self.qk_normalization = qk_normalization
104
+ self.use_flash_attn = use_flash_attn
105
+
106
+ @classmethod
107
+ def from_pretrained(
108
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
109
+ ) -> "PretrainedConfig":
110
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
111
+
112
+ if "vision_config" in config_dict:
113
+ config_dict = config_dict["vision_config"]
114
+
115
+ if (
116
+ "model_type" in config_dict
117
+ and hasattr(cls, "model_type")
118
+ and config_dict["model_type"] != cls.model_type
119
+ ):
120
+ logger.warning(
121
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
122
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
123
+ )
124
+
125
+ return cls.from_dict(config_dict, **kwargs)
vita/model/multimodal_encoder/internvit/flash_attention.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+
6
+ from flash_attn.bert_padding import pad_input, unpad_input
7
+
8
+ try: # v1
9
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
10
+ except: # v2
11
+ from flash_attn.flash_attn_interface import (
12
+ flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func,
13
+ )
14
+
15
+
16
+ class FlashAttention(nn.Module):
17
+ """Implement the scaled dot product attention with softmax.
18
+ Arguments
19
+ ---------
20
+ softmax_scale: The temperature to use for the softmax attention.
21
+ (default: 1/sqrt(d_keys) where d_keys is computed at
22
+ runtime)
23
+ attention_dropout: The dropout rate to apply to the attention
24
+ (default: 0.0)
25
+ """
26
+
27
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
28
+ super().__init__()
29
+ self.softmax_scale = softmax_scale
30
+ self.dropout_p = attention_dropout
31
+
32
+ def forward(
33
+ self,
34
+ qkv,
35
+ key_padding_mask=None,
36
+ causal=False,
37
+ cu_seqlens=None,
38
+ max_s=None,
39
+ need_weights=False,
40
+ ):
41
+ """Implements the multihead softmax attention.
42
+ Arguments
43
+ ---------
44
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
45
+ if unpadded: (nnz, 3, h, d)
46
+ key_padding_mask: a bool tensor of shape (B, S)
47
+ """
48
+ assert not need_weights
49
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
50
+ assert qkv.is_cuda
51
+
52
+ if cu_seqlens is None:
53
+ batch_size = qkv.shape[0]
54
+ seqlen = qkv.shape[1]
55
+ if key_padding_mask is None:
56
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
57
+ max_s = seqlen
58
+ cu_seqlens = torch.arange(
59
+ 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
60
+ )
61
+ output = flash_attn_unpadded_qkvpacked_func(
62
+ qkv,
63
+ cu_seqlens,
64
+ max_s,
65
+ self.dropout_p if self.training else 0.0,
66
+ softmax_scale=self.softmax_scale,
67
+ causal=causal,
68
+ )
69
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
70
+ else:
71
+ nheads = qkv.shape[-2]
72
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
73
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
74
+ x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
75
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
76
+ x_unpad,
77
+ cu_seqlens,
78
+ max_s,
79
+ self.dropout_p if self.training else 0.0,
80
+ softmax_scale=self.softmax_scale,
81
+ causal=causal,
82
+ )
83
+ output = rearrange(
84
+ pad_input(
85
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
86
+ ),
87
+ "b s (h d) -> b s h d",
88
+ h=nheads,
89
+ )
90
+ else:
91
+ assert max_s is not None
92
+ output = flash_attn_unpadded_qkvpacked_func(
93
+ qkv,
94
+ cu_seqlens,
95
+ max_s,
96
+ self.dropout_p if self.training else 0.0,
97
+ softmax_scale=self.softmax_scale,
98
+ causal=causal,
99
+ )
100
+
101
+ return output, None
vita/model/multimodal_encoder/internvit/internvit_encoder.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoConfig, AutoModel, CLIPImageProcessor
4
+
5
+ from .modeling_intern_vit import InternVisionModel
6
+
7
+
8
+ class InternViTVisionTower(nn.Module):
9
+ def __init__(self, vision_tower, args, delay_load=False):
10
+ super().__init__()
11
+
12
+ self.is_loaded = False
13
+
14
+ self.vision_tower_name = vision_tower
15
+ self.select_layer = -1
16
+ self.scale_pix_shuffle = 0.5
17
+
18
+ if not delay_load:
19
+ self.load_model()
20
+ else:
21
+ self.cfg_only = AutoConfig.from_pretrained(
22
+ self.vision_tower_name, trust_remote_code=True
23
+ )
24
+
25
+ def load_model(self):
26
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
27
+ self.vision_tower = InternVisionModel.from_pretrained(
28
+ self.vision_tower_name, trust_remote_code=True
29
+ )
30
+ self.vision_tower.requires_grad_(False)
31
+
32
+ self.is_loaded = True
33
+
34
+ def feature_select(self, image_forward_outs):
35
+ image_features = image_forward_outs.hidden_states[self.select_layer]
36
+
37
+ image_features = image_features[:, 1:]
38
+
39
+ return image_features
40
+
41
+ def pixel_shuffle(self, x, scale_factor=0.5):
42
+ n, w, h, c = x.size()
43
+ # N, W, H, C --> N, W, H * scale, C // scale
44
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
45
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
46
+ x = x.permute(0, 2, 1, 3).contiguous()
47
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
48
+ x = x.view(
49
+ n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))
50
+ )
51
+ x = x.permute(0, 2, 1, 3).contiguous()
52
+ return x
53
+
54
+ #@torch.no_grad()
55
+ def forward(self, images):
56
+ if type(images) is list:
57
+ image_features = []
58
+ for image in images:
59
+ image_forward_out = self.vision_tower(
60
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
61
+ output_hidden_states=True,
62
+ )
63
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
64
+ image_features.append(image_feature)
65
+ else:
66
+ image_forward_outs = self.vision_tower(
67
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
68
+ )
69
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
70
+ h = w = int(image_features.shape[1] ** 0.5)
71
+ assert image_features.shape[1] == h * w
72
+ image_features = image_features.reshape(image_features.shape[0], h, w, -1)
73
+ image_features = self.pixel_shuffle(image_features * self.scale_pix_shuffle)
74
+ image_features = image_features.reshape(
75
+ image_features.shape[0], -1, image_features.shape[-1]
76
+ )
77
+
78
+ return image_features
79
+
80
+ @property
81
+ def dummy_feature(self):
82
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
83
+
84
+ @property
85
+ def dtype(self):
86
+ return self.vision_tower.dtype
87
+
88
+ @property
89
+ def device(self):
90
+ return self.vision_tower.device
91
+
92
+ @property
93
+ def config(self):
94
+ if self.is_loaded:
95
+ return self.vision_tower.config
96
+ else:
97
+ return self.cfg_only
98
+
99
+ @property
100
+ def hidden_size(self):
101
+ return self.config.hidden_size * (int(1 / self.scale_pix_shuffle) ** 2)
102
+
103
+ @property
104
+ def num_patches(self):
105
+ return (self.config.image_size // self.config.patch_size) ** 2
vita/model/multimodal_encoder/internvit/modeling_intern_vit.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from einops import rearrange
12
+ from torch import nn
13
+ from transformers.activations import ACT2FN
14
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
17
+
18
+ from timm.models.layers import DropPath
19
+
20
+ from .configuration_intern_vit import InternVisionConfig
21
+
22
+ try:
23
+ from .flash_attention import FlashAttention
24
+
25
+ has_flash_attn = True
26
+ except:
27
+ print("FlashAttention is not installed.")
28
+ has_flash_attn = False
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class InternRMSNorm(nn.Module):
35
+ def __init__(self, hidden_size, eps=1e-6):
36
+ super().__init__()
37
+ self.weight = nn.Parameter(torch.ones(hidden_size))
38
+ self.variance_epsilon = eps
39
+
40
+ def forward(self, hidden_states):
41
+ input_dtype = hidden_states.dtype
42
+ hidden_states = hidden_states.to(torch.float32)
43
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
44
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
45
+ return self.weight * hidden_states.to(input_dtype)
46
+
47
+
48
+ try:
49
+ from apex.normalization import FusedRMSNorm
50
+
51
+ InternRMSNorm = FusedRMSNorm # noqa
52
+
53
+ logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm")
54
+ except ImportError:
55
+ # using the normal InternRMSNorm
56
+ pass
57
+ except Exception:
58
+ logger.warning("discovered apex but it failed to load, falling back to InternRMSNorm")
59
+ pass
60
+
61
+
62
+ NORM2FN = {
63
+ "rms_norm": InternRMSNorm,
64
+ "layer_norm": nn.LayerNorm,
65
+ }
66
+
67
+
68
+ class InternVisionEmbeddings(nn.Module):
69
+ def __init__(self, config: InternVisionConfig):
70
+ super().__init__()
71
+ self.config = config
72
+ self.embed_dim = config.hidden_size
73
+ self.image_size = config.image_size
74
+ self.patch_size = config.patch_size
75
+
76
+ self.class_embedding = nn.Parameter(
77
+ torch.randn(1, 1, self.embed_dim),
78
+ )
79
+
80
+ self.patch_embedding = nn.Conv2d(
81
+ in_channels=3,
82
+ out_channels=self.embed_dim,
83
+ kernel_size=self.patch_size,
84
+ stride=self.patch_size,
85
+ )
86
+
87
+ self.num_patches = (self.image_size // self.patch_size) ** 2
88
+ self.num_positions = self.num_patches + 1
89
+
90
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
91
+
92
+ def _get_pos_embed(self, pos_embed, H, W):
93
+ target_dtype = pos_embed.dtype
94
+ pos_embed = (
95
+ pos_embed.float()
96
+ .reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1)
97
+ .permute(0, 3, 1, 2)
98
+ )
99
+ pos_embed = (
100
+ F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
101
+ .reshape(1, -1, H * W)
102
+ .permute(0, 2, 1)
103
+ .to(target_dtype)
104
+ )
105
+ return pos_embed
106
+
107
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
108
+ target_dtype = self.patch_embedding.weight.dtype
109
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
110
+ batch_size, _, height, width = patch_embeds.shape
111
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
112
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
113
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
114
+ position_embedding = torch.cat(
115
+ [
116
+ self.position_embedding[:, :1, :],
117
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
118
+ ],
119
+ dim=1,
120
+ )
121
+ embeddings = embeddings + position_embedding.to(target_dtype)
122
+ return embeddings
123
+
124
+
125
+ class InternAttention(nn.Module):
126
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
127
+
128
+ def __init__(self, config: InternVisionConfig):
129
+ super().__init__()
130
+ self.config = config
131
+ self.embed_dim = config.hidden_size
132
+ self.num_heads = config.num_attention_heads
133
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
134
+ if config.use_flash_attn and not has_flash_attn:
135
+ print("Warning: Flash Attention is not available, use_flash_attn is set to False.")
136
+ self.head_dim = self.embed_dim // self.num_heads
137
+ if self.head_dim * self.num_heads != self.embed_dim:
138
+ raise ValueError(
139
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
140
+ f" {self.num_heads})."
141
+ )
142
+
143
+ self.scale = self.head_dim**-0.5
144
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
145
+ self.attn_drop = nn.Dropout(config.attention_dropout)
146
+ self.proj_drop = nn.Dropout(config.dropout)
147
+
148
+ self.qk_normalization = config.qk_normalization
149
+
150
+ if self.qk_normalization:
151
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
152
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
153
+
154
+ if self.use_flash_attn:
155
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
156
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
157
+
158
+ def _naive_attn(self, x):
159
+ B, N, C = x.shape
160
+ qkv = (
161
+ self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
162
+ )
163
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
164
+
165
+ if self.qk_normalization:
166
+ B_, H_, N_, D_ = q.shape
167
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
168
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
169
+
170
+ attn = (q * self.scale) @ k.transpose(-2, -1)
171
+ attn = attn.softmax(dim=-1)
172
+ attn = self.attn_drop(attn)
173
+
174
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
175
+ x = self.proj(x)
176
+ x = self.proj_drop(x)
177
+ return x
178
+
179
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
180
+ qkv = self.qkv(x)
181
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
182
+
183
+ if self.qk_normalization:
184
+ q, k, v = qkv.unbind(2)
185
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
186
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
187
+ qkv = torch.stack([q, k, v], dim=2)
188
+
189
+ context, _ = self.inner_attn(
190
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
191
+ )
192
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
193
+ outs = self.proj_drop(outs)
194
+ return outs
195
+
196
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
197
+ x = (
198
+ self._naive_attn(hidden_states)
199
+ if not self.use_flash_attn
200
+ else self._flash_attn(hidden_states)
201
+ )
202
+ return x
203
+
204
+
205
+ class InternMLP(nn.Module):
206
+ def __init__(self, config: InternVisionConfig):
207
+ super().__init__()
208
+ self.config = config
209
+ self.act = ACT2FN[config.hidden_act]
210
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
211
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
212
+
213
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
214
+ hidden_states = self.fc1(hidden_states)
215
+ hidden_states = self.act(hidden_states)
216
+ hidden_states = self.fc2(hidden_states)
217
+ return hidden_states
218
+
219
+
220
+ class InternVisionEncoderLayer(nn.Module):
221
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
222
+ super().__init__()
223
+ self.embed_dim = config.hidden_size
224
+ self.intermediate_size = config.intermediate_size
225
+ self.norm_type = config.norm_type
226
+
227
+ self.attn = InternAttention(config)
228
+ self.mlp = InternMLP(config)
229
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
230
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
231
+
232
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
233
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
234
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
235
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
236
+
237
+ def forward(
238
+ self,
239
+ hidden_states: torch.Tensor,
240
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
241
+ """
242
+ Args:
243
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
244
+ """
245
+ hidden_states = hidden_states + self.drop_path1(
246
+ self.attn(self.norm1(hidden_states)) * self.ls1
247
+ )
248
+
249
+ hidden_states = hidden_states + self.drop_path2(
250
+ self.mlp(self.norm2(hidden_states)) * self.ls2
251
+ )
252
+
253
+ return hidden_states
254
+
255
+
256
+ class InternVisionEncoder(nn.Module):
257
+ """
258
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
259
+ [`InternEncoderLayer`].
260
+
261
+ Args:
262
+ config (`InternConfig`):
263
+ The corresponding vision configuration for the `InternEncoder`.
264
+ """
265
+
266
+ def __init__(self, config: InternVisionConfig):
267
+ super().__init__()
268
+ self.config = config
269
+ # stochastic depth decay rule
270
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
271
+ self.layers = nn.ModuleList(
272
+ [InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]
273
+ )
274
+ self.gradient_checkpointing = True
275
+
276
+ def forward(
277
+ self,
278
+ inputs_embeds,
279
+ output_hidden_states: Optional[bool] = None,
280
+ return_dict: Optional[bool] = None,
281
+ ) -> Union[Tuple, BaseModelOutput]:
282
+ r"""
283
+ Args:
284
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
285
+ Embedded representation of the inputs. Should be float, not int tokens.
286
+ output_hidden_states (`bool`, *optional*):
287
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
288
+ for more detail.
289
+ return_dict (`bool`, *optional*):
290
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
291
+ """
292
+ output_hidden_states = (
293
+ output_hidden_states
294
+ if output_hidden_states is not None
295
+ else self.config.output_hidden_states
296
+ )
297
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
298
+
299
+ encoder_states = () if output_hidden_states else None
300
+ hidden_states = inputs_embeds
301
+
302
+ for idx, encoder_layer in enumerate(self.layers):
303
+ if output_hidden_states:
304
+ encoder_states = encoder_states + (hidden_states,)
305
+ if self.gradient_checkpointing and self.training:
306
+ layer_outputs = torch.utils.checkpoint.checkpoint(encoder_layer, hidden_states)
307
+ else:
308
+ layer_outputs = encoder_layer(
309
+ hidden_states,
310
+ )
311
+ hidden_states = layer_outputs
312
+
313
+ if output_hidden_states:
314
+ encoder_states = encoder_states + (hidden_states,)
315
+
316
+ if not return_dict:
317
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
318
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
319
+
320
+
321
+ class InternVisionModel(PreTrainedModel):
322
+ main_input_name = "pixel_values"
323
+ config_class = InternVisionConfig
324
+ _no_split_modules = ["InternVisionEncoderLayer"]
325
+
326
+ def __init__(self, config: InternVisionConfig):
327
+ super().__init__(config)
328
+ self.config = config
329
+
330
+ self.embeddings = InternVisionEmbeddings(config)
331
+ self.encoder = InternVisionEncoder(config)
332
+
333
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
334
+ pos_emb = self.embeddings.position_embedding
335
+ _, num_positions, embed_dim = pos_emb.shape
336
+ cls_emb = pos_emb[:, :1, :]
337
+ pos_emb = (
338
+ pos_emb[:, 1:, :]
339
+ .reshape(1, old_size // patch_size, old_size // patch_size, -1)
340
+ .permute(0, 3, 1, 2)
341
+ )
342
+ pos_emb = F.interpolate(
343
+ pos_emb.float(), size=new_size // patch_size, mode="bicubic", align_corners=False
344
+ )
345
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
346
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
347
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
348
+ self.embeddings.image_size = new_size
349
+ logger.info("Resized position embeddings from {} to {}".format(old_size, new_size))
350
+
351
+ def get_input_embeddings(self):
352
+ return self.embeddings
353
+
354
+ def forward(
355
+ self,
356
+ pixel_values: Optional[torch.FloatTensor] = None,
357
+ output_hidden_states: Optional[bool] = None,
358
+ return_dict: Optional[bool] = None,
359
+ pixel_embeds: Optional[torch.FloatTensor] = None,
360
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
361
+ output_hidden_states = (
362
+ output_hidden_states
363
+ if output_hidden_states is not None
364
+ else self.config.output_hidden_states
365
+ )
366
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
367
+
368
+ if pixel_values is None and pixel_embeds is None:
369
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
370
+
371
+ if pixel_embeds is not None:
372
+ hidden_states = pixel_embeds
373
+ else:
374
+ if len(pixel_values.shape) == 4:
375
+ hidden_states = self.embeddings(pixel_values)
376
+ else:
377
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
378
+ encoder_outputs = self.encoder(
379
+ inputs_embeds=hidden_states,
380
+ output_hidden_states=output_hidden_states,
381
+ return_dict=return_dict,
382
+ )
383
+ last_hidden_state = encoder_outputs.last_hidden_state
384
+ pooled_output = last_hidden_state[:, 0, :]
385
+
386
+ if not return_dict:
387
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
388
+
389
+ return BaseModelOutputWithPooling(
390
+ last_hidden_state=last_hidden_state,
391
+ pooler_output=pooled_output,
392
+ hidden_states=encoder_outputs.hidden_states,
393
+ attentions=encoder_outputs.attentions,
394
+ )
vita/model/multimodal_encoder/siglip/siglip_encoder.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
4
+
5
+ from vita.util.s2wrapper import forward as multiscale_forward
6
+
7
+
8
+ class SiglipVisionTower(nn.Module):
9
+ def __init__(self, vision_tower, args, delay_load=False):
10
+ super().__init__()
11
+
12
+ self.is_loaded = False
13
+
14
+ self.vision_tower_name = vision_tower
15
+ self.select_layer = -2
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ else:
20
+ self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
21
+
22
+ def load_model(self):
23
+ self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
24
+ self.image_processor.crop_size = self.image_processor.size
25
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
26
+ self.vision_tower.requires_grad_(False)
27
+
28
+ self.is_loaded = True
29
+
30
+ def feature_select(self, image_forward_outs):
31
+ image_features = image_forward_outs.hidden_states[self.select_layer]
32
+
33
+ return image_features
34
+
35
+ @torch.no_grad()
36
+ def forward(self, images):
37
+ if type(images) is list:
38
+ image_features = []
39
+ for image in images:
40
+ image_forward_out = self.vision_tower(
41
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
42
+ output_hidden_states=True,
43
+ )
44
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
45
+ image_features.append(image_feature)
46
+ else:
47
+ image_forward_outs = self.vision_tower(
48
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
49
+ )
50
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
51
+
52
+ return image_features
53
+
54
+ @property
55
+ def dummy_feature(self):
56
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
57
+
58
+ @property
59
+ def dtype(self):
60
+ return self.vision_tower.dtype
61
+
62
+ @property
63
+ def device(self):
64
+ return self.vision_tower.device
65
+
66
+ @property
67
+ def config(self):
68
+ if self.is_loaded:
69
+ return self.vision_tower.config
70
+ else:
71
+ return self.cfg_only
72
+
73
+ @property
74
+ def hidden_size(self):
75
+ return self.config.hidden_size
76
+
77
+ @property
78
+ def num_patches(self):
79
+ return (self.config.image_size // self.config.patch_size) ** 2
80
+
81
+
82
+ class SiglipVisionTowerS2(SiglipVisionTower):
83
+ def __init__(self, vision_tower, args, delay_load=False):
84
+ self.s2_scales = getattr(args, "s2_scales", "384,768,1152")
85
+ self.s2_scales = list(map(int, self.s2_scales.split(",")))
86
+ self.s2_scales.sort()
87
+ self.s2_split_size = self.s2_scales[0]
88
+ self.s2_image_size = self.s2_scales[-1]
89
+
90
+ super().__init__(vision_tower, args, delay_load)
91
+
92
+ self.multiscale_forward = multiscale_forward
93
+
94
+ if not delay_load:
95
+ self.image_processor.size["height"] = self.image_processor.size[
96
+ "width"
97
+ ] = self.s2_image_size
98
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size[
99
+ "width"
100
+ ] = self.s2_image_size
101
+
102
+ def load_model(self):
103
+ self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
104
+ self.image_processor.crop_size = self.image_processor.size
105
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
106
+ self.vision_tower.requires_grad_(False)
107
+
108
+ self.image_processor.size["height"] = self.image_processor.size[
109
+ "width"
110
+ ] = self.s2_image_size
111
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size[
112
+ "width"
113
+ ] = self.s2_image_size
114
+
115
+ self.is_loaded = True
116
+
117
+ @torch.no_grad()
118
+ def forward_feature(self, images):
119
+ image_forward_outs = self.vision_tower(
120
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
121
+ )
122
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
123
+ return image_features
124
+
125
+ @torch.no_grad()
126
+ def forward(self, images):
127
+ if type(images) is list:
128
+ image_features = []
129
+ for image in images:
130
+ image_feature = self.multiscale_forward(
131
+ self.forward_feature,
132
+ image.unsqueeze(0),
133
+ img_sizes=self.s2_scales,
134
+ max_split_size=self.s2_split_size,
135
+ )
136
+ image_features.append(image_feature)
137
+ else:
138
+ image_features = self.multiscale_forward(
139
+ self.forward_feature,
140
+ images,
141
+ img_sizes=self.s2_scales,
142
+ max_split_size=self.s2_split_size,
143
+ )
144
+
145
+ return image_features
146
+
147
+ @property
148
+ def hidden_size(self):
149
+ return self.config.hidden_size * len(self.s2_scales)
vita/model/multimodal_encoder/whale/adapter.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.utils.rnn import pad_sequence
4
+
5
+
6
+ class CNNAdapter(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ enc_out_dim: int = 512,
10
+ llm_embed_dim: int = 4096,
11
+ kernel_size: int = 5,
12
+ ):
13
+ super().__init__()
14
+
15
+ self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
16
+ self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
17
+ self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
18
+ self.relu1 = nn.ReLU()
19
+
20
+ self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
21
+ self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 1, 0)
22
+ self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
23
+ self.relu2 = nn.ReLU()
24
+
25
+ self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
26
+
27
+ def forward(self, x, mask_pad):
28
+ """
29
+ x: B, T, enc_out_dim
30
+ mask: (B, T) or (B, 1, T)
31
+ """
32
+ x = x.transpose(1, 2) # B, channels, T
33
+
34
+ # mask batch padding
35
+ if mask_pad.size(2) > 0: # time > 0
36
+ x.masked_fill_(~mask_pad, 0.0)
37
+
38
+ x = self.left_padding1(x)
39
+ x = self.conv1d1(x)
40
+ x = self.bn1(x)
41
+ x = self.relu1(x)
42
+
43
+ x = self.left_padding2(x)
44
+ x = self.conv1d2(x)
45
+ x = self.bn2(x)
46
+ x = self.relu2(x)
47
+
48
+ x = x.transpose(1, 2)
49
+ x = self.project(x)
50
+
51
+ return x, mask_pad
52
+
53
+
54
+ class LinearAdapter(torch.nn.Module):
55
+ def __init__(
56
+ self,
57
+ enc_out_dim: int = 512,
58
+ llm_embed_dim: int = 4096,
59
+ ):
60
+ super().__init__()
61
+
62
+ self.adpter = torch.nn.Linear(enc_out_dim, llm_embed_dim)
63
+
64
+ def forward(self, x, mask_pad):
65
+ return self.adpter(x), mask_pad
66
+
67
+
68
+ class CNNSubsampling(torch.nn.Module):
69
+ def __init__(
70
+ self,
71
+ enc_out_dim: int = 512,
72
+ llm_embed_dim: int = 4096,
73
+ kernel_size: int = 5,
74
+ activation_func: str = "relu",
75
+ norm: str = "batch",
76
+ ):
77
+ super().__init__()
78
+
79
+ #if enc_out_dim * 4 < llm_embed_dim:
80
+ if enc_out_dim * 4 < 0:
81
+ self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
82
+ self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
83
+ self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
84
+ self.relu1 = nn.ReLU()
85
+
86
+ self.left_padding2 = nn.ConstantPad1d((0, kernel_size - 1), 0.0)
87
+ self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 2, 0)
88
+ self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
89
+ self.relu2 = nn.ReLU()
90
+
91
+ self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
92
+ self.cnn_num = 2
93
+ else:
94
+ self.left_padding2 = nn.ConstantPad1d((0, kernel_size - 1), 0.0)
95
+ self.conv1d2 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 2, 0)
96
+ if norm == "batch":
97
+ self.bn2 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
98
+ elif norm == "layer":
99
+ self.bn2 = nn.LayerNorm(2 * enc_out_dim, eps=1e-3)
100
+ if activation_func == "gelu":
101
+ self.relu2 = nn.GELU()
102
+ else:
103
+ self.relu2 = nn.ReLU()
104
+
105
+ self.project = nn.Linear(2 * enc_out_dim, llm_embed_dim)
106
+ self.cnn_num = 1
107
+
108
+ def forward(self, x, mask_pad):
109
+ """
110
+ x: B, T, enc_out_dim
111
+ mask: (B, T) or (B, 1, T)
112
+ """
113
+ x = x.transpose(1, 2) # B, channels, T
114
+
115
+ # mask batch padding
116
+ if mask_pad.size(2) > 0: # time > 0
117
+ x.masked_fill_(~mask_pad, 0.0)
118
+
119
+ if self.cnn_num == 2:
120
+ x = self.left_padding1(x)
121
+ x = self.conv1d1(x)
122
+ x = self.bn1(x)
123
+ x = self.relu1(x)
124
+
125
+ x = self.left_padding2(x)
126
+ x = self.conv1d2(x)
127
+ if isinstance(self.bn2, nn.LayerNorm):
128
+ x = x.transpose(1, 2)
129
+ x = self.bn2(x)
130
+ if isinstance(self.bn2, nn.LayerNorm):
131
+ x = x.transpose(1, 2)
132
+ x = self.relu2(x)
133
+
134
+ x = x.transpose(1, 2)
135
+ x = self.project(x)
136
+
137
+ return x, mask_pad[:, :, 0::2]
vita/model/multimodal_encoder/whale/cmvn.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import json
4
+ import math
5
+
6
+
7
+ class GlobalCMVN(torch.nn.Module):
8
+ def __init__(self, mean: torch.Tensor, istd: torch.Tensor, norm_var: bool = True):
9
+ """
10
+ Args:
11
+ mean (torch.Tensor): mean stats
12
+ istd (torch.Tensor): inverse std, std which is 1.0 / std
13
+ """
14
+ super().__init__()
15
+ assert mean.shape == istd.shape
16
+ self.norm_var = norm_var
17
+ # The buffer can be accessed from this module using self.mean
18
+ self.register_buffer("mean", mean)
19
+ self.register_buffer("istd", istd)
20
+
21
+ def forward(self, x: torch.Tensor):
22
+ """
23
+ Args:
24
+ x (torch.Tensor): (batch, max_len, feat_dim)
25
+
26
+ Returns:
27
+ (torch.Tensor): normalized feature
28
+ """
29
+ x = x - self.mean
30
+ if self.norm_var:
31
+ x = x * self.istd
32
+ return x
33
+
34
+
35
+ def load_cmvn_json(json_cmvn_file):
36
+ with open(json_cmvn_file) as f:
37
+ cmvn_json = json.load(f)
38
+
39
+ avg = cmvn_json["mean_stat"]
40
+ var = cmvn_json["var_stat"]
41
+ count = cmvn_json["frame_num"]
42
+ for i in range(len(avg)):
43
+ avg[i] /= count
44
+ var[i] = var[i] / count - avg[i] * avg[i]
45
+ if var[i] < 1.0e-20:
46
+ var[i] = 1.0e-20
47
+ var[i] = 1.0 / math.sqrt(var[i])
48
+ cmvn = np.array([avg, var])
49
+ return cmvn
50
+
51
+
52
+ def load_cmvn_kaldi(kaldi_cmvn_file):
53
+ avg = []
54
+ var = []
55
+ with open(kaldi_cmvn_file, "r") as file:
56
+ # kaldi binary file start with '\0B'
57
+ if file.read(2) == "\0B":
58
+ logging.error(
59
+ "kaldi cmvn binary file is not supported, please "
60
+ )
61
+ sys.exit(1)
62
+ file.seek(0)
63
+ arr = file.read().split()
64
+ assert arr[0] == "["
65
+ assert arr[-2] == "0"
66
+ assert arr[-1] == "]"
67
+ feat_dim = int((len(arr) - 2 - 2) / 2)
68
+ for i in range(1, feat_dim + 1):
69
+ avg.append(float(arr[i]))
70
+ count = float(arr[feat_dim + 1])
71
+ for i in range(feat_dim + 2, 2 * feat_dim + 2):
72
+ var.append(float(arr[i]))
73
+
74
+ for i in range(len(avg)):
75
+ avg[i] /= count
76
+ var[i] = var[i] / count - avg[i] * avg[i]
77
+ if var[i] < 1.0e-20:
78
+ var[i] = 1.0e-20
79
+ var[i] = 1.0 / math.sqrt(var[i])
80
+ cmvn = np.array([avg, var])
81
+ return cmvn
82
+
83
+
84
+ def load_cmvn(filename, is_json):
85
+ if is_json:
86
+ file = load_cmvn_json(filename)
87
+ else:
88
+ file = load_cmvn_kaldi(filename)
89
+ return file[0], file[1]
vita/model/multimodal_encoder/whale/init_model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Binbin Zhang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ import torchaudio
21
+ import torchaudio.compliance.kaldi as kaldi
22
+
23
+ from .adapter import CNNAdapter, CNNSubsampling, LinearAdapter
24
+ from .cmvn import GlobalCMVN, load_cmvn
25
+ from .module.encoder.encoder import whaleEncoder
26
+
27
+
28
+ class audioEncoderProcessor:
29
+ def __init__(
30
+ self,
31
+ dataset_conf: dict = None,
32
+ ):
33
+ self.dataset_conf = dataset_conf
34
+
35
+ def process(self, wav_path):
36
+ try:
37
+ waveform, sample_rate = torchaudio.load(wav_path)
38
+ except Exception as e:
39
+ print(f"cannot open {wav_path}!!!!!!!!!!!!!!!!")
40
+ if sample_rate != self.dataset_conf["resample_conf"]["resample_rate"]:
41
+ waveform = torchaudio.transforms.Resample(
42
+ orig_freq=sample_rate, new_freq=self.dataset_conf["resample_conf"]["resample_rate"]
43
+ )(waveform)
44
+ sample_rate = self.dataset_conf['resample_conf']['resample_rate']
45
+
46
+ waveform = waveform * (1 << 15)
47
+ # Only keep key, feat, label
48
+ mat = kaldi.fbank(
49
+ waveform,
50
+ num_mel_bins=self.dataset_conf["fbank_conf"]["num_mel_bins"],
51
+ frame_length=self.dataset_conf["fbank_conf"]["frame_length"],
52
+ frame_shift=self.dataset_conf["fbank_conf"]["frame_shift"],
53
+ dither=self.dataset_conf["fbank_conf"]["dither"],
54
+ energy_floor=0.0,
55
+ sample_frequency=sample_rate,
56
+ )
57
+ attn_mask = torch.ones(mat.shape[0])
58
+ attn_mask = attn_mask[2::2][2::2][0::2]
59
+
60
+ return mat, attn_mask.shape[0]
61
+
62
+
63
+ class audioEncoder(torch.nn.Module):
64
+ def __init__(
65
+ self,
66
+ encoder: torch.nn.Module,
67
+ llm_path: str,
68
+ freeze_llm: bool = True,
69
+ enc_out_dim: int = 512,
70
+ llm_embed_dim: int = 4096,
71
+ kernel_size: int = 3,
72
+ IGNORE_ID: int = -100,
73
+ adpter_type: str = "cnn",
74
+ add_audio_bos_eos: bool = False,
75
+ task_num: int = 10,
76
+ task_before_audio: bool = False,
77
+ task_type: str = "prompt",
78
+ freeze_encoder: bool = False,
79
+ freeze_adpter: bool = False,
80
+ audio_prompt_finetune: bool = False,
81
+ audio_prompt_num: int = 25,
82
+ activation_func: str = "relu",
83
+ norm: str = "batch",
84
+ chat_template=None,
85
+ ):
86
+ super().__init__()
87
+ self.encoder = encoder
88
+
89
+ self.enc_out_dim = enc_out_dim
90
+ self.llm_embed_dim = llm_embed_dim
91
+ self.IGNORE_ID = IGNORE_ID
92
+ self.add_audio_bos_eos = add_audio_bos_eos
93
+ self.task_before_audio = task_before_audio
94
+ self.task_type = task_type
95
+ self.freeze_encoder = freeze_encoder
96
+ self.freeze_adpter = freeze_adpter
97
+ self.audio_prompt_finetune = audio_prompt_finetune
98
+ self.audio_prompt_num = audio_prompt_num
99
+
100
+ if adpter_type == "cnn":
101
+ self.adpter = CNNAdapter(enc_out_dim, llm_embed_dim, kernel_size)
102
+ elif adpter_type == "linear":
103
+ self.adpter = LinearAdapter(enc_out_dim, llm_embed_dim)
104
+ elif adpter_type == "subsampling":
105
+ self.adpter = CNNSubsampling(
106
+ enc_out_dim, llm_embed_dim, kernel_size, activation_func, norm
107
+ )
108
+
109
+ if self.freeze_encoder:
110
+ self.encoder.eval()
111
+ for (name, param) in self.encoder.named_parameters():
112
+ param.requires_grad = False
113
+ if self.freeze_adpter:
114
+ self.adpter.eval()
115
+ for (name, param) in self.adpter.named_parameters():
116
+ param.requires_grad = False
117
+
118
+ if self.audio_prompt_finetune:
119
+ self.prompt_embeddings = nn.Embedding(audio_prompt_num, llm_embed_dim)
120
+ self.prompt_ids = torch.tensor([i for i in range(audio_prompt_num)]).long()
121
+
122
+ def forward(
123
+ self,
124
+ speech: torch.Tensor,
125
+ speech_lengths: torch.Tensor,
126
+ ) -> Dict[str, Optional[torch.Tensor]]:
127
+
128
+ speech = speech.to(next(self.parameters()).dtype)
129
+
130
+ # 1. Encoder
131
+ encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
132
+ inputs_embeds, encoder_mask = self.adpter(encoder_out, encoder_mask) # B, T, D
133
+ attention_mask = encoder_mask.squeeze(1) # B, T
134
+ assert inputs_embeds.size(1) == attention_mask.size(1)
135
+
136
+ # audio bos/eos
137
+ if self.add_audio_bos_eos:
138
+ inputs_embeds, attention_mask, target = self._add_bos_eos(
139
+ "audio", "/audio", inputs_embeds, attention_mask, target
140
+ )
141
+
142
+ B, _, _ = inputs_embeds.shape
143
+ if self.audio_prompt_finetune:
144
+ prompt_ids = self.prompt_ids.repeat(B, 1).to(inputs_embeds.device)
145
+ prompt_embeds = self.prompt_embeddings(
146
+ prompt_ids.to(inputs_embeds.device)) # B, 5, D
147
+ inputs_embeds = torch.cat((prompt_embeds, inputs_embeds), 1) # B, (T+5), D
148
+
149
+ outputs = {
150
+ "inputs_embeds": inputs_embeds,
151
+ "attention_mask": attention_mask,
152
+ }
153
+
154
+ return outputs
155
+
156
+ def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None):
157
+ B = len(inputs_embeds)
158
+ bos_embed = self.task_embeddings(
159
+ torch.full([B, 1], self.task_ids[bos]).to(inputs_embeds.device)
160
+ ) # B, 1, D
161
+ eos_embed = self.task_embeddings(
162
+ torch.full([B, 1], self.task_ids[eos]).to(inputs_embeds.device)
163
+ ) # B, 1, D
164
+ bos_eos_target = torch.full([B, 2], self.IGNORE_ID).to(inputs_embeds.device) # B, 2
165
+ bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device) # B, 1
166
+
167
+ inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1) # B, (1+T), D
168
+ inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1) # B, (1+T+1), D
169
+ attention_mask = torch.cat((bos_eos_mask, attention_mask), 1) # B, (1+T)
170
+ attention_mask = torch.cat((attention_mask, bos_eos_mask), 1) # B, (1+T+1)
171
+ if target is not None:
172
+ target = torch.cat((target, bos_eos_target), 1) # B, (T+2), D
173
+
174
+ return inputs_embeds, attention_mask, target
175
+
176
+
177
+ def init_model(configs):
178
+ if configs["cmvn_file"] is not None:
179
+ mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"])
180
+ global_cmvn = GlobalCMVN(torch.from_numpy(mean).float(), torch.from_numpy(istd).float())
181
+ else:
182
+ global_cmvn = None
183
+
184
+ input_dim = configs["input_dim"]
185
+
186
+ encoder = whaleEncoder(input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"])
187
+ model = audioEncoder(encoder=encoder, **configs["model_conf"])
188
+ processor = audioEncoderProcessor(dataset_conf=configs["dataset_conf"])
189
+
190
+ model.audio_processor = processor
191
+
192
+ return model
vita/model/multimodal_encoder/whale/module/component/mamba.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Encoder self-attention layer definition."""
2
+
3
+ import math
4
+ import pdb
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, strtobool
13
+
14
+ try:
15
+ from mamba_ssm.modules.mamba_simple import Mamba, Block
16
+ from mamba_ssm.models.mixer_seq_simple import _init_weights
17
+ from mamba_ssm.ops.triton.layernorm import RMSNorm
18
+ except ImportError:
19
+ print("Please install mamba_ssm to use MambaSSM component.")
20
+
21
+
22
+ class MambaBlock(nn.Module):
23
+ def __init__(self, in_channels, n_layer=1, d_state=16, d_conv=4, expand=4, bidirectional=False):
24
+ super(MambaBlock, self).__init__()
25
+ self.forward_blocks = nn.ModuleList([])
26
+ self.forward_norm_f = RMSNorm(in_channels, eps=1e-5)
27
+ for i in range(n_layer):
28
+ self.forward_blocks.append(
29
+ Block(
30
+ in_channels,
31
+ mixer_cls=partial(
32
+ Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand
33
+ ),
34
+ norm_cls=partial(RMSNorm, eps=1e-5),
35
+ fused_add_norm=True,
36
+ residual_in_fp32=True,
37
+ )
38
+ )
39
+ if bidirectional:
40
+ self.backward_blocks = nn.ModuleList([])
41
+ for i in range(n_layer):
42
+ self.backward_blocks.append(
43
+ Block(
44
+ in_channels,
45
+ mixer_cls=partial(
46
+ Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand
47
+ ),
48
+ norm_cls=partial(RMSNorm, eps=1e-5),
49
+ fused_add_norm=True,
50
+ residual_in_fp32=True,
51
+ )
52
+ )
53
+ self.backward_norm_f = RMSNorm(in_channels, eps=1e-5)
54
+ else:
55
+ self.backward_blocks = None
56
+
57
+ self.apply(partial(_init_weights, n_layer=n_layer))
58
+
59
+ def forward(self, input):
60
+ for_residual = None
61
+ forward_f = input.clone()
62
+ for block in self.forward_blocks:
63
+ forward_f, for_residual = block(forward_f, for_residual, inference_params=None)
64
+ residual = (forward_f + for_residual) if for_residual is not None else forward_f
65
+ residual = self.forward_norm_f(residual)
66
+
67
+ if self.backward_blocks is not None:
68
+ back_residual = None
69
+ backward_f = torch.flip(input, [1])
70
+ for block in self.backward_blocks:
71
+ backward_f, back_residual = block(backward_f, back_residual, inference_params=None)
72
+ back_residual = (
73
+ (backward_f + back_residual) if back_residual is not None else backward_f
74
+ )
75
+
76
+ back_residual = torch.flip(back_residual, [1])
77
+ back_residual = self.backward_norm_f(back_residual)
78
+ residual = torch.cat([residual, back_residual], -1)
79
+
80
+ return residual
81
+
82
+
83
+ class MambaSSM(torch.nn.Module):
84
+ @staticmethod
85
+ def add_arguments(group):
86
+ """Add TDNN common arguments."""
87
+ group.add_argument(
88
+ "--mamba-num-layers", default=4, type=int, help="Output dim of MambaSSM."
89
+ )
90
+ group.add_argument(
91
+ "--mamba-input-dim", default=256, type=int, help="Input dim of MambaSSM."
92
+ )
93
+ group.add_argument(
94
+ "--mamba-output-dim", default=256, type=int, help="Output dim of MambaSSM."
95
+ )
96
+ group.add_argument("--mamba-d-state", default=16, type=int, help="d-state of MambaSSM.")
97
+ group.add_argument("--mamba-d-conv", default=4, type=int, help="d-conv of MambaSSM.")
98
+ group.add_argument("--mamba-expand", default=4, type=int, help="expand of MambaSSM.")
99
+ return group
100
+
101
+ def __init__(self, args):
102
+ """Construct an Encoder object."""
103
+ super(MambaSSM, self).__init__()
104
+ self.mamb_num_layers = args.mamba_num_layers
105
+ self.mamba_input_dim = args.mamba_input_dim
106
+ self.mamba_output_dim = args.mamba_output_dim
107
+ self.mamba_d_state = args.mamba_d_state
108
+ self.mamba_d_conv = args.mamba_d_conv
109
+ self.mamba_expand = args.mamba_expand
110
+
111
+ self.mamba = MambaBlock(
112
+ self.mamba_input_dim,
113
+ self.mamb_num_layers,
114
+ self.mamba_d_state,
115
+ self.mamba_d_conv,
116
+ self.mamba_expand,
117
+ )
118
+
119
+ @torch.jit.unused
120
+ def forward(self, xs, ilens=None, masks=None):
121
+ """Embed positions in tensor.
122
+
123
+ :param torch.Tensor xs: input tensor
124
+ :param torch.Tensor masks: input mask
125
+ :return: position embedded tensor and mask
126
+ :rtype Tuple[torch.Tensor, torch.Tensor]:
127
+ """
128
+
129
+ xs_out = self.mamba(xs)
130
+
131
+ return xs_out.to(xs.dtype), ilens, masks
vita/model/multimodal_encoder/whale/module/component/subsampling.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Union
3
+
4
+
5
+ class BaseSubsampling(torch.nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.subsampling_rate = 1
9
+ self.right_context = 0
10
+
11
+ def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
12
+ return self.pos_enc.position_encoding(offset, size)
13
+
14
+
15
+ class Conv2dSubsampling4(BaseSubsampling):
16
+ """Convolutional 2D subsampling (to 1/4 length).
17
+
18
+ Args:
19
+ idim (int): Input dimension.
20
+ odim (int): Output dimension.
21
+ dropout_rate (float): Dropout rate.
22
+
23
+ """
24
+
25
+ def __init__(self, idim: int, odim: int, dropout_rate: float):
26
+ """Construct an Conv2dSubsampling4 object."""
27
+ super().__init__()
28
+ self.conv = torch.nn.Sequential(
29
+ torch.nn.Conv2d(1, odim, 3, 2),
30
+ torch.nn.ReLU(),
31
+ torch.nn.Conv2d(odim, odim, 3, 2),
32
+ torch.nn.ReLU(),
33
+ )
34
+ self.out = torch.nn.Sequential(torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
35
+ self.right_context = 6
36
+ self.subsampling_rate = 4
37
+
38
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
39
+ x = x.unsqueeze(1) # (b, c=1, t, f)
40
+ x = self.conv(x)
41
+ b, c, t, f = x.size()
42
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
43
+ return x, x_mask[:, :, 2::2][:, :, 2::2]
44
+
45
+
46
+ class Subsampling(torch.nn.Module):
47
+ @staticmethod
48
+ def add_arguments(group):
49
+ """Add Subsampling common arguments."""
50
+ group.add_argument("--subsampling-rate", default=4, type=int)
51
+ group.add_argument("--subsampling-input-dim", default=256, type=int)
52
+ group.add_argument("--subsampling-output-dim", default=256, type=int)
53
+ group.add_argument("--subsampling-dropout-rate", default=0.1, type=float)
54
+
55
+ return group
56
+
57
+ def __init__(self, args):
58
+ super().__init__()
59
+ self.subsampling_rate = args.subsampling_rate
60
+ self.subsampling_input_dim = args.subsampling_input_dim
61
+ self.subsampling_output_dim = args.subsampling_output_dim
62
+ self.subsampling_dropout_rate = args.subsampling_dropout_rate
63
+
64
+ if self.subsampling_rate == 4:
65
+ self.core = Conv2dSubsampling4(
66
+ self.subsampling_input_dim,
67
+ self.subsampling_output_dim,
68
+ self.subsampling_dropout_rate,
69
+ )
70
+
71
+ def forward(self, xs, ilens, masks):
72
+ xs, masks = self.core(xs, masks)
73
+ ilens = masks.squeeze(1).sum(1)
74
+ return xs, ilens, masks
vita/model/multimodal_encoder/whale/module/component/transformer.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Encoder self-attention layer definition."""
2
+
3
+ import math
4
+ import pdb
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from vita.model.multimodal_encoder.whale.module.layer.attention import (
12
+ Conv1dLinear,
13
+ MultiHeadedAttention,
14
+ MultiLayeredConv1d,
15
+ PositionalEncoding,
16
+ PositionwiseFeedForward,
17
+ RelPositionalEncoding,
18
+ )
19
+
20
+ # from vita.model.multimodal_encoder.whale.module.component.utils import *
21
+ from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, add_optional_chunk_mask, strtobool
22
+
23
+
24
+ def repeat(N, fn):
25
+ """Repeat module N times.
26
+
27
+ :param int N: repeat time
28
+ :param function fn: function to generate module
29
+ :return: repeated modules
30
+ :rtype: MultiSequential
31
+ """
32
+ return MultiSequential(*[fn(n) for n in range(N)])
33
+
34
+
35
+ class MultiSequential(torch.nn.Sequential):
36
+ """Multi-input multi-output torch.nn.Sequential."""
37
+
38
+ def forward(self, x, masks, pos_emb):
39
+
40
+ """Repeat."""
41
+ for m in self:
42
+ x, masks, pos_emb = m(x, masks, pos_emb)
43
+ return x, masks, pos_emb
44
+
45
+ @torch.jit.export
46
+ def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
47
+ # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
48
+ """Repeat."""
49
+ for m in self:
50
+ x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
51
+ x, pos_emb, buffer, buffer_index, buffer_out
52
+ )
53
+ return x, pos_emb, buffer, buffer_index, buffer_out
54
+
55
+ @torch.jit.export
56
+ def infer_hidden(self, x, pos_emb, buffer, buffer_index, buffer_out, hidden_out):
57
+ # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
58
+ """Repeat."""
59
+ for m in self:
60
+ x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
61
+ x, pos_emb, buffer, buffer_index, buffer_out
62
+ )
63
+ hidden_out.append(x)
64
+ return x, pos_emb, buffer, buffer_index, buffer_out, hidden_out
65
+
66
+
67
+ class TransformerLayer(nn.Module):
68
+ """Transformer layer module.
69
+
70
+ :param int size: input dim
71
+ :param self_attn: self attention module
72
+ :param feed_forward: feed forward module
73
+ :param float dropout_rate: dropout rate
74
+ :param bool normalize_before: whether to use layer_norm before the first block
75
+ :param bool concat_after: whether to concat attention layer's input and output
76
+ if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
77
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
78
+
79
+ """
80
+
81
+ def __init__(
82
+ self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False
83
+ ):
84
+ """Construct an TransformerLayer object."""
85
+ super(TransformerLayer, self).__init__()
86
+ self.self_attn = self_attn
87
+ self.feed_forward = feed_forward
88
+ self.norm1 = torch.nn.LayerNorm(size)
89
+ self.norm2 = torch.nn.LayerNorm(size)
90
+ self.dropout = nn.Dropout(dropout_rate)
91
+ self.size = size
92
+ self.normalize_before = normalize_before
93
+ self.concat_after = concat_after
94
+ if self.concat_after:
95
+ self.concat_linear = nn.Linear(size + size, size)
96
+ else:
97
+ self.concat_linear = nn.Identity()
98
+
99
+ @torch.jit.unused
100
+ def forward(self, x, mask, pos_emb):
101
+ """Compute encoded features.
102
+
103
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
104
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
105
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
106
+ """
107
+ residual = x
108
+ if self.normalize_before:
109
+ x = self.norm1(x)
110
+ if self.concat_after:
111
+ x_concat = torch.cat((x, self.self_attn(x, x, x, mask, pos_emb)), dim=-1)
112
+ x = residual + self.concat_linear(x_concat)
113
+ else:
114
+ x = residual + self.dropout(self.self_attn(x, x, x, mask, pos_emb))
115
+ if not self.normalize_before:
116
+ x = self.norm1(x)
117
+
118
+ residual = x
119
+ if self.normalize_before:
120
+ x = self.norm2(x)
121
+ x = residual + self.dropout(self.feed_forward(x))
122
+ if not self.normalize_before:
123
+ x = self.norm2(x)
124
+
125
+ return x, mask, pos_emb
126
+
127
+ @torch.jit.export
128
+ def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
129
+ # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
130
+ residual = x.clone()
131
+ if self.normalize_before:
132
+ x = self.norm1(x)
133
+ if self.concat_after:
134
+ x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
135
+ x, x, x, pos_emb, buffer, buffer_index, buffer_out
136
+ )
137
+ x_concat = torch.cat((x, x_att), dim=-1)
138
+ x = residual + self.concat_linear(x_concat)
139
+ else:
140
+ x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
141
+ x, x, x, pos_emb, buffer, buffer_index, buffer_out
142
+ )
143
+ x = residual + x_att
144
+ if not self.normalize_before:
145
+ x = self.norm1(x)
146
+
147
+ residual = x.clone()
148
+ if self.normalize_before:
149
+ x = self.norm2(x)
150
+ x_feed, buffer, buffer_index, buffer_out = self.feed_forward.infer(
151
+ x, buffer, buffer_index, buffer_out
152
+ )
153
+ x = residual + x_feed
154
+ if not self.normalize_before:
155
+ x = self.norm2(x)
156
+
157
+ return x, pos_emb, buffer, buffer_index, buffer_out
158
+
159
+
160
+ class Transformer(torch.nn.Module):
161
+ @staticmethod
162
+ def add_arguments(group):
163
+ """Add TDNN common arguments."""
164
+ group.add_argument(
165
+ "--transformer-input-dim", default=256, type=int, help="Input dim of Transformer."
166
+ )
167
+ group.add_argument(
168
+ "--transformer-output-dim", default=4, type=int, help="Output dim of Transformer."
169
+ )
170
+ group.add_argument(
171
+ "--transformer-attention-dim", default=256, type=int, help="Dimention of attention."
172
+ )
173
+ group.add_argument(
174
+ "--transformer-attention-heads",
175
+ default=4,
176
+ type=int,
177
+ help="The number of heads of multi head attention.",
178
+ )
179
+ group.add_argument(
180
+ "--transformer-linear-units",
181
+ default=1024,
182
+ type=int,
183
+ help="The number of units of position-wise feed forward.",
184
+ )
185
+ group.add_argument(
186
+ "--transformer-num-blocks", default=6, type=int, help="The number of attention blocks."
187
+ )
188
+ group.add_argument(
189
+ "--transformer-dropout-rate",
190
+ default=0.1,
191
+ type=float,
192
+ help="Dropout rate in Transformer.",
193
+ )
194
+ group.add_argument(
195
+ "--transformer-attention-dropout-rate",
196
+ default=0.0,
197
+ type=float,
198
+ help="Dropout rate in attention.",
199
+ )
200
+ group.add_argument(
201
+ "--transformer-positional-dropout-rate",
202
+ default=0.1,
203
+ type=float,
204
+ help="Dropout rate after adding positional encoding.",
205
+ )
206
+ group.add_argument(
207
+ "--transformer-input-layer", default="linear", type=str, help="Type of input layer"
208
+ )
209
+ group.add_argument("--transformer-pos-enc-class", default="abs-enc", type=str, help="")
210
+ group.add_argument(
211
+ "--transformer-normalize-before",
212
+ default=True,
213
+ type=strtobool,
214
+ help="Whether to use layer-norm before the first block.",
215
+ )
216
+ group.add_argument(
217
+ "--transformer-concat-after",
218
+ default=False,
219
+ type=strtobool,
220
+ help="Whether to concat attention layer's input and output.",
221
+ )
222
+ group.add_argument(
223
+ "--transformer-positionwise-layer-type",
224
+ default="linear",
225
+ type=str,
226
+ help="Linear of conv1d.",
227
+ )
228
+ group.add_argument(
229
+ "--transformer-positionwise-conv-kernel_size",
230
+ default=1,
231
+ type=int,
232
+ help="Kernel size of positionwise conv1d layer.",
233
+ )
234
+ group.add_argument("--transformer-chunk_size", default=-1, type=int, help="")
235
+ group.add_argument("--transformer-left_chunks", default=-1, type=int, help="")
236
+ group.add_argument("--transformer-dynamic-chunks", default=True, type=strtobool, help="")
237
+ return group
238
+
239
+ def __init__(
240
+ self,
241
+ args,
242
+ input_dim=None,
243
+ output_dim=None,
244
+ attention_dim=None,
245
+ attention_heads=None,
246
+ linear_units=None,
247
+ num_blocks=None,
248
+ dropout_rate=None,
249
+ positional_dropout_rate=None,
250
+ attention_dropout_rate=None,
251
+ input_layer=None,
252
+ pos_enc_class=None,
253
+ normalize_before=None,
254
+ concat_after=None,
255
+ positionwise_layer_type=None,
256
+ positionwise_conv_kernel_size=None,
257
+ chunk_size=None,
258
+ left_chunks=None,
259
+ ):
260
+ """Construct an Encoder object."""
261
+ super(Transformer, self).__init__()
262
+ if args is None:
263
+ self.input_dim = input_dim
264
+ self.output_dim = output_dim
265
+ self.attention_dim = attention_dim
266
+ self.attention_heads = attention_heads
267
+ self.linear_units = linear_units
268
+ self.num_blocks = num_blocks
269
+ self.dropout_rate = dropout_rate
270
+ self.positional_dropout_rate = positional_dropout_rate
271
+ self.attention_dropout_rate = attention_dropout_rate
272
+ self.input_layer = input_layer
273
+ self.pos_enc_class = pos_enc_class
274
+ self.normalize_before = normalize_before
275
+ self.concat_after = concat_after
276
+ self.positionwise_layer_type = positionwise_layer_type
277
+ self.positionwise_conv_kernel_size = positionwise_conv_kernel_size
278
+ self.chunk_size = chunk_size
279
+ self.left_chunks = left_chunks
280
+ else:
281
+ self.input_dim = args.transformer_input_dim
282
+ self.output_dim = args.transformer_output_dim
283
+ self.attention_dim = args.transformer_attention_dim
284
+ self.attention_heads = args.transformer_attention_heads
285
+ self.linear_units = args.transformer_linear_units
286
+ self.num_blocks = args.transformer_num_blocks
287
+ self.dropout_rate = args.transformer_dropout_rate
288
+ self.positional_dropout_rate = args.transformer_positional_dropout_rate
289
+ self.attention_dropout_rate = args.transformer_attention_dropout_rate
290
+ self.input_layer = args.transformer_input_layer
291
+ self.pos_enc_class = args.transformer_pos_enc_class
292
+ self.normalize_before = args.transformer_normalize_before
293
+ self.concat_after = args.transformer_concat_after
294
+ self.positionwise_layer_type = args.transformer_positionwise_layer_type
295
+ self.positionwise_conv_kernel_size = args.transformer_positionwise_conv_kernel_size
296
+ self.chunk_size = args.transformer_chunk_size
297
+ self.left_chunks = args.transformer_left_chunks
298
+ self.transformer_dynamic_chunks = args.transformer_dynamic_chunks
299
+
300
+ if self.pos_enc_class == "abs-enc":
301
+ pos_enc_args = (self.attention_dim, self.positional_dropout_rate)
302
+ pos_enc_class = PositionalEncoding
303
+ elif self.pos_enc_class == "rel-enc":
304
+ pos_enc_args = (
305
+ self.attention_dim,
306
+ self.positional_dropout_rate,
307
+ self.chunk_size,
308
+ self.left_chunks,
309
+ )
310
+ pos_enc_class = RelPositionalEncoding
311
+
312
+ if self.input_layer == "linear":
313
+ self.embed = torch.nn.Sequential(
314
+ torch.nn.Linear(self.input_dim, self.attention_dim),
315
+ torch.nn.LayerNorm(self.attention_dim),
316
+ torch.nn.Dropout(self.dropout_rate),
317
+ torch.nn.ReLU(),
318
+ )
319
+ elif self.input_layer == "embed":
320
+ self.embed = torch.nn.Sequential(
321
+ torch.nn.Embedding(self.input_dim, self.attention_dim, padding_idx=IGNORE_ID)
322
+ )
323
+ elif self.input_layer == "none":
324
+ self.embed = torch.nn.Sequential(torch.nn.Identity())
325
+ else:
326
+ raise ValueError("unknown input_layer: " + self.input_layer)
327
+ self.pe = pos_enc_class(*pos_enc_args)
328
+ self.embed_layer_num = len(self.embed)
329
+
330
+ if self.positionwise_layer_type == "linear":
331
+ positionwise_layer = PositionwiseFeedForward
332
+ positionwise_layer_args = (self.attention_dim, self.linear_units, self.dropout_rate)
333
+ elif self.positionwise_layer_type == "conv1d":
334
+ positionwise_layer = MultiLayeredConv1d
335
+ positionwise_layer_args = (
336
+ self.attention_dim,
337
+ self.linear_units,
338
+ self.positionwise_conv_kernel_size,
339
+ self.dropout_rate,
340
+ )
341
+ elif self.positionwise_layer_type == "conv1d-linear":
342
+ positionwise_layer = Conv1dLinear
343
+ positionwise_layer_args = (
344
+ self.attention_dim,
345
+ self.linear_units,
346
+ self.positionwise_conv_kernel_size,
347
+ self.dropout_rate,
348
+ )
349
+ else:
350
+ raise NotImplementedError("Support only linear or conv1d.")
351
+
352
+ self.encoders = repeat(
353
+ self.num_blocks,
354
+ lambda lnum: TransformerLayer(
355
+ self.attention_dim,
356
+ MultiHeadedAttention(
357
+ self.attention_heads,
358
+ self.attention_dim,
359
+ self.attention_dropout_rate,
360
+ self.chunk_size,
361
+ self.left_chunks,
362
+ self.pos_enc_class,
363
+ ),
364
+ positionwise_layer(*positionwise_layer_args),
365
+ self.dropout_rate,
366
+ self.normalize_before,
367
+ self.concat_after,
368
+ ),
369
+ )
370
+ if self.normalize_before:
371
+ self.after_norm = torch.nn.LayerNorm(self.attention_dim)
372
+
373
+ @torch.jit.unused
374
+ def forward(self, xs, ilens=None, masks=None):
375
+ """Embed positions in tensor.
376
+
377
+ :param torch.Tensor xs: input tensor
378
+ :param torch.Tensor masks: input mask
379
+ :return: position embedded tensor and mask
380
+ :rtype Tuple[torch.Tensor, torch.Tensor]:
381
+ """
382
+
383
+ if self.transformer_dynamic_chunks == True: # and self.training:
384
+ chunk_masks = add_optional_chunk_mask(xs, masks, True, True, 0, 0, -1)
385
+ else:
386
+ chunk_masks = add_optional_chunk_mask(
387
+ xs, masks, False, False, self.chunk_size, self.chunk_size, self.left_chunks
388
+ ).to(xs.device)
389
+ xs = self.embed(xs)
390
+ xs, pos_emb = self.pe(xs)
391
+ xs, chunk_masks, pos_emb = self.encoders(xs, chunk_masks, pos_emb)
392
+ if self.normalize_before:
393
+ xs = self.after_norm(xs)
394
+ return xs, ilens, masks
395
+
396
+ @torch.jit.export
397
+ def infer(self, xs, buffer, buffer_index, buffer_out):
398
+ xs = self.embed(xs)
399
+
400
+ # pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
401
+ # xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
402
+ # buffer_out.append(pe_index.reshape(-1).to(torch.float32))
403
+ # buffer_index = buffer_index + 1
404
+ xs, pos_emb, _ = self.pe.infer(xs, 0)
405
+ xs, pos_emb, buffer, buffer_index, buffer_out = self.encoders.infer(
406
+ xs, pos_emb, buffer, buffer_index, buffer_out
407
+ )
408
+
409
+ if self.normalize_before:
410
+ xs = self.after_norm(xs)
411
+ return xs, buffer, buffer_index, buffer_out
412
+
413
+ @torch.jit.export
414
+ def infer_hidden(self, xs, buffer, buffer_index, buffer_out, hidden_out):
415
+ xs = self.embed(xs)
416
+
417
+ # pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
418
+ # xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
419
+ # buffer_out.append(pe_index.reshape(-1).to(torch.float32))
420
+ # buffer_index = buffer_index + 1
421
+ xs, pos_emb, _ = self.pe.infer(xs, 0)
422
+ xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out = self.encoders.infer_hidden(
423
+ xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out
424
+ )
425
+
426
+ if self.normalize_before:
427
+ xs = self.after_norm(xs)
428
+ return xs, buffer, buffer_index, buffer_out, hidden_out
vita/model/multimodal_encoder/whale/module/encoder/encoder.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import sys
4
+ import time
5
+ from typing import Dict, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import six
9
+ import torch
10
+
11
+ from vita.model.multimodal_encoder.whale.module.component.mamba import MambaSSM
12
+ from vita.model.multimodal_encoder.whale.module.component.subsampling import Subsampling
13
+ from vita.model.multimodal_encoder.whale.module.component.transformer import Transformer
14
+ from vita.model.multimodal_encoder.whale.utils import make_pad_mask
15
+
16
+
17
+ def add_encoder_args(group):
18
+ """Add Encoder common arguments."""
19
+ group.add_argument(
20
+ "--encoder-layer-config",
21
+ type=str,
22
+ default="tdnn-dtc",
23
+ help="Layer config of encoder. Format layername-layername-..., default(conv1d-fsmn-rnn)",
24
+ )
25
+ group.add_argument(
26
+ "--encoder-input-dim",
27
+ type=int,
28
+ default=256,
29
+ help="Input dim of encoder. Must equal to the input dim of the first Component (default=40)",
30
+ )
31
+ group.add_argument(
32
+ "--encoder-output-dim",
33
+ type=int,
34
+ default=256,
35
+ help="Output dim of encoder. Must enqual to the output dim of the last Component ! (default=256)",
36
+ )
37
+ # Add args of all kinds of components.
38
+ # If you add a new component, DO NOT forget to add args to add_component_args func.
39
+ group = Transformer.add_arguments(group)
40
+ group = Subsampling.add_arguments(group)
41
+ group = MambaSSM.add_arguments(group)
42
+ return group
43
+
44
+
45
+ def assign_args_from_dict(args, dict, prefix_key=None):
46
+ if prefix_key is not None:
47
+ dict = dict[prefix_key]
48
+ for k, v in dict.items():
49
+ k_args = k.replace("-", "_")
50
+ if hasattr(args, k_args):
51
+ setattr(args, k_args, dict[k])
52
+ return args
53
+
54
+
55
+ class whaleEncoder(torch.nn.Module):
56
+ def __init__(self, input_dim, overview_conf=None, para_conf=None, global_cmvn=None):
57
+ super(whaleEncoder, self).__init__()
58
+
59
+ parser = argparse.ArgumentParser()
60
+ add_encoder_args(parser)
61
+ args, _ = parser.parse_known_args()
62
+
63
+ assign_args_from_dict(args, overview_conf)
64
+ # assign_args_from_dict(args, para_conf)
65
+
66
+ self.config = args.encoder_layer_config.split("-")
67
+ encoder_input_dim = args.encoder_input_dim
68
+ encoder_output_dim = args.encoder_output_dim
69
+ prev_output_dim = encoder_input_dim
70
+ prev_component_name = "encoder"
71
+ self.enc = torch.nn.ModuleList([])
72
+ for name in self.config:
73
+ assign_args_from_dict(args, para_conf[name])
74
+ if len(name.split("_")) == 2:
75
+ name = name.split("_")[0]
76
+ elif len(name.split("_")) == 1:
77
+ name = name
78
+ else:
79
+ logging.error("WRONG CONFIG! {} is not valid".format("encoder", name))
80
+ sys.exit()
81
+
82
+ if name == "transformer":
83
+ self.enc.append(Transformer(args))
84
+ elif name == "subsampling":
85
+ self.enc.append(Subsampling(args))
86
+ elif name == "mamba":
87
+ self.enc.append(MambaSSM(args))
88
+ else:
89
+ print("{} is not supported now!".format(name))
90
+ return NotImplemented
91
+ component_input_dim = getattr(args, name + "_input_dim")
92
+ if component_input_dim != prev_output_dim:
93
+ # This is the first layer
94
+ logging.error(
95
+ "WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-input-dim ({})".format(
96
+ prev_component_name, prev_output_dim, name, component_input_dim
97
+ )
98
+ )
99
+ sys.exit()
100
+ prev_output_dim = getattr(args, name + "_output_dim")
101
+ prev_component_name = name
102
+
103
+ self.global_cmvn = global_cmvn
104
+ if prev_output_dim != encoder_output_dim:
105
+ logging.error(
106
+ "WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-output-dim ({}, the last component)".format(
107
+ "encoder", encoder_output_dim, name, prev_output_dim
108
+ )
109
+ )
110
+ sys.exit()
111
+
112
+ self._output_size = encoder_output_dim
113
+
114
+ num_params = sum(p.numel() for p in self.parameters())
115
+ print("the number of whale encoder params: {}M".format(num_params / 1024 / 1024))
116
+
117
+ def output_size(self) -> int:
118
+ return self._output_size
119
+
120
+ @torch.jit.unused
121
+ def forward(self, xs, ilens, decoding_chunk_size=None, num_decoding_left_chunks=None):
122
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[List[int]], Optional[Tensor]]
123
+ """Encoder forward
124
+
125
+ :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
126
+ :param torch.Tensor ilens: batch of lengths of input sequences (B)
127
+ :return: batch of hidden state sequences (B, Tmax, eprojs)
128
+ :rtype: torch.Tensor
129
+ """
130
+
131
+ if decoding_chunk_size is not None and num_decoding_left_chunks is not None:
132
+ for layer in self.enc:
133
+ if hasattr(layer, "chunk_size"):
134
+ layer.chunk_size = decoding_chunk_size
135
+ if hasattr(layer, "left_chunks"):
136
+ layer.left_chunks = num_decoding_left_chunks
137
+ if hasattr(layer, "transformer_dynamic_chunks"):
138
+ layer.transformer_dynamic_chunks = False
139
+
140
+ assert (len(xs.shape)) == 3
141
+ T = xs.size(1)
142
+ masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T)
143
+ if self.global_cmvn is not None:
144
+ xs = self.global_cmvn(xs)
145
+ for module in self.enc:
146
+ xs, ilens, masks = module(xs, ilens, masks)
147
+ return xs, masks
148
+
149
+ @torch.jit.export
150
+ def infer(self, xs_pad, buffer, buffer_index, buffer_out):
151
+ if self.global_cmvn is not None:
152
+ xs = self.global_cmvn(xs)
153
+ for module in self.enc:
154
+ xs_pad, buffer, buffer_index, buffer_out = module.infer(
155
+ xs_pad, buffer, buffer_index, buffer_out
156
+ )
157
+ return xs_pad, buffer, buffer_index, buffer_out
158
+
159
+ @torch.jit.export
160
+ def infer_hidden(self, xs_pad, buffer, buffer_index, buffer_out, hidden_out):
161
+ if self.global_cmvn is not None:
162
+ xs = self.global_cmvn(xs)
163
+ for module in self.enc:
164
+ xs_pad, buffer, buffer_index, buffer_out, hidden_out = module.infer_hidden(
165
+ xs_pad, buffer, buffer_index, buffer_out, hidden_out
166
+ )
167
+ return xs_pad, buffer, buffer_index, buffer_out, hidden_out
168
+
169
+ @torch.jit.ignore(drop=True)
170
+ def get_extra_loss(self) -> Dict[str, torch.Tensor]:
171
+ return None
vita/model/multimodal_encoder/whale/module/layer/attention.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pdb
3
+
4
+ import numpy
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class PositionalEncoding(torch.nn.Module):
10
+ """Positional encoding.
11
+ :param int d_model: embedding dim
12
+ :param float dropout_rate: dropout rate
13
+ :param int max_len: maximum input length
14
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
15
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
16
+ """
17
+
18
+ def __init__(
19
+ self, d_model: int, dropout_rate: float, max_len: int = 1500, reverse: bool = False
20
+ ):
21
+ """Construct an PositionalEncoding object."""
22
+ super().__init__()
23
+ self.d_model = d_model
24
+ self.xscale = math.sqrt(self.d_model)
25
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
26
+ self.max_len = max_len
27
+
28
+ self.pe = torch.zeros(self.max_len, self.d_model)
29
+ position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1)
30
+ div_term = torch.exp(
31
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
32
+ * -(math.log(10000.0) / self.d_model)
33
+ )
34
+ self.pe[:, 0::2] = torch.sin(position * div_term)
35
+ self.pe[:, 1::2] = torch.cos(position * div_term)
36
+ self.pe = self.pe.unsqueeze(0)
37
+
38
+ def forward(self, x: torch.Tensor, offset: int = 0):
39
+ """Add positional encoding.
40
+ Args:
41
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
42
+ offset (int): position offset
43
+ Returns:
44
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
45
+ torch.Tensor: for compatibility to RelPositionalEncoding
46
+ """
47
+ assert offset + x.size(1) < self.max_len
48
+ self.pe = self.pe.to(x.device)
49
+ pos_emb = self.pe[:, offset : offset + x.size(1)]
50
+ x = x * self.xscale + pos_emb
51
+ return self.dropout(x), self.dropout(pos_emb)
52
+
53
+ def position_encoding(self, offset: int, size: int):
54
+ """For getting encoding in a streaming fashion
55
+ Attention!!!!!
56
+ we apply dropout only once at the whole utterance level in a none
57
+ streaming way, but will call this function several times with
58
+ increasing input size in a streaming scenario, so the dropout will
59
+ be applied several times.
60
+ Args:
61
+ offset (int): start offset
62
+ size (int): requried size of position encoding
63
+ Returns:
64
+ torch.Tensor: Corresponding encoding
65
+ """
66
+ assert offset + size < self.max_len
67
+ return self.dropout(self.pe[:, offset : offset + size])
68
+
69
+
70
+ class RelPositionalEncoding(PositionalEncoding):
71
+ """Relative positional encoding module.
72
+ See : Appendix B in https://arxiv.org/abs/1901.02860
73
+ Args:
74
+ d_model (int): Embedding dimension.
75
+ dropout_rate (float): Dropout rate.
76
+ max_len (int): Maximum input length.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ d_model: int,
82
+ dropout_rate: float,
83
+ chunk_size: int,
84
+ left_chunks: int,
85
+ max_len: int = 5000,
86
+ ):
87
+ """Initialize class."""
88
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
89
+ self.chunk_size = chunk_size
90
+ self.left_chunks = left_chunks
91
+ self.full_chunk_size = (self.left_chunks + 1) * self.chunk_size
92
+
93
+ self.div_term = torch.exp(
94
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
95
+ * -(math.log(10000.0) / self.d_model)
96
+ )
97
+ self.max_len = self.chunk_size * (max_len // self.chunk_size) - self.full_chunk_size
98
+
99
+ @torch.jit.export
100
+ def forward(self, x: torch.Tensor, offset: int = 0):
101
+ """Compute positional encoding.
102
+ Args:
103
+ x (torch.Tensor): Input tensor (batch, time, `*`).
104
+ Returns:
105
+ torch.Tensor: Encoded tensor (batch, time, `*`).
106
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
107
+ """
108
+ self.pe = self.pe.to(x.device)
109
+ x = x * self.xscale
110
+ pos_emb = self.pe[:, offset : offset + x.size(1)]
111
+ return self.dropout(x), self.dropout(pos_emb)
112
+
113
+ @torch.jit.export
114
+ def infer(self, xs, pe_index):
115
+ # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
116
+ pe_index = pe_index % self.max_len
117
+ xs = xs * self.xscale
118
+
119
+ pe = torch.zeros(self.full_chunk_size, self.d_model)
120
+ position = torch.arange(
121
+ pe_index, pe_index + self.full_chunk_size, dtype=torch.float32
122
+ ).unsqueeze(1)
123
+ pe[:, 0::2] = torch.sin(position * self.div_term)
124
+ pe[:, 1::2] = torch.cos(position * self.div_term)
125
+ pos_emb = pe.unsqueeze(0)
126
+
127
+ pe_index = pe_index + self.chunk_size
128
+ return xs, pos_emb, pe_index
129
+
130
+
131
+ class PositionwiseFeedForward(torch.nn.Module):
132
+ """Positionwise feed forward layer.
133
+ :param int idim: input dimenstion
134
+ :param int hidden_units: number of hidden units
135
+ :param float dropout_rate: dropout rate
136
+ """
137
+
138
+ def __init__(self, idim, hidden_units, dropout_rate):
139
+ """Construct an PositionwiseFeedForward object."""
140
+ super(PositionwiseFeedForward, self).__init__()
141
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
142
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
143
+ self.dropout = torch.nn.Dropout(dropout_rate)
144
+
145
+ def forward(self, x):
146
+ """Forward funciton."""
147
+ return self.w_2(self.dropout(torch.relu(self.w_1(x))))
148
+
149
+ @torch.jit.export
150
+ def infer(self, xs, buffer, buffer_index, buffer_out):
151
+ # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
152
+ return self.w_2(torch.relu(self.w_1(xs))), buffer, buffer_index, buffer_out
153
+
154
+
155
+ class MultiLayeredConv1d(torch.nn.Module):
156
+ """Multi-layered conv1d for Transformer block.
157
+
158
+ This is a module of multi-leyered conv1d designed
159
+ to replace positionwise feed-forward network
160
+ in Transformer block, which is introduced in
161
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
162
+
163
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
164
+ https://arxiv.org/pdf/1905.09263.pdf
165
+
166
+ """
167
+
168
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
169
+ """Initialize MultiLayeredConv1d module.
170
+
171
+ Args:
172
+ in_chans (int): Number of input channels.
173
+ hidden_chans (int): Number of hidden channels.
174
+ kernel_size (int): Kernel size of conv1d.
175
+ dropout_rate (float): Dropout rate.
176
+
177
+ """
178
+ super(MultiLayeredConv1d, self).__init__()
179
+ self.w_1 = torch.nn.Conv1d(
180
+ in_chans,
181
+ hidden_chans,
182
+ kernel_size,
183
+ stride=1,
184
+ padding=(kernel_size - 1) // 2,
185
+ )
186
+ self.w_2 = torch.nn.Conv1d(
187
+ hidden_chans,
188
+ in_chans,
189
+ kernel_size,
190
+ stride=1,
191
+ padding=(kernel_size - 1) // 2,
192
+ )
193
+ self.dropout = torch.nn.Dropout(dropout_rate)
194
+
195
+ @torch.jit.unused
196
+ def forward(self, x):
197
+ """Calculate forward propagation.
198
+
199
+ Args:
200
+ x (Tensor): Batch of input tensors (B, ..., in_chans).
201
+
202
+ Returns:
203
+ Tensor: Batch of output tensors (B, ..., hidden_chans).
204
+
205
+ """
206
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
207
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
208
+
209
+
210
+ class Conv1dLinear(torch.nn.Module):
211
+ """Conv1D + Linear for Transformer block.
212
+
213
+ A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
214
+
215
+ """
216
+
217
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
218
+ """Initialize Conv1dLinear module.
219
+
220
+ Args:
221
+ in_chans (int): Number of input channels.
222
+ hidden_chans (int): Number of hidden channels.
223
+ kernel_size (int): Kernel size of conv1d.
224
+ dropout_rate (float): Dropout rate.
225
+
226
+ """
227
+ super(Conv1dLinear, self).__init__()
228
+ self.lorder = kernel_size - 1
229
+ self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
230
+ self.w_1 = torch.nn.Sequential(
231
+ torch.nn.Conv1d(in_chans, in_chans, kernel_size, stride=1, padding=0, groups=in_chans),
232
+ torch.nn.Conv1d(in_chans, hidden_chans, 1, padding=0),
233
+ )
234
+ self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
235
+ self.dropout = torch.nn.Dropout(dropout_rate)
236
+ self.in_chans = in_chans
237
+
238
+ # cnn_buffer = 1, in_chans, self.lorder
239
+ self.buffer_size = 1 * self.in_chans * self.lorder
240
+
241
+ @torch.jit.unused
242
+ def forward(self, x):
243
+ """Calculate forward propagation.
244
+
245
+ Args:
246
+ x (Tensor): Batch of input tensors (B, ..., in_chans).
247
+
248
+ Returns:
249
+ Tensor: Batch of output tensors (B, ..., hidden_chans).
250
+
251
+ """
252
+ x = torch.relu(self.w_1(self.left_padding(x.transpose(-1, 1)))).transpose(-1, 1)
253
+ return self.w_2(self.dropout(x))
254
+
255
+ @torch.jit.export
256
+ def infer(self, x, buffer, buffer_index, buffer_out):
257
+ # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
258
+ x = x.transpose(-1, 1)
259
+
260
+ cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
261
+ [1, self.in_chans, self.lorder]
262
+ )
263
+ x = torch.cat([cnn_buffer, x], dim=2)
264
+ buffer_out.append(x[:, :, -self.lorder :].reshape(-1))
265
+ buffer_index = buffer_index + self.buffer_size
266
+
267
+ x = self.w_1(x)
268
+ x = torch.relu(x).transpose(-1, 1)
269
+ x = self.w_2(x)
270
+ return x, buffer, buffer_index, buffer_out
271
+
272
+
273
+ class MultiHeadedAttention(nn.Module):
274
+ """Multi-Head Attention layer.
275
+
276
+ :param int n_head: the number of head s
277
+ :param int n_feat: the number of features
278
+ :param float dropout_rate: dropout rate
279
+
280
+ """
281
+
282
+ def __init__(self, n_head, n_feat, dropout_rate, chunk_size, left_chunks, pos_enc_class):
283
+ """Construct an MultiHeadedAttention object."""
284
+ super(MultiHeadedAttention, self).__init__()
285
+ assert n_feat % n_head == 0
286
+ # We assume d_v always equals d_k
287
+ self.d_k = n_feat // n_head
288
+ self.h = n_head
289
+ self.linear_q = nn.Linear(n_feat, n_feat)
290
+ self.linear_k = nn.Linear(n_feat, n_feat)
291
+ self.linear_v = nn.Linear(n_feat, n_feat)
292
+ self.linear_out = nn.Linear(n_feat, n_feat)
293
+ self.dropout = nn.Dropout(p=dropout_rate)
294
+ # self.min_value = float(numpy.finfo(torch.tensor(0, dtype=torch.float16).numpy().dtype).min)
295
+ self.min_value = float(torch.finfo(torch.float16).min)
296
+ # chunk par
297
+ if chunk_size > 0 and left_chunks > 0: # for streaming mode
298
+ self.buffersize = chunk_size * (left_chunks)
299
+ self.left_chunk_size = chunk_size * left_chunks
300
+ else: # for non-streaming mode
301
+ self.buffersize = 1
302
+ self.left_chunk_size = 1
303
+ self.chunk_size = chunk_size
304
+
305
+ # encoding setup
306
+ if pos_enc_class == "rel-enc":
307
+ self.rel_enc = True
308
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
309
+ # these two learnable bias are used in matrix c and matrix d
310
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
311
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
312
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
313
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
314
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
315
+ else:
316
+ self.rel_enc = False
317
+ self.linear_pos = nn.Identity()
318
+ self.pos_bias_u = torch.tensor([0])
319
+ self.pos_bias_v = torch.tensor([0])
320
+
321
+ # buffer
322
+ # key_buffer = 1, self.h, self.buffersize, self.d_k
323
+ self.key_buffer_size = 1 * self.h * self.buffersize * self.d_k
324
+ # value_buffer = 1, self.h, self.buffersize, self.d_k
325
+ self.value_buffer_size = 1 * self.h * self.buffersize * self.d_k
326
+ if self.chunk_size > 0:
327
+ # buffer_mask_size = 1, self.h, self.chunk_size, self.buffersize
328
+ self.buffer_mask_size = 1 * self.h * self.chunk_size * self.buffersize
329
+ # self.buffer_mask = torch.ones([1, self.h, self.chunk_size, self.buffersize], dtype=torch.bool)
330
+ else:
331
+ self.buffer_mask = torch.ones([1, self.h, 1, 1], dtype=torch.bool)
332
+
333
+ @torch.jit.unused
334
+ def rel_shift(self, x, zero_triu: bool = False):
335
+ """Compute relative positinal encoding.
336
+ Args:
337
+ x (torch.Tensor): Input tensor (batch, time, size).
338
+ zero_triu (bool): If true, return the lower triangular part of
339
+ the matrix.
340
+ Returns:
341
+ torch.Tensor: Output tensor.
342
+ """
343
+
344
+ zero_pad = torch.zeros(
345
+ (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
346
+ )
347
+ x_padded = torch.cat([zero_pad, x], dim=-1)
348
+
349
+ x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
350
+ x = x_padded[:, :, 1:].view_as(x)
351
+
352
+ if zero_triu:
353
+ ones = torch.ones((x.size(2), x.size(3)))
354
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
355
+ return x
356
+
357
+ @torch.jit.export
358
+ def forward(self, query, key, value, mask=None, pos_emb=torch.tensor(1.0)):
359
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], Tensor) -> Tensor
360
+ """Compute 'Scaled Dot Product Attention'.
361
+
362
+ :param torch.Tensor query: (batch, time1, size)
363
+ :param torch.Tensor key: (batch, time2, size)
364
+ :param torch.Tensor value: (batch, time2, size)
365
+ :param torch.Tensor mask: (batch, time1, time2)
366
+ :param torch.nn.Dropout dropout:
367
+ :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
368
+ weighted by the query dot key attention (batch, head, time1, time2)
369
+ """
370
+ n_batch = query.size(0)
371
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
372
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
373
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
374
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
375
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
376
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
377
+
378
+ if self.rel_enc:
379
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
380
+ n_batch_pos = pos_emb.size(0)
381
+ p = self.linear_pos(pos_emb.to(query.dtype)).view(n_batch_pos, -1, self.h, self.d_k)
382
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
383
+ # (batch, head, time1, d_k)
384
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
385
+ # (batch, head, time1, d_k)
386
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
387
+ # compute attention score
388
+ # first compute matrix a and matrix c
389
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
390
+ # (batch, head, time1, time2)
391
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
392
+ # compute matrix b and matrix d
393
+ # (batch, head, time1, time2)
394
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
395
+ # Remove rel_shift since it is useless in speech recognition,
396
+ # and it requires special attention for streaming.
397
+ # matrix_bd = self.rel_shift(matrix_bd)
398
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
399
+ else:
400
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
401
+ self.d_k
402
+ ) # (batch, head, time1, time2)
403
+
404
+ if mask is not None:
405
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
406
+ scores = scores.masked_fill(mask, self.min_value)
407
+ attn = torch.softmax(scores, dim=-1).masked_fill(
408
+ mask, 0.0
409
+ ) # (batch, head, time1, time2)
410
+ else:
411
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
412
+
413
+ p_attn = self.dropout(attn)
414
+
415
+ x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
416
+ x = (
417
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
418
+ ) # (batch, time1, d_model)
419
+ return self.linear_out(x) # (batch, time1, d_model)
420
+
421
+ @torch.jit.export
422
+ def infer(self, query, key, value, pos_emb, buffer, buffer_index, buffer_out):
423
+ # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
424
+ n_batch = query.size(0)
425
+
426
+ q = (
427
+ self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
428
+ ) # (batch, head, len_q, d_k)
429
+ k = (
430
+ self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
431
+ ) # (batch, head, len_k, d_k)
432
+ v = (
433
+ self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
434
+ ) # (batch, head, len_v, d_k)
435
+
436
+ key_value_buffer = buffer[
437
+ buffer_index : buffer_index + self.key_buffer_size + self.value_buffer_size
438
+ ].reshape([1, self.h, self.buffersize * 2, self.d_k])
439
+ key_buffer = torch.cat([key_value_buffer[:, :, : self.buffersize, :], k], dim=2)
440
+ value_buffer = torch.cat([key_value_buffer[:, :, self.buffersize :, :], v], dim=2)
441
+ buffer_out.append(
442
+ torch.cat(
443
+ [key_buffer[:, :, self.chunk_size :, :], value_buffer[:, :, self.chunk_size :, :]],
444
+ dim=2,
445
+ ).reshape(-1)
446
+ )
447
+ buffer_index = buffer_index + self.key_buffer_size + self.value_buffer_size
448
+
449
+ if self.rel_enc:
450
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
451
+ n_batch_pos = pos_emb.size(0)
452
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
453
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
454
+ # (batch, head, time1, d_k)
455
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
456
+ # (batch, head, time1, d_k)
457
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
458
+ # compute attention score
459
+ # first compute matrix a and matrix c
460
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
461
+ # (batch, head, time1, time2)
462
+ matrix_ac = torch.matmul(q_with_bias_u, key_buffer.transpose(-2, -1))
463
+ # compute matrix b and matrix d
464
+ # (batch, head, time1, time2)
465
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
466
+ # Remove rel_shift since it is useless in speech recognition,
467
+ # and it requires special attention for streaming.
468
+ # matrix_bd = self.rel_shift(matrix_bd)
469
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
470
+ else:
471
+ scores = torch.matmul(q, key_buffer.transpose(-2, -1)) / math.sqrt(
472
+ self.d_k
473
+ ) # (batch, head, len_q, buffersize)
474
+
475
+ attn = torch.softmax(scores, dim=-1)
476
+
477
+ x = torch.matmul(attn, value_buffer) # (batch, head, len_q, d_k)
478
+ x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
479
+ return self.linear_out(x), buffer, buffer_index, buffer_out # (batch, time1, d_model)
480
+
481
+ @torch.jit.export
482
+ def infer_mask(self, query, key, value, mask, buffer, buffer_index, buffer_out, is_static):
483
+ n_batch = query.size(0)
484
+
485
+ q = (
486
+ self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
487
+ ) # (batch, head, len_q, d_k)
488
+ k = (
489
+ self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
490
+ ) # (batch, head, len_k, d_k)
491
+ v = (
492
+ self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
493
+ ) # (batch, head, len_v, d_k)
494
+
495
+ if is_static:
496
+ key_buffer = k
497
+ value_buffer = v
498
+ else:
499
+ key_value_buffer = buffer[
500
+ buffer_index : buffer_index + self.key_buffer_size + self.value_buffer_size
501
+ ].reshape([1, self.h, self.buffersize * 2, self.d_k])
502
+ key_buffer = torch.cat([key_value_buffer[:, :, : self.buffersize, :], k], dim=2)
503
+ value_buffer = torch.cat([key_value_buffer[:, :, self.buffersize :, :], v], dim=2)
504
+ buffer_out.append(
505
+ torch.cat(
506
+ [
507
+ key_buffer[:, :, self.chunk_size :, :],
508
+ value_buffer[:, :, self.chunk_size :, :],
509
+ ],
510
+ dim=2,
511
+ ).reshape(-1)
512
+ )
513
+ buffer_index = buffer_index + self.key_buffer_size + self.value_buffer_size
514
+
515
+ scores = torch.matmul(q, key_buffer.transpose(-2, -1)) / math.sqrt(
516
+ self.d_k
517
+ ) # (batch, head, len_q, buffersize)
518
+
519
+ if mask is not None:
520
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
521
+ scores = scores.masked_fill(mask, self.min_value)
522
+ attn = torch.softmax(scores, dim=-1).masked_fill(
523
+ mask, 0.0
524
+ ) # (batch, head, time1, time2)
525
+ else:
526
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
527
+
528
+ x = torch.matmul(attn, value_buffer) # (batch, head, len_q, d_k)
529
+ x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
530
+ return self.linear_out(x), buffer_index, buffer_out # (batch, time1, d_model)
531
+
532
+
533
+ class SoftAttention(nn.Module):
534
+ def __init__(self, in_dim, hidden_dim):
535
+ super(SoftAttention, self).__init__()
536
+ self.q = torch.nn.Parameter(torch.rand([hidden_dim]), requires_grad=True)
537
+ self.wb = nn.Linear(in_dim, hidden_dim)
538
+ self.min_value = float(numpy.finfo(torch.tensor(0, dtype=torch.float32).numpy().dtype).min)
539
+ # buffer
540
+ self.window_size = 50
541
+ self.buffer_in = torch.zeros([1, self.window_size, in_dim], dtype=torch.float32)
542
+ self.buffer = torch.zeros([1, self.window_size], dtype=torch.float32)
543
+ self.buffer[:, :] = float(
544
+ numpy.finfo(torch.tensor(0, dtype=torch.float32).numpy().dtype).min
545
+ )
546
+
547
+ @torch.jit.unused
548
+ def forward(self, x, mask=None):
549
+ hidden = torch.tanh(self.wb(x)) # B T D
550
+ hidden = torch.einsum("btd,d->bt", hidden, self.q)
551
+ score = torch.softmax(hidden, dim=-1) # B T
552
+ if mask is not None:
553
+ score = score.masked_fill(mask, 0.0)
554
+ output = torch.einsum("bt,btd->bd", score, x)
555
+ return output
556
+
557
+ @torch.jit.export
558
+ def infer(self, x):
559
+ # type: (Tensor) -> Tensor
560
+ hidden = torch.tanh(self.wb(x)) # B T D
561
+ hidden = torch.einsum("btd,d->bt", hidden, self.q)
562
+ size = hidden.shape[1]
563
+ output = torch.zeros([size, x.shape[-1]])
564
+ for i in range(size):
565
+ self.buffer = torch.cat([self.buffer, hidden[:, i : i + 1]], dim=-1)
566
+ self.buffer = self.buffer[:, 1:]
567
+ score = torch.softmax(self.buffer, dim=-1) # B T
568
+ self.buffer_in = torch.cat([self.buffer_in, x[:, i : i + 1, :]], dim=1)
569
+ self.buffer_in = self.buffer_in[:, 1:]
570
+ output[i : i + 1] = torch.einsum("bt,btd->bd", score, self.buffer_in)
571
+ return output
vita/model/multimodal_encoder/whale/module/layer/conv1d.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Conv1dLayer(nn.Module):
7
+ def __init__(
8
+ self,
9
+ input_dim,
10
+ output_dim,
11
+ kernel_size,
12
+ stride,
13
+ causal_conv,
14
+ dilation,
15
+ dropout_rate,
16
+ residual=True,
17
+ ):
18
+ super(Conv1dLayer, self).__init__()
19
+ self.input_dim = input_dim
20
+ self.output_dim = output_dim
21
+ self.kernel_size = kernel_size
22
+ self.stride = stride
23
+ self.dilation = dilation
24
+ self.causal_conv = causal_conv
25
+ if causal_conv:
26
+ self.lorder = (kernel_size - 1) * self.dilation
27
+ self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
28
+ else:
29
+ assert (kernel_size - 1) % 2 == 0
30
+ self.lorder = ((kernel_size - 1) // 2) * self.dilation
31
+ self.left_padding = nn.ConstantPad1d((self.lorder, self.lorder), 0.0)
32
+ self.conv1d = nn.Conv1d(
33
+ self.input_dim, self.output_dim, self.kernel_size, self.stride, 0, self.dilation
34
+ )
35
+ self.bn = nn.BatchNorm1d(self.output_dim, eps=1e-3, momentum=0.99)
36
+ self.relu = nn.ReLU()
37
+ self.dropout = nn.Dropout(p=dropout_rate)
38
+ self.residual = residual
39
+ if self.input_dim != self.output_dim:
40
+ self.residual = False
41
+
42
+ # buffer = 1, self.input_dim, self.lorder
43
+ self.lorder = (kernel_size - 1) * self.dilation - (self.stride - 1)
44
+ self.buffer_size = 1 * self.input_dim * self.lorder
45
+ self.x_data_chache_size = self.lorder
46
+ self.x_data_buffer_size = self.input_dim * self.x_data_chache_size
47
+
48
+ @torch.jit.unused
49
+ def forward(self, x):
50
+ x_data = x
51
+ x = self.left_padding(x)
52
+ x = self.conv1d(x)
53
+ x = self.bn(x)
54
+ if self.stride == 1 and self.residual:
55
+ x = self.relu(x + x_data)
56
+ else:
57
+ x = self.relu(x)
58
+ x = self.dropout(x)
59
+ return x
60
+
61
+ @torch.jit.export
62
+ def infer(self, x, buffer, buffer_index, buffer_out):
63
+ # type: (Tensor) -> Tensor
64
+ x_data = x.clone()
65
+
66
+ cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
67
+ [1, self.input_dim, self.lorder]
68
+ )
69
+ x = torch.cat([cnn_buffer, x], dim=2)
70
+ buffer_out.append(x[:, :, -self.lorder :].reshape(-1))
71
+ buffer_index = buffer_index + self.buffer_size
72
+
73
+ x = self.conv1d(x)
74
+ x = self.bn(x)
75
+
76
+ if self.stride == 1 and self.residual:
77
+ x_data_cnn_buffer = buffer[
78
+ buffer_index : buffer_index + self.x_data_buffer_size
79
+ ].reshape([1, self.input_dim, self.x_data_chache_size])
80
+ x_data = torch.cat([x_data_cnn_buffer, x_data], dim=2)
81
+ buffer_out.append(x_data[:, :, -self.x_data_chache_size :].reshape(-1))
82
+ buffer_index = buffer_index + self.x_data_buffer_size
83
+ x_data = x_data[:, :, : -self.x_data_chache_size]
84
+ x = self.relu(x + x_data)
85
+ else:
86
+ x = self.relu(x)
87
+
88
+ return x, buffer, buffer_index, buffer_out
vita/model/multimodal_encoder/whale/module/layer/dtcblock.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pdb
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class DTCBlock(nn.Module):
11
+ def __init__(
12
+ self, input_dim, output_dim, kernel_size, stride, causal_conv, dilation, dropout_rate
13
+ ):
14
+ super(DTCBlock, self).__init__()
15
+ self.input_dim = input_dim
16
+ self.output_dim = output_dim
17
+ self.kernel_size = kernel_size
18
+ self.stride = stride
19
+ self.dilation = dilation
20
+ if causal_conv:
21
+ self.padding = 0
22
+ self.lorder = (kernel_size - 1) * self.dilation
23
+ self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
24
+ else:
25
+ assert (kernel_size - 1) % 2 == 0
26
+ self.padding = ((kernel_size - 1) // 2) * self.dilation
27
+ self.lorder = 0
28
+ self.causal_conv = causal_conv
29
+ self.depthwise_conv = nn.Conv1d(
30
+ self.input_dim,
31
+ self.input_dim,
32
+ self.kernel_size,
33
+ self.stride,
34
+ self.padding,
35
+ self.dilation,
36
+ groups=self.input_dim,
37
+ )
38
+ self.point_conv_1 = nn.Conv1d(self.input_dim, self.input_dim, 1, 1, self.padding)
39
+ self.point_conv_2 = nn.Conv1d(self.input_dim, self.input_dim, 1, 1, self.padding)
40
+ self.bn_1 = nn.BatchNorm1d(self.input_dim)
41
+ self.bn_2 = nn.BatchNorm1d(self.input_dim)
42
+ self.bn_3 = nn.BatchNorm1d(self.input_dim)
43
+ self.dropout = nn.Dropout(p=dropout_rate)
44
+
45
+ # buffer = 1, self.input_dim, self.lorder
46
+ self.lorder = (kernel_size - 1) * self.dilation - (self.stride - 1)
47
+ self.buffer_size = 1 * self.input_dim * self.lorder
48
+
49
+ @torch.jit.unused
50
+ def forward(self, x):
51
+ x_in = x
52
+ x_data = x_in.transpose(1, 2)
53
+ if self.causal_conv:
54
+ x_data_pad = self.left_padding(x_data)
55
+ else:
56
+ x_data_pad = x_data
57
+ x_depth = self.depthwise_conv(x_data_pad)
58
+ x_bn_1 = self.bn_1(x_depth)
59
+ x_point_1 = self.point_conv_1(x_bn_1)
60
+ x_bn_2 = self.bn_2(x_point_1)
61
+ x_relu_2 = torch.relu(x_bn_2)
62
+ x_point_2 = self.point_conv_2(x_relu_2)
63
+ x_bn_3 = self.bn_3(x_point_2)
64
+ x_bn_3 = x_bn_3.transpose(1, 2)
65
+ if self.stride == 1:
66
+ x_relu_3 = torch.relu(x_bn_3 + x_in)
67
+ else:
68
+ x_relu_3 = torch.relu(x_bn_3)
69
+ x_drop = self.dropout(x_relu_3)
70
+ return x_drop
71
+
72
+ @torch.jit.export
73
+ def infer(self, x, buffer, buffer_index, buffer_out):
74
+ # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
75
+ x_in = x
76
+ x = x_in.transpose(1, 2)
77
+ cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
78
+ [1, self.input_dim, self.lorder]
79
+ )
80
+ x = torch.cat([cnn_buffer, x], dim=2)
81
+ buffer_out.append(x[:, :, -self.lorder :].reshape(-1))
82
+ buffer_index = buffer_index + self.buffer_size
83
+ x = self.depthwise_conv(x)
84
+ x = self.bn_1(x)
85
+ x = self.point_conv_1(x)
86
+ x = self.bn_2(x)
87
+ x = torch.relu(x)
88
+ x = self.point_conv_2(x)
89
+ x = self.bn_3(x)
90
+ x = x.transpose(1, 2)
91
+ if self.stride == 1:
92
+ x = torch.relu(x + x_in)
93
+ else:
94
+ x = torch.relu(x)
95
+ return x, buffer, buffer_index, buffer_out
vita/model/multimodal_encoder/whale/module/layer/fsmn.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class FsmnLayer(nn.Module):
7
+ def __init__(
8
+ self,
9
+ input_dim,
10
+ out_dim,
11
+ hidden_dim,
12
+ left_frame=1,
13
+ right_frame=1,
14
+ left_dilation=1,
15
+ right_dilation=1,
16
+ ):
17
+ super(FsmnLayer, self).__init__()
18
+ self.input_dim = input_dim
19
+ self.out_dim = out_dim
20
+ self.hidden_dim = hidden_dim
21
+ self.left_frame = left_frame
22
+ self.right_frame = right_frame
23
+ self.left_dilation = left_dilation
24
+ self.right_dilation = right_dilation
25
+ self.conv_in = nn.Conv1d(input_dim, hidden_dim, kernel_size=1)
26
+ if left_frame > 0:
27
+ self.pad_left = nn.ConstantPad1d([left_dilation * left_frame, 0], 0.0)
28
+ self.conv_left = nn.Conv1d(
29
+ hidden_dim,
30
+ hidden_dim,
31
+ kernel_size=left_frame + 1,
32
+ dilation=left_dilation,
33
+ bias=False,
34
+ groups=hidden_dim,
35
+ )
36
+ if right_frame > 0:
37
+ self.pad_right = nn.ConstantPad1d([-right_dilation, right_dilation * right_frame], 0.0)
38
+ self.conv_right = nn.Conv1d(
39
+ hidden_dim,
40
+ hidden_dim,
41
+ kernel_size=right_frame,
42
+ dilation=right_dilation,
43
+ bias=False,
44
+ groups=hidden_dim,
45
+ )
46
+ self.conv_out = nn.Conv1d(hidden_dim, out_dim, kernel_size=1)
47
+
48
+ # cache = 1, self.hidden_dim, left_frame * left_dilation + right_frame * right_dilation
49
+ self.cache_size = left_frame * left_dilation + right_frame * right_dilation
50
+ self.buffer_size = self.hidden_dim * self.cache_size
51
+ self.p_in_raw_chache_size = self.right_frame * self.right_dilation
52
+ self.p_in_raw_buffer_size = self.hidden_dim * self.p_in_raw_chache_size
53
+ self.hidden_chache_size = self.right_frame * self.right_dilation
54
+ self.hidden_buffer_size = self.hidden_dim * self.hidden_chache_size
55
+
56
+ @torch.jit.unused
57
+ def forward(self, x, hidden=None):
58
+ x_data = x.transpose(1, 2)
59
+ p_in = self.conv_in(x_data)
60
+ if self.left_frame > 0:
61
+ p_left = self.pad_left(p_in)
62
+ p_left = self.conv_left(p_left)
63
+ else:
64
+ p_left = 0
65
+ if self.right_frame > 0:
66
+ p_right = self.pad_right(p_in)
67
+ p_right = self.conv_right(p_right)
68
+ else:
69
+ p_right = 0
70
+ p_out = p_in + p_right + p_left
71
+ if hidden is not None:
72
+ p_out = hidden + p_out
73
+ out = F.relu(self.conv_out(p_out))
74
+ out = out.transpose(1, 2)
75
+ return out, p_out
76
+
77
+ @torch.jit.export
78
+ def infer(self, x, buffer, buffer_index, buffer_out, hidden=None):
79
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor]
80
+ p_in_raw = self.conv_in(x)
81
+
82
+ cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
83
+ [1, self.hidden_dim, self.cache_size]
84
+ )
85
+ p_in = torch.cat([cnn_buffer, p_in_raw], dim=2)
86
+ # buffer[buffer_index: buffer_index + self.buffer_size] = p_in[:, :, -self.cache_size:].reshape(-1)
87
+ buffer_out.append(p_in[:, :, -self.cache_size :].reshape(-1))
88
+ buffer_index = buffer_index + self.buffer_size
89
+
90
+ if self.left_frame > 0:
91
+ if self.right_frame > 0:
92
+ p_left = p_in[:, :, : -self.right_frame * self.right_dilation]
93
+ else:
94
+ p_left = p_in[:, :]
95
+ p_left_out = self.conv_left(p_left)
96
+ else:
97
+ p_left_out = torch.tensor([0])
98
+ if self.right_frame > 0:
99
+ p_right = p_in[:, :, self.left_frame * self.left_dilation + 1 :]
100
+ p_right_out = self.conv_right(p_right)
101
+ else:
102
+ p_right_out = torch.tensor([0])
103
+
104
+ if self.right_frame > 0:
105
+ p_in_raw_cnn_buffer = buffer[
106
+ buffer_index : buffer_index + self.p_in_raw_buffer_size
107
+ ].reshape([1, self.hidden_dim, self.p_in_raw_chache_size])
108
+ p_in_raw = torch.cat([p_in_raw_cnn_buffer, p_in_raw], dim=2)
109
+ # buffer[buffer_index: buffer_index + self.p_in_raw_buffer_size] = p_in_raw[:, :, -self.p_in_raw_chache_size:].reshape(-1)
110
+ buffer_out.append(p_in_raw[:, :, -self.p_in_raw_chache_size :].reshape(-1))
111
+ buffer_index = buffer_index + self.p_in_raw_buffer_size
112
+ p_in_raw = p_in_raw[:, :, : -self.p_in_raw_chache_size]
113
+ p_out = p_in_raw + p_left_out + p_right_out
114
+
115
+ if hidden is not None:
116
+ if self.right_frame > 0:
117
+ hidden_cnn_buffer = buffer[
118
+ buffer_index : buffer_index + self.hidden_buffer_size
119
+ ].reshape([1, self.hidden_dim, self.hidden_chache_size])
120
+ hidden = torch.cat([hidden_cnn_buffer, hidden], dim=2)
121
+ # buffer[buffer_index: buffer_index + self.hidden_buffer_size] = hidden[:, :, -self.hidden_chache_size:].reshape(-1)
122
+ buffer_out.append(hidden[:, :, -self.hidden_chache_size :].reshape(-1))
123
+ buffer_index = buffer_index + self.hidden_buffer_size
124
+ hidden = hidden[:, :, : -self.hidden_chache_size]
125
+ p_out = hidden + p_out
126
+
127
+ out = F.relu(self.conv_out(p_out))
128
+
129
+ return out, buffer, buffer_index, buffer_out, p_out
vita/model/multimodal_encoder/whale/utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import importlib
3
+ import json
4
+ import os
5
+ from distutils.util import strtobool as dist_strtobool
6
+
7
+ import torch
8
+ import yaml
9
+
10
+ IGNORE_ID = -1
11
+
12
+
13
+ def assign_args_from_yaml(args, yaml_path, prefix_key=None):
14
+ with open(yaml_path) as f:
15
+ ydict = yaml.load(f, Loader=yaml.FullLoader)
16
+ if prefix_key is not None:
17
+ ydict = ydict[prefix_key]
18
+ for k, v in ydict.items():
19
+ k_args = k.replace("-", "_")
20
+ if hasattr(args, k_args):
21
+ setattr(args, k_args, ydict[k])
22
+ return args
23
+
24
+
25
+ def get_model_conf(model_path):
26
+ model_conf = os.path.dirname(model_path) + "/model.json"
27
+ with open(model_conf, "rb") as f:
28
+ print("reading a config file from " + model_conf)
29
+ confs = json.load(f)
30
+ # for asr, tts, mt
31
+ idim, odim, args = confs
32
+ return argparse.Namespace(**args)
33
+
34
+
35
+ def strtobool(x):
36
+ return bool(dist_strtobool(x))
37
+
38
+
39
+ def dynamic_import(import_path, alias=dict()):
40
+ """dynamic import module and class
41
+
42
+ :param str import_path: syntax 'module_name:class_name'
43
+ e.g., 'espnet.transform.add_deltas:AddDeltas'
44
+ :param dict alias: shortcut for registered class
45
+ :return: imported class
46
+ """
47
+ if import_path not in alias and ":" not in import_path:
48
+ raise ValueError(
49
+ "import_path should be one of {} or "
50
+ 'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : '
51
+ "{}".format(set(alias), import_path)
52
+ )
53
+ if ":" not in import_path:
54
+ import_path = alias[import_path]
55
+
56
+ module_name, objname = import_path.split(":")
57
+ m = importlib.import_module(module_name)
58
+ return getattr(m, objname)
59
+
60
+
61
+ def set_deterministic_pytorch(args):
62
+ # seed setting
63
+ torch.manual_seed(args.seed)
64
+
65
+ torch.backends.cudnn.deterministic = False
66
+ torch.backends.cudnn.benchmark = False
67
+
68
+
69
+ def pad_list(xs, pad_value):
70
+ n_batch = len(xs)
71
+ max_len = max(x.size(0) for x in xs)
72
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
73
+ for i in range(n_batch):
74
+ pad[i, : xs[i].size(0)] = xs[i]
75
+ return pad
76
+
77
+
78
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
79
+ batch_size = lengths.size(0)
80
+ max_len = max_len if max_len > 0 else lengths.max().item()
81
+ seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
82
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
83
+ seq_length_expand = lengths.unsqueeze(-1)
84
+ mask = seq_range_expand >= seq_length_expand
85
+ return mask
86
+
87
+
88
+ def subsequent_chunk_mask(
89
+ size: int,
90
+ ck_size: int,
91
+ num_l_cks: int = -1,
92
+ device: torch.device = torch.device("cpu"),
93
+ ) -> torch.Tensor:
94
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
95
+ for i in range(size):
96
+ if num_l_cks < 0:
97
+ start = 0
98
+ else:
99
+ start = max((i // ck_size - num_l_cks) * ck_size, 0)
100
+ ending = min((i // ck_size + 1) * ck_size, size)
101
+ ret[i, start:ending] = True
102
+ return ret
103
+
104
+
105
+ def add_optional_chunk_mask(
106
+ xs: torch.Tensor,
107
+ masks: torch.Tensor,
108
+ use_dynamic_chunk: bool,
109
+ use_dynamic_left_chunk: bool,
110
+ decoding_chunk_size: int,
111
+ static_chunk_size: int,
112
+ num_decoding_left_chunks: int,
113
+ ):
114
+ if use_dynamic_chunk:
115
+ max_len = xs.size(1)
116
+ if decoding_chunk_size < 0:
117
+ chunk_size = max_len
118
+ num_l_cks = -1
119
+ elif decoding_chunk_size > 0:
120
+ chunk_size = decoding_chunk_size
121
+ num_l_cks = num_decoding_left_chunks
122
+ else:
123
+ chunk_size = torch.randint(1, max_len, (1,)).item()
124
+ num_l_cks = -1
125
+ if chunk_size > max_len // 2:
126
+ chunk_size = max_len
127
+ else:
128
+ chunk_size = chunk_size % 25 + 1
129
+ if use_dynamic_left_chunk:
130
+ max_left_chunks = (max_len - 1) // chunk_size
131
+ num_l_cks = torch.randint(0, max_left_chunks, (1,)).item()
132
+ ck_masks = subsequent_chunk_mask(
133
+ xs.size(1), chunk_size, num_l_cks, xs.device
134
+ ) # (L, L)
135
+ ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
136
+ ck_masks = masks & ck_masks # (B, L, L)
137
+ elif static_chunk_size > 0:
138
+ num_l_cks = num_decoding_left_chunks
139
+ ck_masks = subsequent_chunk_mask(
140
+ xs.size(1), static_chunk_size, num_l_cks, xs.device
141
+ ) # (L, L)
142
+ ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
143
+ ck_masks = masks & ck_masks # (B, L, L)
144
+ else:
145
+ ck_masks = masks
146
+ return ck_masks
vita/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ from functools import partial
4
+
5
+ from torch import nn
6
+
7
+ from timm.layers.norm_act import LayerNormAct2d
8
+ from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig
9
+ from torchvision.ops.misc import SqueezeExcitation as SELayer
10
+
11
+
12
+ class IdentityMap(nn.Module):
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ def forward(self, x, *args, **kwargs):
17
+ return x
18
+
19
+ @property
20
+ def config(self):
21
+ return {"mm_projector_type": "identity"}
22
+
23
+
24
+ class Minigpt(nn.Module):
25
+ def __init__(self, config=None):
26
+ super(Minigpt, self).__init__()
27
+ # c*4 is the input size, and c is the output size for the linear layer
28
+ inc, ouc = config.mm_hidden_size, config.hidden_size
29
+ self.linear = nn.Linear(inc * 4, ouc)
30
+
31
+ def forward(self, x):
32
+ # x is the input tensor with shape [b, num_tokens, c]
33
+ b, num_tokens, c = x.shape
34
+
35
+ # Check if num_tokens is divisible by 4
36
+ if num_tokens % 4 != 0:
37
+ raise ValueError("num_tokens must be divisible by 4")
38
+
39
+ # Reshape x to [b, num_tokens/4, c*4]
40
+ x = x.view(b, num_tokens // 4, c * 4)
41
+
42
+ # Apply the linear transformation
43
+ x = self.linear(x)
44
+ return x
45
+
46
+
47
+ class Vanilla(nn.Module):
48
+ def __init__(self, config=None):
49
+ super(Vanilla, self).__init__()
50
+ # c*4 is the input size, and c is the output size for the linear layer
51
+ inc, ouc = config.mm_hidden_size, config.hidden_size
52
+ self.linear = nn.Linear(inc * 4, ouc)
53
+
54
+ def forward(self, x):
55
+ b, num_tokens, c = x.shape
56
+
57
+ # Check if num_tokens is divisible by 4
58
+ if num_tokens % 4 != 0:
59
+ raise ValueError("num_tokens must be divisible by 4")
60
+
61
+ # First, reshape to [b, num_tokens//4, 4, c]
62
+ x = x.view(b, num_tokens // 4, 4, c)
63
+
64
+ # Then, permute to interleave the tokens
65
+ x = x.permute(0, 1, 3, 2).contiguous()
66
+
67
+ # Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens
68
+ x = x.view(b, num_tokens // 4, c * 4)
69
+
70
+ # Apply the linear transformation
71
+ x = self.linear(x)
72
+ return x
73
+
74
+
75
+ class LDPBlock(nn.Module):
76
+ # Lightweight Downsample Projector Block
77
+
78
+ def __init__(self, config=None):
79
+ super().__init__()
80
+
81
+ inc, ouc = config.mm_hidden_size, config.hidden_size
82
+ layer_norm = partial(LayerNormAct2d, act_layer=None)
83
+ se_layer = partial(SELayer, scale_activation=nn.Hardsigmoid)
84
+ self.mlp = nn.Sequential(nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc))
85
+ self.mb_block = nn.Sequential(
86
+ nn.Identity(),
87
+ InvertedResidual(
88
+ InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer
89
+ ),
90
+ InvertedResidual(
91
+ InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer
92
+ ),
93
+ )
94
+
95
+ def forward(self, x):
96
+ b, num_tokens, c = x.shape
97
+ h = int(math.sqrt(num_tokens))
98
+ x = self.mlp(x)
99
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
100
+ x = self.mb_block(x)
101
+ x = x.flatten(2).permute(0, 2, 1)
102
+ return x
103
+
104
+
105
+ class LDPNetProjector(nn.Module):
106
+ def __init__(self, config=None):
107
+ super().__init__()
108
+ self.model = LDPBlock(config)
109
+
110
+ def forward(self, x):
111
+ return self.model(x)
112
+
113
+
114
+ class SPP(nn.Module):
115
+ def __init__(self, config=None, projector_type="v1"):
116
+ super().__init__()
117
+
118
+ self.projector_type = projector_type
119
+
120
+ inc, ouc = config.mm_hidden_size, config.hidden_size
121
+ self.linear_0 = nn.Linear(inc, inc)
122
+
123
+ self.linear_1 = nn.Linear(inc, ouc)
124
+
125
+ self.pooling = nn.AvgPool2d(kernel_size=2)
126
+
127
+ self.linear_2 = nn.Linear(ouc, ouc)
128
+
129
+ def forward(self, x):
130
+ b, num_tokens, c = x.shape
131
+ h = int(math.sqrt(num_tokens))
132
+ if "v1" in self.projector_type:
133
+ x = self.linear_1(x)
134
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
135
+ x = self.pooling(x)
136
+ x = x.flatten(2).permute(0, 2, 1)
137
+ x = self.linear_2(x)
138
+ elif "v2" in self.projector_type:
139
+ x = self.linear_1(x)
140
+ x = self.linear_2(x)
141
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
142
+ x = self.pooling(x)
143
+ x = x.flatten(2).permute(0, 2, 1)
144
+ elif "v3" in self.projector_type:
145
+ x = self.linear_0(x)
146
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
147
+ x = self.pooling(x)
148
+ x = x.flatten(2).permute(0, 2, 1)
149
+ x = self.linear_1(x)
150
+ x = self.linear_2(x)
151
+ return x
152
+
153
+
154
+ def build_vision_projector(config, delay_load=False, **kwargs):
155
+ projector_type = getattr(config, "mm_projector_type", "mlp2x_gelu")
156
+
157
+ if projector_type == "linear":
158
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
159
+
160
+ elif projector_type.startswith("mlp"):
161
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
162
+ if mlp_gelu_match:
163
+ mlp_depth = int(mlp_gelu_match.group(1))
164
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
165
+ for _ in range(1, mlp_depth):
166
+ modules.append(nn.GELU())
167
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
168
+ return nn.Sequential(*modules)
169
+
170
+ elif projector_type.startswith("spp"):
171
+ return SPP(config, projector_type)
172
+
173
+ elif projector_type == "ldp":
174
+ return LDPNetProjector(config)
175
+
176
+ elif projector_type == "vanilla":
177
+ return Vanilla(config)
178
+
179
+ elif projector_type == "minigpt":
180
+ return Minigpt(config)
181
+
182
+ elif projector_type == "identity":
183
+ return IdentityMap()
184
+
185
+ raise ValueError(f"Unknown projector type: {projector_type}")
vita/model/vita_arch.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from vita.constants import AUDIO_TOKEN_INDEX, IGNORE_INDEX, IMAGE_TOKEN_INDEX
8
+
9
+ from .multimodal_encoder.builder import build_audio_encoder, build_vision_tower
10
+ from .multimodal_projector.builder import build_vision_projector
11
+ import numpy as np
12
+
13
+ class VITAMetaModel:
14
+ def __init__(self, config):
15
+ super(VITAMetaModel, self).__init__(config)
16
+
17
+ if hasattr(config, "mm_vision_tower"):
18
+ self.vision_tower = build_vision_tower(
19
+ config, delay_load=False#not getattr(config, "continuous_training", False)
20
+ )
21
+ if getattr(config, "continuous_training", False):
22
+ config.continuous_training = False
23
+ self.mm_projector = build_vision_projector(config)
24
+
25
+ if hasattr(config, "mm_audio_encoder"):
26
+ self.audio_encoder = build_audio_encoder(config)
27
+
28
+ def get_vision_tower(self):
29
+ vision_tower = getattr(self, "vision_tower", None)
30
+ if type(vision_tower) is list:
31
+ vision_tower = vision_tower[0]
32
+ return vision_tower
33
+
34
+ def get_audio_encoder(self):
35
+ audio_encoder = getattr(self, "audio_encoder", None)
36
+ return audio_encoder
37
+
38
+ def initialize_vision_modules(self, model_args):
39
+ vision_tower = model_args.vision_tower
40
+
41
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
42
+
43
+ self.config.mm_vision_tower = vision_tower
44
+
45
+ if self.get_vision_tower() is None:
46
+ vision_tower = build_vision_tower(model_args)
47
+ self.vision_tower = vision_tower
48
+ else:
49
+ vision_tower = self.vision_tower
50
+ #vision_tower.load_model()
51
+
52
+ self.config.use_mm_proj = True
53
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type")
54
+ self.config.mm_hidden_size = vision_tower.hidden_size
55
+
56
+ if getattr(self, "mm_projector", None) is None:
57
+ self.mm_projector = build_vision_projector(self.config)
58
+ else:
59
+ # In case it is frozen by LoRA
60
+ for p in self.mm_projector.parameters():
61
+ p.requires_grad = True
62
+
63
+ if pretrain_mm_mlp_adapter is not None:
64
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
65
+
66
+ def get_w(weights, keyword):
67
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
68
+
69
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
70
+
71
+ def initialize_audio_modules(self, model_args):
72
+ audio_encoder = model_args.audio_encoder
73
+
74
+ pretrain_audio_mlp_adapter = model_args.pretrain_audio_mlp_adapter
75
+
76
+ setattr(self.config, "mm_audio_encoder", audio_encoder)
77
+
78
+ audio_encoder = build_audio_encoder(self.config)
79
+ self.audio_encoder = audio_encoder
80
+
81
+ load_audio_ckpt_from_mllm = True
82
+ if load_audio_ckpt_from_mllm:
83
+ from safetensors.torch import load_file
84
+ import os
85
+ audio_weights = {}
86
+ for file_name in os.listdir(model_args.model_name_or_path):
87
+ if file_name.endswith('safetensors'):
88
+ audio_weights.update(
89
+ {k[20:]: v for k, v in load_file(os.path.join(model_args.model_name_or_path, file_name)).items() if
90
+ k.startswith('model.audio_encoder.')})
91
+ self.audio_encoder.load_state_dict(audio_weights, strict=True)
92
+
93
+ #load_audio_ckpt = True
94
+ #if self.get_audio_encoder() is None or load_audio_ckpt or model_args.audio_prompt_finetune:
95
+ # audio_encoder = build_audio_encoder(self.config)
96
+ # self.audio_encoder = audio_encoder
97
+
98
+ #load_audio_prompt_weight = False #True
99
+ #if load_audio_prompt_weight:
100
+ # from safetensors.torch import load_file
101
+ # import os
102
+ # audio_weights = {}
103
+ # for file_name in os.listdir(model_args.model_name_or_path):
104
+ # if file_name.endswith('safetensors'):
105
+ # audio_weights.update(
106
+ # {k[38:]: v for k, v in load_file(os.path.join(model_args.model_name_or_path, file_name)).items() if
107
+ # k.startswith('model.audio_encoder.prompt_embeddings')})
108
+ # self.audio_encoder.prompt_embeddings.load_state_dict(audio_weights, strict=True)
109
+
110
+ #checkpoint = torch.load(model_args.audio_encoder + "/final.pt", map_location="cpu")
111
+ #model_dict = self.audio_encoder.state_dict()
112
+ #for key in model_dict.keys():
113
+ # if key in checkpoint.keys():
114
+ # if model_dict[key].shape == checkpoint[key].shape:
115
+ # model_dict[key] = checkpoint[key]
116
+ # else:
117
+ # print(
118
+ # "Key {} has different shape, {} VS {}".format(
119
+ # key, model_dict[key].shape, checkpoint[key].shape
120
+ # )
121
+ # )
122
+ # else:
123
+ # print("Key {} has not in resume model".format(key))
124
+ #self.audio_encoder.load_state_dict(model_dict)
125
+
126
+ if pretrain_audio_mlp_adapter is not None:
127
+ audio_projector_weights = torch.load(pretrain_audio_mlp_adapter, map_location="cpu")
128
+
129
+ def get_w(weights, keyword):
130
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
131
+
132
+ self.audio_encoder.adpter.load_state_dict(get_w(audio_projector_weights, "audio_encoder.adpter"))
133
+
134
+
135
+ class VITAMetaForCausalLM(ABC):
136
+ @abstractmethod
137
+ def get_model(self):
138
+ pass
139
+
140
+ def get_vision_tower(self):
141
+ return self.get_model().get_vision_tower()
142
+
143
+ def get_audio_encoder(self):
144
+ return self.get_model().get_audio_encoder()
145
+
146
+ def pool_feats(self, x, out_size):
147
+ ndim = x.ndim
148
+ if ndim == 2:
149
+ x = x.unsqueeze(0)
150
+ b, num_tokens, c = x.shape
151
+ h = int(math.sqrt(num_tokens))
152
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
153
+ x = F.interpolate(x, size=out_size, mode='bilinear', align_corners=False)
154
+ num_tokens = x.shape[2] * x.shape[3] # Recalculate the number of tokens after pooling
155
+ x = x.reshape(b, c, num_tokens).permute(0, 2, 1)
156
+ if ndim == 2:
157
+ x = x.squeeze(0)
158
+ return x
159
+
160
+ def encode_images(self, images):
161
+ image_features = self.get_model().get_vision_tower()(images)
162
+ #image_features = self.pool_feats(image_features)
163
+ image_features = self.get_model().mm_projector(image_features)
164
+ return image_features
165
+
166
+ def encode_images_frameCat(self, images):
167
+ image_features = self.get_model().get_vision_tower()(images)
168
+ assert len(image_features) % 5 == 0
169
+
170
+ concatenated_features = []
171
+ for i in range(0, len(image_features), 5):
172
+ tensors_to_concat = [image_features[j] for j in range(i, i + 5)]
173
+ concatenated_tensor = torch.cat(tensors_to_concat, dim=-1)
174
+ concatenated_features.append(concatenated_tensor)
175
+ concatenated_features = torch.stack(concatenated_features)
176
+ image_features = concatenated_features
177
+
178
+ image_features = self.get_model().mm_projector(image_features)
179
+ return image_features
180
+
181
+ def slow_fast_pooling0(self, temp_img_feats):
182
+ num_frame = len(temp_img_feats)
183
+ if num_frame <= 30:
184
+ slow_token_num = max([e for e in [256, 225, 196, 169] if e <= 5200/num_frame])
185
+ fast_token_num = slow_token_num
186
+ elif num_frame <= 45:
187
+ slow_token_num = 169
188
+ fast_token_num = 81
189
+ elif num_frame <= 64:
190
+ slow_token_num = 169
191
+ fast_token_num = 49
192
+ else:
193
+ raise ValueError("The number of frames is too large!")
194
+
195
+ if num_frame <= 30:
196
+ num_slow = num_frame
197
+ else:
198
+ num_slow = int((5200 - fast_token_num * num_frame) / (slow_token_num - fast_token_num))
199
+ num_fast = num_frame - num_slow
200
+ slow_index = list(np.linspace(0, num_frame, num=num_slow, dtype=int))
201
+
202
+ new_img_feats = []
203
+ for i, feat in enumerate(temp_img_feats):
204
+ if i in slow_index:
205
+ sqrt_len = int(math.sqrt(slow_token_num))
206
+ else:
207
+ sqrt_len = int(math.sqrt(fast_token_num))
208
+ if sqrt_len != 16:
209
+ feat = self.pool_feats(feat, out_size=(sqrt_len, sqrt_len))
210
+ new_img_feats.append(feat)
211
+
212
+ return new_img_feats
213
+
214
+ def slow_fast_pooling1(self, temp_img_feats):
215
+ num_frame = len(temp_img_feats)
216
+ if num_frame <= 28:
217
+ slow_token_num = max([e for e in [256, 225, 196, 169, 144] if e <= 4096/num_frame])
218
+ fast_token_num = slow_token_num
219
+ elif num_frame <= 40:
220
+ slow_token_num = 144
221
+ fast_token_num = 81
222
+ elif num_frame <= 64:
223
+ slow_token_num = 144
224
+ fast_token_num = 49
225
+ else:
226
+ raise ValueError("The number of frames is too large!")
227
+
228
+ if num_frame <= 28:
229
+ num_slow = num_frame
230
+ else:
231
+ num_slow = int((4096 - fast_token_num * num_frame) / (slow_token_num - fast_token_num))
232
+ num_fast = num_frame - num_slow
233
+ slow_index = list(np.linspace(0, num_frame, num=num_slow, dtype=int))
234
+
235
+ new_img_feats = []
236
+ for i, feat in enumerate(temp_img_feats):
237
+ if i in slow_index:
238
+ sqrt_len = int(math.sqrt(slow_token_num))
239
+ else:
240
+ sqrt_len = int(math.sqrt(fast_token_num))
241
+ if sqrt_len != 16:
242
+ feat = self.pool_feats(feat, out_size=(sqrt_len, sqrt_len))
243
+ new_img_feats.append(feat)
244
+
245
+ return new_img_feats
246
+
247
+ def slow_fast_pooling(self, temp_img_feats):
248
+ num_frame = len(temp_img_feats)
249
+ slow_token_num = 144
250
+ fast_token_num = 49
251
+
252
+ slow_index = list(range(0, num_frame, 4))
253
+
254
+ new_img_feats = []
255
+ for i, feat in enumerate(temp_img_feats):
256
+ if i in slow_index:
257
+ sqrt_len = int(math.sqrt(slow_token_num))
258
+ else:
259
+ sqrt_len = int(math.sqrt(fast_token_num))
260
+ if sqrt_len != 16:
261
+ feat = self.pool_feats(feat, out_size=(sqrt_len, sqrt_len))
262
+ new_img_feats.append(feat)
263
+
264
+ return new_img_feats
265
+
266
+ def slow_fast_pooling3(self, temp_img_feats):
267
+ num_frame = len(temp_img_feats)
268
+ slow_token_num = 144
269
+ fast_token_num = 36
270
+
271
+ slow_index = list(range(0, num_frame, 16))
272
+
273
+ new_img_feats = []
274
+ for i, feat in enumerate(temp_img_feats):
275
+ if i in slow_index:
276
+ sqrt_len = int(math.sqrt(slow_token_num))
277
+ else:
278
+ sqrt_len = int(math.sqrt(fast_token_num))
279
+ if sqrt_len != 16:
280
+ feat = self.pool_feats(feat, out_size=(sqrt_len, sqrt_len))
281
+ new_img_feats.append(feat)
282
+
283
+ return new_img_feats
284
+
285
+ def slow_fast(self, image_features, sf_masks):
286
+ new_image_features = []
287
+ temp_img_feats = [] # 初始化 temp_img_feats 在循环外
288
+ for i, img_feat in enumerate(image_features):
289
+ if i == 0 or sf_masks[i] != sf_masks[i-1]:
290
+ if temp_img_feats: # 如果 temp_img_feats 不为空,则添加到 new_image_features
291
+ if sf_masks[i-1] > 0:
292
+ temp_img_feats = self.slow_fast_pooling(temp_img_feats)
293
+ new_image_features.append(temp_img_feats)
294
+ temp_img_feats = [img_feat] # 重新初始化 temp_img_feats
295
+ else:
296
+ temp_img_feats.append(img_feat)
297
+ if temp_img_feats: # 处理最后一个子列表
298
+ if sf_masks[-1] > 0:
299
+ temp_img_feats = self.slow_fast_pooling(temp_img_feats)
300
+ new_image_features.append(temp_img_feats)
301
+
302
+ output_features = []
303
+ for e in new_image_features:
304
+ output_features += e
305
+
306
+ return output_features
307
+
308
+ def prepare_inputs_labels_for_multimodal(
309
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images, audios, sf_masks, shared_v_pid_stride=None
310
+ ):
311
+ vision_tower = self.get_vision_tower()
312
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
313
+ if (
314
+ past_key_values is not None
315
+ and vision_tower is not None
316
+ and images is not None
317
+ and input_ids.shape[1] == 1
318
+ ):
319
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
320
+ attention_mask = torch.cat(
321
+ (
322
+ attention_mask,
323
+ torch.ones(
324
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
325
+ dtype=attention_mask.dtype,
326
+ device=attention_mask.device,
327
+ ),
328
+ ),
329
+ dim=1,
330
+ )
331
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
332
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
333
+
334
+ if type(images) is list or images.ndim == 5:
335
+ concat_images = torch.cat([image for image in images], dim=0)
336
+ image_features = self.encode_images(concat_images)
337
+ split_sizes = [image.shape[0] for image in images]
338
+ image_features = torch.split(image_features, split_sizes, dim=0)
339
+ image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
340
+ else:
341
+ image_features = self.encode_images(images).to(self.device)
342
+
343
+ image_features = [e for e in image_features]
344
+ if sf_masks is not None:
345
+ assert len(image_features) == len(sf_masks)
346
+ image_features = self.slow_fast(image_features, sf_masks)
347
+
348
+ audio_encoder = self.get_audio_encoder()
349
+ if audios is not None:
350
+ audio_features = audio_encoder(audios["audios"], audios["lengths"])
351
+ state_labels = audios.get("state_labels", None)
352
+ lengths_for_llm = audios["lengths_for_llm"]
353
+ if state_labels is not None:
354
+ assert len(audio_features["inputs_embeds"]) == len(state_labels) == len(lengths_for_llm)
355
+ else:
356
+ audio_features, state_labels, lengths_for_llm = None, None, None
357
+
358
+ # Let's just add dummy tensors if they do not exist,
359
+ # it is a headache to deal with None all the time.
360
+ # But it is not ideal, and if you have a better idea,
361
+ # please open an issue / submit a PR, thanks.
362
+ _labels = labels
363
+ _position_ids = position_ids
364
+ _attention_mask = attention_mask
365
+ if attention_mask is None:
366
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
367
+ else:
368
+ attention_mask = attention_mask.bool()
369
+ if position_ids is None:
370
+ position_ids = torch.arange(
371
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
372
+ )
373
+ if labels is None:
374
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
375
+
376
+ # remove the padding using attention_mask -- TODO: double check
377
+ input_ids = [
378
+ cur_input_ids[cur_attention_mask]
379
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
380
+ ]
381
+ labels = [
382
+ cur_labels[cur_attention_mask]
383
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
384
+ ]
385
+
386
+ new_input_embeds = []
387
+ new_labels = []
388
+ v_start_end = []
389
+ cur_image_idx = 0
390
+ cur_audio_idx = 0
391
+ assert (
392
+ sum([(cur == IMAGE_TOKEN_INDEX).sum() for cur in input_ids])
393
+ + sum([(IMAGE_TOKEN_INDEX not in cur) for cur in input_ids])
394
+ == len(image_features)
395
+ ), input_ids
396
+ assert (
397
+ sum([(cur == AUDIO_TOKEN_INDEX).sum() for cur in input_ids])
398
+ + sum([(AUDIO_TOKEN_INDEX not in cur) for cur in input_ids])
399
+ == audio_features["inputs_embeds"].shape[0]
400
+ ), input_ids
401
+
402
+ for batch_idx, cur_input_ids in enumerate(input_ids):
403
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
404
+ num_audio_frames = (cur_input_ids == AUDIO_TOKEN_INDEX).sum()
405
+ if num_images == 0 and num_audio_frames == 0:
406
+ cur_image_features = image_features[cur_image_idx]
407
+ cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
408
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
409
+ cur_input_embeds = torch.cat(
410
+ [cur_input_embeds_1, cur_image_features[0:0], cur_audio_features[0:0]], dim=0
411
+ )
412
+ new_input_embeds.append(cur_input_embeds)
413
+ new_labels.append(labels[batch_idx])
414
+ cur_image_idx += 1
415
+ cur_audio_idx += 1
416
+ continue
417
+
418
+ image_audio_token_indices = (
419
+ [-1]
420
+ + torch.where(
421
+ (cur_input_ids == IMAGE_TOKEN_INDEX) | (cur_input_ids == AUDIO_TOKEN_INDEX)
422
+ )[0].tolist()
423
+ + [cur_input_ids.shape[0]]
424
+ )
425
+ cur_input_ids_noim_noau = []
426
+ cur_labels = labels[batch_idx]
427
+ cur_labels_noim_noau = []
428
+ for i in range(len(image_audio_token_indices) - 1):
429
+ cur_input_ids_noim_noau.append(
430
+ cur_input_ids[
431
+ image_audio_token_indices[i] + 1 : image_audio_token_indices[i + 1]
432
+ ]
433
+ )
434
+ cur_labels_noim_noau.append(
435
+ cur_labels[image_audio_token_indices[i] + 1 : image_audio_token_indices[i + 1]]
436
+ )
437
+
438
+ split_sizes = [x.shape[0] for x in cur_labels_noim_noau]
439
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim_noau))
440
+ cur_input_embeds_no_im_no_au = torch.split(cur_input_embeds, split_sizes, dim=0)
441
+ cur_new_input_embeds = []
442
+ cur_new_labels = []
443
+ cur_v_start_end = []
444
+ for i in range(num_images + num_audio_frames + 1):
445
+ cur_new_input_embeds.append(cur_input_embeds_no_im_no_au[i])
446
+ cur_new_labels.append(cur_labels_noim_noau[i])
447
+ if i < num_images + num_audio_frames:
448
+ if cur_input_ids[image_audio_token_indices[i + 1]] == IMAGE_TOKEN_INDEX:
449
+ cur_image_features = image_features[cur_image_idx]
450
+ cur_image_idx += 1
451
+ cur_new_input_embeds.append(cur_image_features)
452
+ cur_new_labels.append(
453
+ torch.full(
454
+ (cur_image_features.shape[0],),
455
+ IGNORE_INDEX,
456
+ device=cur_labels.device,
457
+ dtype=cur_labels.dtype,
458
+ )
459
+ )
460
+ if shared_v_pid_stride:
461
+ start = sum([x.shape[0] for x in cur_new_labels[:-1]])
462
+ end = start + cur_new_labels[-1].shape[0]
463
+ cur_v_start_end.append((start, end))
464
+ elif cur_input_ids[image_audio_token_indices[i + 1]] == AUDIO_TOKEN_INDEX:
465
+ cur_lengths_for_llm = lengths_for_llm[cur_audio_idx]
466
+ cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
467
+ if getattr(self.config, "audio_prompt_num", None):#self.config.audio_prompt_num:
468
+ cur_lengths_for_llm = cur_lengths_for_llm + self.config.audio_prompt_num
469
+ cur_audio_features = cur_audio_features[:cur_lengths_for_llm]
470
+ if state_labels is not None:
471
+ cur_state_label = state_labels[cur_audio_idx]
472
+ cur_audio_idx += 1
473
+ cur_new_input_embeds.append(cur_audio_features)
474
+ cur_new_labels.append(
475
+ torch.full(
476
+ (cur_audio_features.shape[0],),
477
+ IGNORE_INDEX,
478
+ device=cur_labels.device,
479
+ dtype=cur_labels.dtype,
480
+ )
481
+ )
482
+ if state_labels is not None:
483
+ cur_new_labels[-1][-1] = cur_state_label
484
+ else:
485
+ raise ValueError
486
+
487
+ if num_images != 0 and num_audio_frames == 0:
488
+ cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
489
+ cur_audio_idx += 1
490
+ cur_new_input_embeds.append(cur_audio_features[0:0])
491
+ elif num_images == 0 and num_audio_frames != 0:
492
+ cur_image_features = image_features[cur_image_idx]
493
+ cur_image_idx += 1
494
+ cur_new_input_embeds.append(cur_image_features[0:0])
495
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
496
+ cur_new_labels = torch.cat(cur_new_labels)
497
+
498
+ new_input_embeds.append(cur_new_input_embeds)
499
+ new_labels.append(cur_new_labels)
500
+
501
+ if shared_v_pid_stride:
502
+ cur_v_start_end = merge_consecutive_tuples(cur_v_start_end)
503
+ v_start_end.append(cur_v_start_end)
504
+
505
+ assert cur_image_idx == len(image_features)
506
+ assert cur_audio_idx == audio_features["inputs_embeds"].shape[0]
507
+ if state_labels is not None:
508
+ assert cur_audio_idx == len(state_labels)
509
+ if state_labels is not None:
510
+ assert (
511
+ sum([(cur == AUDIO_TOKEN_INDEX).sum() for cur in input_ids])
512
+ == sum([(cur == -101).sum() for cur in new_labels]) + sum([(cur == -102).sum() for cur in new_labels])
513
+ ), (input_ids, sum([(cur == AUDIO_TOKEN_INDEX).sum() for cur in input_ids]), sum([(cur == -101).sum() for cur in new_labels]), sum([(cur == -102).sum() for cur in new_labels]), new_labels.shape)
514
+
515
+ # Truncate sequences to max length as image embeddings can make the sequence longer
516
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
517
+ if tokenizer_model_max_length is not None:
518
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
519
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
520
+
521
+ # Combine them
522
+ max_len = max(x.shape[0] for x in new_input_embeds)
523
+ batch_size = len(new_input_embeds)
524
+
525
+ new_input_embeds_padded = []
526
+ new_labels_padded = torch.full(
527
+ (batch_size, max_len),
528
+ IGNORE_INDEX,
529
+ dtype=new_labels[0].dtype,
530
+ device=new_labels[0].device,
531
+ )
532
+ attention_mask = torch.zeros(
533
+ (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device
534
+ )
535
+ position_ids = torch.zeros(
536
+ (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
537
+ )
538
+
539
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
540
+ cur_len = cur_new_embed.shape[0]
541
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
542
+ new_input_embeds_padded.append(
543
+ torch.cat(
544
+ (
545
+ torch.zeros(
546
+ (max_len - cur_len, cur_new_embed.shape[1]),
547
+ dtype=cur_new_embed.dtype,
548
+ device=cur_new_embed.device,
549
+ ),
550
+ cur_new_embed,
551
+ ),
552
+ dim=0,
553
+ )
554
+ )
555
+ if cur_len > 0:
556
+ new_labels_padded[i, -cur_len:] = cur_new_labels
557
+ attention_mask[i, -cur_len:] = True
558
+ position_ids[i, -cur_len:] = torch.arange(
559
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
560
+ )
561
+ else:
562
+ new_input_embeds_padded.append(
563
+ torch.cat(
564
+ (
565
+ cur_new_embed,
566
+ torch.zeros(
567
+ (max_len - cur_len, cur_new_embed.shape[1]),
568
+ dtype=cur_new_embed.dtype,
569
+ device=cur_new_embed.device,
570
+ ),
571
+ ),
572
+ dim=0,
573
+ )
574
+ )
575
+ if cur_len > 0:
576
+ new_labels_padded[i, :cur_len] = cur_new_labels
577
+ attention_mask[i, :cur_len] = True
578
+ if shared_v_pid_stride is None:
579
+ position_ids[i, :cur_len] = torch.arange(
580
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
581
+ )
582
+ else:
583
+ cur_v_start_end = v_start_end[i]
584
+ cur_shared_position_ids = make_shared_position_ids(cur_v_start_end, cur_len, shared_v_pid_stride)
585
+ position_ids[i, :cur_len] = cur_shared_position_ids
586
+
587
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
588
+
589
+ if _labels is None:
590
+ new_labels = None
591
+ else:
592
+ new_labels = new_labels_padded
593
+
594
+ if _attention_mask is None:
595
+ attention_mask = None
596
+ else:
597
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
598
+
599
+ if _position_ids is None and shared_v_pid_stride is None:
600
+ position_ids = None
601
+
602
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
603
+
604
+
605
+ def merge_consecutive_tuples(tuples_list):
606
+ if not tuples_list:
607
+ return []
608
+
609
+ # 首先对列表按照起点索引进行排序
610
+ sorted_tuples = sorted(tuples_list, key=lambda x: x[0])
611
+
612
+ # 初始化合并后的列表
613
+ merged_tuples = [sorted_tuples[0]]
614
+
615
+ for current_start, current_end in sorted_tuples[1:]:
616
+ last_merged_start, last_merged_end = merged_tuples[-1]
617
+ if current_start <= last_merged_end: # 如果当前元组的起点小于等于上一个合并元组的终点
618
+ # 合并这两个元组
619
+ new_start, new_end = merged_tuples[-1][0], max(last_merged_end, current_end)
620
+ merged_tuples[-1] = (new_start, new_end)
621
+ else:
622
+ # 如果当前元组不连续,直接添加到合并后的列表中
623
+ merged_tuples.append((current_start, current_end))
624
+
625
+ return merged_tuples
626
+
627
+
628
+ def make_shared_position_ids(cur_v_start_end, cur_len, shared_v_pid_stride):
629
+ position_ids = torch.tensor([1.0] * cur_len)
630
+
631
+ for start, end in cur_v_start_end:
632
+ position_ids[start:end] = 1/shared_v_pid_stride
633
+ v_mod = (end - start) % shared_v_pid_stride
634
+ if v_mod != 0:
635
+ position_ids[end-v_mod:end] = 1 / v_mod
636
+ position_ids = position_ids.cumsum(dim=0)
637
+ position_ids = torch.ceil(position_ids).long() - 1
638
+
639
+ return position_ids
vita/model/vita_tts/adapter.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import copy
4
+ import re
5
+
6
+ from torch import nn
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ import torch.nn.functional as F
9
+
10
+ class CNNAdapter(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ enc_out_dim: int = 512,
14
+ llm_embed_dim: int = 4096,
15
+ kernel_size: int = 5,
16
+ ):
17
+ super().__init__()
18
+
19
+ self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
20
+ self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
21
+
22
+ self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
23
+ self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 1, 0)
24
+
25
+ self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
26
+ self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
27
+
28
+ self.relu1 = nn.ReLU()
29
+ self.relu2 = nn.ReLU()
30
+
31
+ self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
32
+
33
+ def forward(self, x, mask_pad):
34
+ """
35
+ x: B, T, enc_out_dim
36
+ mask: (B, T) or (B, 1, T)
37
+ """
38
+ x = x.transpose(1, 2) # B, channels, T
39
+
40
+ # mask batch padding
41
+ if mask_pad.size(2) > 0: # time > 0
42
+ x.masked_fill_(~mask_pad, 0.0)
43
+
44
+ x = self.left_padding1(x)
45
+ x = self.conv1d1(x)
46
+ x = self.bn1(x)
47
+ x = self.relu1(x)
48
+
49
+ x = self.left_padding2(x)
50
+ x = self.conv1d2(x)
51
+ x = self.bn2(x)
52
+ x = self.relu2(x)
53
+
54
+ x = x.transpose(1, 2)
55
+ x = self.project(x)
56
+
57
+ return x, mask_pad
58
+
59
+ class LinearAdapter(torch.nn.Module):
60
+ def __init__(
61
+ self,
62
+ enc_out_dim: int = 512,
63
+ llm_embed_dim: int = 4096,
64
+ ):
65
+ super().__init__()
66
+
67
+ self.adpter = torch.nn.Linear(enc_out_dim, llm_embed_dim)
68
+
69
+ def forward(self, x, mask_pad):
70
+ return self.adpter(x), mask_pad
71
+
72
+ class CNNSubsampling(torch.nn.Module):
73
+ def __init__(
74
+ self,
75
+ enc_out_dim: int = 512,
76
+ llm_embed_dim: int = 4096,
77
+ kernel_size: int = 5,
78
+ activation_func: str = 'relu',
79
+ norm: str = 'batch',
80
+ ):
81
+ super().__init__()
82
+
83
+ self.kernel_size = kernel_size
84
+ if enc_out_dim * 4 < llm_embed_dim:
85
+ self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
86
+ self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
87
+ self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
88
+ self.relu1 = nn.ReLU()
89
+
90
+ self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
91
+ self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 2, 0)
92
+ self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
93
+ self.relu2 = nn.ReLU()
94
+
95
+ self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
96
+ self.cnn_num = 2
97
+ else:
98
+ self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
99
+ self.conv1d2 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 2, 0)
100
+ if norm == 'batch':
101
+ self.bn2 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
102
+ elif norm == 'layer':
103
+ self.bn2 = nn.LayerNorm(2 * enc_out_dim, eps=1e-3)
104
+ if activation_func == 'gelu':
105
+ self.relu2 = nn.GELU()
106
+ else:
107
+ self.relu2 = nn.ReLU()
108
+
109
+ self.project = nn.Linear(2 * enc_out_dim, llm_embed_dim)
110
+ self.cnn_num = 1
111
+
112
+ def forward(self, x, mask_pad, cache=None, return_cache=False):
113
+ """
114
+ x: B, T, enc_out_dim
115
+ mask: (B, T) or (B, 1, T)
116
+ """
117
+ x = x.transpose(1, 2) # B, channels, T
118
+
119
+ # mask batch padding
120
+ if mask_pad.size(2) > 0: # time > 0
121
+ x.masked_fill_(~mask_pad, 0.0)
122
+
123
+ if self.cnn_num == 2:
124
+ if cache is None:
125
+ x = self.left_padding1(x)
126
+ else:
127
+ x = torch.cat((cache[1], x), dim=2)
128
+ if cache is not None:
129
+ cache[1] = x[:, :, 1-self.kernel_size:]
130
+ else:
131
+ cache = [None, x[:, :, 1-self.kernel_size:]]
132
+ x = self.conv1d1(x)
133
+ x = self.bn1(x)
134
+ x = self.relu1(x)
135
+
136
+ if cache is None or cache[0] is None:
137
+ x = self.left_padding2(x)
138
+ else:
139
+ x = torch.cat((cache[0], x), dim=2)
140
+ if cache is not None:
141
+ cache[0] = x[:, :, 1-self.kernel_size:]
142
+ else:
143
+ cache = [x[:, :, 1-self.kernel_size:]]
144
+ x = self.conv1d2(x)
145
+ if isinstance(self.bn2, nn.LayerNorm):
146
+ x = x.transpose(1, 2)
147
+ x = self.bn2(x)
148
+ if isinstance(self.bn2, nn.LayerNorm):
149
+ x = x.transpose(1, 2)
150
+ x = self.relu2(x)
151
+
152
+ x = x.transpose(1, 2)
153
+ x = self.project(x)
154
+
155
+ if return_cache:
156
+ return x, mask_pad[:, :, 0::2], cache
157
+ return x, mask_pad[:, :, 0::2]
vita/model/vita_tts/audioLLM.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import copy
4
+ import re
5
+
6
+ from torch import nn
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ import torch.nn.functional as F
9
+
10
+ from collections import defaultdict
11
+ from typing import Dict, List, Optional, Tuple
12
+
13
+ from transformers import AutoModelForCausalLM
14
+ from transformers import AutoTokenizer
15
+
16
+ from vita.model.vita_tts.adapter import *
17
+
18
+ IGNORE_ID = -1
19
+
20
+ class AudioLLM(torch.nn.Module):
21
+ def __init__(
22
+ self,
23
+ encoder: torch.nn.Module,
24
+ llm_path: str,
25
+ freeze_llm: bool = True,
26
+ enc_out_dim: int = 512,
27
+ llm_embed_dim: int = 4096,
28
+ kernel_size: int = 3,
29
+ IGNORE_ID: int = -100,
30
+ adpter_type: str = 'cnn',
31
+ add_audio_bos_eos: bool = False,
32
+ task_num: int = 10,
33
+ add_ctc_prompt_ratio: float = 0.0,
34
+ lang_dict: dict = None,
35
+ ctc: torch.nn.Module = None,
36
+ tokenize_ctc_char: bool = False,
37
+ task_before_audio: bool = False,
38
+ hyp_before_task: bool = False,
39
+ prompt_finetune: bool = False,
40
+ add_prompt_before: bool = False,
41
+ prompt_num: int = 5,
42
+ prefix_finetune: bool = False,
43
+ prefix_num: int = 5,
44
+ llm_head_num: int = 32,
45
+ num_key_value_heads: int = None,
46
+ task_type: str = 'prompt',
47
+ freeze_encoder: bool = False,
48
+ freeze_adpter: bool = False,
49
+ activation_func: str = 'relu',
50
+ norm: str = 'batch',
51
+ use_lora: bool = False,
52
+ clone_encoder: torch.nn.Module = None,
53
+ chat_template: str = None,
54
+ predict_usr_state: int = 0,
55
+ chunk_size: int = -1,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.encoder = encoder
60
+ self.llm_decoder = AutoModelForCausalLM.from_pretrained(llm_path,
61
+ torch_dtype="auto",
62
+ trust_remote_code=True)
63
+ self.tokenizer = AutoTokenizer.from_pretrained(llm_path,
64
+ trust_remote_code=True)
65
+ self.freeze_llm = freeze_llm
66
+ self.enc_out_dim = enc_out_dim
67
+ self.llm_embed_dim = llm_embed_dim
68
+ self.IGNORE_ID = IGNORE_ID
69
+ self.add_audio_bos_eos = add_audio_bos_eos
70
+ self.add_ctc_prompt_ratio = add_ctc_prompt_ratio
71
+ self.lang_dict = lang_dict
72
+ self.tokenize_ctc_char = tokenize_ctc_char
73
+ self.task_before_audio = task_before_audio
74
+ self.hyp_before_task = hyp_before_task
75
+ self.prompt_finetune = prompt_finetune
76
+ self.add_prompt_before = add_prompt_before
77
+ self.prompt_num = prompt_num
78
+ self.prefix_finetune = prefix_finetune
79
+ self.prefix_num = prefix_num
80
+ self.llm_head_num = llm_head_num
81
+ if num_key_value_heads is None:
82
+ self.num_key_value_heads = llm_head_num
83
+ else:
84
+ self.num_key_value_heads = num_key_value_heads
85
+ self.kv_cache_dim = llm_embed_dim // self.llm_head_num * self.num_key_value_heads
86
+ self.task_type = task_type
87
+ self.freeze_encoder = freeze_encoder
88
+ self.freeze_adpter = freeze_adpter
89
+ self.predict_usr_state = predict_usr_state
90
+ self.chunk_size = chunk_size
91
+
92
+ if not hasattr(self.tokenizer, "eod_id"):
93
+ self.tokenizer.eod_id = self.tokenizer.eos_token_id
94
+ if not hasattr(self.llm_decoder, "transformer"):
95
+ self.llm_decoder.transformer = self.llm_decoder.model
96
+ self.llm_decoder.transformer.h = self.llm_decoder.transformer.layers
97
+ if not hasattr(self.llm_decoder.transformer, "wte"):
98
+ self.llm_decoder.transformer.wte = \
99
+ self.llm_decoder.transformer.embed_tokens
100
+
101
+ # for chat mode
102
+ if chat_template is not None:
103
+ self.tokenizer.eod_id = self.tokenizer('<|im_end|>'
104
+ )['input_ids'][0]
105
+ self.chat_template = {}
106
+ chat_template = chat_template.split('<audio>')
107
+ chat_prefix = chat_template[0].split('<|im_end|>')
108
+ chat_role = chat_prefix[0] + '<|im_end|>'
109
+ self.chat_template['role'] = self.tokenizer(
110
+ [chat_role], return_tensors="pt")['input_ids']
111
+ self.chat_template['prefix'] = self.tokenizer(
112
+ [chat_prefix[1]], return_tensors="pt")['input_ids']
113
+ self.chat_template['suffix'] = self.tokenizer(
114
+ [chat_template[1]], return_tensors="pt")['input_ids']
115
+ else:
116
+ self.chat_template = None
117
+
118
+ # for CTC prompt
119
+ if self.add_ctc_prompt_ratio > 0.0:
120
+ assert lang_dict is not None
121
+ assert ctc is not None
122
+ self.ctc = ctc.eval()
123
+ if clone_encoder is None:
124
+ self.clone_encoder = copy.deepcopy(encoder)
125
+ else:
126
+ self.clone_encoder = clone_encoder
127
+ self.clone_encoder.eval()
128
+ for (name, param) in self.clone_encoder.named_parameters():
129
+ param.requires_grad = False
130
+ for (name, param) in self.ctc.named_parameters():
131
+ param.requires_grad = False
132
+ else:
133
+ self.clone_encoder = None
134
+
135
+ if self.freeze_llm:
136
+ self.llm_decoder.eval()
137
+ for (name, param) in self.llm_decoder.named_parameters():
138
+ param.requires_grad = False
139
+
140
+ if use_lora:
141
+ config = LoraConfig(
142
+ r=lora_r,
143
+ lora_alpha=lora_alpha,
144
+ target_modules=UNET_TARGET_MODULES,
145
+ lora_dropout=args.lora_dropout,
146
+ bias=args.lora_bias,
147
+ )
148
+
149
+ if adpter_type == 'cnn':
150
+ self.adpter = CNNAdapter(enc_out_dim, llm_embed_dim, kernel_size)
151
+ elif adpter_type == 'linear':
152
+ self.adpter = LinearAdapter(enc_out_dim, llm_embed_dim)
153
+ elif adpter_type == 'subsampling':
154
+ self.adpter = CNNSubsampling(enc_out_dim, llm_embed_dim,
155
+ kernel_size, activation_func, norm)
156
+
157
+ self.task_embeddings = torch.nn.Embedding(task_num, llm_embed_dim)
158
+ if task_type == 'prefix':
159
+ self.prefix_embeddings = nn.ModuleList(
160
+ [
161
+ torch.nn.ModuleList(
162
+ [nn.Embedding(task_num, self.kv_cache_dim),
163
+ nn.Embedding(task_num, self.kv_cache_dim)]
164
+ )
165
+ for i in range(len(self.llm_decoder.transformer.h))
166
+ ]
167
+ )
168
+
169
+ if self.prompt_finetune or self.prefix_finetune:
170
+ if self.prompt_finetune:
171
+ self.prompt_embeddings = nn.Embedding(prompt_num, llm_embed_dim)
172
+ self.prompt_ids = torch.Tensor([i for i in range(prompt_num)]).long()
173
+ if self.prefix_finetune:
174
+ self.prefix_embeddings = nn.ModuleList(
175
+ [
176
+ torch.nn.ModuleList(
177
+ [nn.Embedding(prefix_num, self.kv_cache_dim),
178
+ nn.Embedding(prefix_num, self.kv_cache_dim)]
179
+ )
180
+ for i in range(len(self.llm_decoder.transformer.h))
181
+ ]
182
+ )
183
+ self.prefix_ids = torch.Tensor([i for i in range(prefix_num)]).long()
184
+
185
+ if self.freeze_encoder:
186
+ self.encoder.eval()
187
+ for (name, param) in self.encoder.named_parameters():
188
+ param.requires_grad = False
189
+ if self.freeze_adpter:
190
+ self.adpter.eval()
191
+ for (name, param) in self.adpter.named_parameters():
192
+ param.requires_grad = False
193
+
194
+ if self.predict_usr_state:
195
+ self.predictor_head = torch.nn.Linear(llm_embed_dim, predict_usr_state)
196
+ else:
197
+ self.predictor_head = None
198
+
199
+ # define task ids
200
+ self.task_ids = {
201
+ "sot": 0,
202
+ "transcribe": 1,
203
+ "translate": 2,
204
+ "zh": 3,
205
+ "en": 4,
206
+ "audio": 5,
207
+ "/audio": 6,
208
+ "hyps": 7,
209
+ "/hyps": 8,
210
+ }
211
+
212
+ def set_system_role(
213
+ self,
214
+ extra_inputs: Optional[dict] = None,
215
+ ):
216
+ # Ensure 'past_key_values' does not exist in extra_inputs, raise an exception if it does
217
+ assert extra_inputs.get('past_key_values', None) is None, "past key values already exist!!!"
218
+
219
+ # If 'role' key is present in extra_inputs, use that role as the chat prefix
220
+ if extra_inputs.get('role', None) is not None:
221
+ chat_prefix = self.tokenizer([extra_inputs['role']],
222
+ return_tensors="pt")['input_ids'].to('cuda') # Convert role to tokens and move to CUDA device
223
+ else:
224
+ # If no 'role' is provided, use the default chat template and remove the last token (<|im_end|>)
225
+ chat_prefix = self.chat_template['role'][:, :-1].to('cuda')
226
+
227
+ # Use the LLM decoder's word embedding layer to convert the chat prefix into embeddings
228
+ inputs_embeds = self.llm_decoder.transformer.wte(chat_prefix)
229
+
230
+ # Create an attention mask with the same shape as the chat prefix, all values set to True
231
+ attention_mask = torch.full(chat_prefix.shape,
232
+ True).to(inputs_embeds.device)
233
+
234
+ # Prepare the input dictionary containing embeddings and attention mask
235
+ inputs = {
236
+ 'inputs_embeds': inputs_embeds.half(), # Convert embeddings to half precision floats
237
+ 'attention_mask': attention_mask,
238
+ }
239
+
240
+ # Call the _generate_one_step method to generate one step output, including past_key_values, etc.
241
+ _, past_key_values, stat, _ = self._generate_one_step(
242
+ copy.deepcopy(inputs), "sl")
243
+
244
+ # Return the generated past_key_values
245
+ return past_key_values
246
+
247
+ def recognize(
248
+ self,
249
+ speech: torch.Tensor,
250
+ speech_lengths: torch.Tensor,
251
+ extra_inputs: Optional[dict] = None,
252
+ ):
253
+ assert extra_inputs.get('past_key_values', None) is not None, "must set system role first!!!"
254
+
255
+ buffer = extra_inputs.get('encoder_cache', None)
256
+ cnn_cache = extra_inputs.get('adapter_cache', None)
257
+ pe_index = extra_inputs.get('pe_index', 0)
258
+ if extra_inputs['stat'] == 'sl' or extra_inputs['stat'] == 'cl':
259
+ # Encoder
260
+
261
+ if buffer is None:
262
+ buffer = [None] * self.encoder.enc[1].num_blocks
263
+
264
+ encoder_out, buffer, _, _, pe_index = self.encoder.infer(speech, buffer,
265
+ 0, None, pe_index)
266
+
267
+ encoder_mask = torch.full(encoder_out.shape[:2], True).unsqueeze(1
268
+ ).to(encoder_out.device)
269
+
270
+ # adapter
271
+ inputs_embeds, encoder_mask, cnn_cache = self.adpter(encoder_out, encoder_mask,
272
+ cache=cnn_cache, return_cache=True) # 1, T, D
273
+
274
+ attention_mask = encoder_mask.squeeze(1) # 1, T
275
+
276
+ # prompt
277
+ if extra_inputs['stat'] == 'sl':
278
+ if self.prompt_finetune:
279
+ prompt_ids = self.prompt_ids.repeat(1, 1).to(inputs_embeds.device)
280
+ prompt_embeds = self.prompt_embeddings(
281
+ prompt_ids.to(inputs_embeds.device)) # B, 5, D
282
+ prompt_mask = torch.full(prompt_ids.shape,
283
+ True).to(inputs_embeds.device) # B, 5
284
+
285
+ if self.add_prompt_before:
286
+ inputs_embeds = torch.cat((prompt_embeds, inputs_embeds), 1) # B, (T+5), D
287
+ attention_mask = torch.cat((prompt_mask, attention_mask), 1) # B, (T+5)
288
+
289
+ # chat mode
290
+ if self.chat_template is not None:
291
+ if extra_inputs['stat'] == 'sl':
292
+ chat_prefix = self.chat_template['prefix'].to(
293
+ inputs_embeds.device)
294
+ chat_prefix = torch.cat((torch.tensor([[self.tokenizer.eod_id]]
295
+ ).to(inputs_embeds.device), chat_prefix), 1)
296
+ chat_prefix_embeds = self.llm_decoder.transformer.wte(chat_prefix)
297
+ chat_prefix_mask = torch.full(chat_prefix.shape,
298
+ True).to(inputs_embeds.device)
299
+ inputs_embeds = torch.cat((chat_prefix_embeds, inputs_embeds), 1)
300
+ attention_mask = torch.cat((chat_prefix_mask, attention_mask), 1)
301
+ if extra_inputs['stat'] == 'ss':
302
+ chat_suffix = self.chat_template['suffix'].to('cuda')
303
+ chat_suffix_embeds = self.llm_decoder.transformer.wte(chat_suffix)
304
+ chat_suffix_mask = torch.full(chat_suffix.shape, True).to('cuda')
305
+ inputs_embeds = chat_suffix_embeds
306
+ attention_mask = chat_suffix_mask
307
+
308
+ if extra_inputs['stat'] != 'cs':
309
+ inputs = {
310
+ 'inputs_embeds': inputs_embeds.half(),
311
+ 'attention_mask': attention_mask,
312
+ }
313
+ else:
314
+ attention_mask = torch.full([1, 1], True).to('cuda')
315
+ inputs = {
316
+ 'input_ids': extra_inputs['last_id'],
317
+ 'attention_mask': attention_mask
318
+ }
319
+
320
+ # add kv cache
321
+ inputs['past_key_values'] = extra_inputs['past_key_values']
322
+ past_mask = torch.full([1, inputs['past_key_values'][0][0].size(2)],
323
+ True).to('cuda')
324
+ attention_mask = torch.cat((past_mask, attention_mask), 1)
325
+ inputs['attention_mask'] = attention_mask
326
+
327
+ top_p = extra_inputs.get('top_p', 1.0)
328
+ top_k = extra_inputs.get('top_k', 0)
329
+ temperature = extra_inputs.get('temperature', 1.0)
330
+
331
+ last_id, past_key_values, stat, hidden_state = self._generate_one_step(copy.deepcopy(inputs),
332
+ extra_inputs['stat'],
333
+ top_p=top_p,
334
+ top_k=top_k,
335
+ temperature=temperature)
336
+
337
+ return last_id, stat, past_key_values, cnn_cache, buffer, pe_index, hidden_state
338
+
339
+ def _post_decode(self, output, temperature=1.0, top_k=0, top_p=0.0):
340
+ """
341
+ Decoding function, based on the posterior probability output,
342
+ uses top_k, top_p, and temperature parameters for sampling.
343
+
344
+ Parameters:
345
+ - output: torch.Tensor, shaped as (1, 1, D), represents the posterior probability output by the model.
346
+ - top_k: int, indicates selecting the top k tokens with the highest probability for sampling.
347
+ If 0, no top_k filtering is performed.
348
+ - top_p: float, indicates selecting tokens with cumulative probability not exceeding p for sampling.
349
+ If 0.0, no top_p filtering is performed.
350
+ - temperature: float, represents the sampling temperature parameter.
351
+ The higher the value, the more random the sampling;
352
+ the lower the value, the more deterministic the sampling.
353
+
354
+ Returns:
355
+ - Selected token index.
356
+ """
357
+ output = output.squeeze(0).squeeze(0)
358
+
359
+ # temperature
360
+ if temperature != 1.0:
361
+ output = output / temperature
362
+
363
+ probs = torch.nn.functional.softmax(output, dim=-1)
364
+
365
+ # top_k
366
+ if top_k > 0:
367
+ top_k_probs, top_k_indices = torch.topk(probs, top_k)
368
+ probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs)
369
+ probs = probs / probs.sum()
370
+
371
+ # top_p
372
+ if top_p > 0.0:
373
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
374
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
375
+ sorted_indices_to_remove = cumulative_probs > top_p
376
+ if sorted_indices_to_remove[0]:
377
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
378
+ sorted_indices_to_remove[0] = 0
379
+
380
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
381
+ probs[indices_to_remove] = 0
382
+ probs = probs / probs.sum()
383
+
384
+ token_index = torch.multinomial(probs, 1)
385
+ return token_index.unsqueeze(0)
386
+
387
+ def _generate_one_step(
388
+ self,
389
+ inputs,
390
+ stat,
391
+ top_p: float = 1.0,
392
+ top_k: int = 0,
393
+ temperature: float = 1.0,
394
+ ):
395
+ """
396
+ Generates the model's next output based on the current input and state.
397
+
398
+ Parameters:
399
+ - inputs: The input tensor containing the model's input data.
400
+ - stat: The current state information used to control the generation process.
401
+ - top_p: The threshold for controlling top-p sampling.
402
+ - top_k: The threshold for controlling top-k sampling.
403
+ - temperature: Controls the randomness of sampling.
404
+
405
+ Returns:
406
+ - last_id: The index of the last generated token.
407
+ - stat: The updated state information.
408
+ - past_key_values: The model's historical key-value pairs, used for cross-step memory.
409
+ - hidden_state: The model's hidden state, used to maintain cross-step contextual information.
410
+ """
411
+ outputs = self.llm_decoder.model(**inputs)
412
+ if stat == 'sl' or stat == 'cl':
413
+ state_logits = self.predictor_head(
414
+ outputs['last_hidden_state'])[0, :]
415
+ prob = F.softmax(state_logits[:, :-1])
416
+ state_prob = prob[-1].clone()
417
+ state_1 = state_prob[1]
418
+ state_2 = state_prob[2]
419
+ print("State 1 prob: {:.4f}, State 2 prob: {:.4f}".format(state_1.item(), state_2.item()))
420
+ if state_2 > 0.5:
421
+ return None, outputs['past_key_values'], 'el', None
422
+ if state_1 > 0.5:
423
+ return None, outputs['past_key_values'], 'ss', None
424
+ return None, outputs['past_key_values'], 'cl', None
425
+
426
+ last_logit = self.llm_decoder.lm_head(outputs['last_hidden_state'][:, -1:, :])
427
+ last_id = self._post_decode(last_logit, temperature=temperature, top_k=top_k, top_p=top_p)
428
+ return_tts_state = outputs['last_hidden_state'][:, -1:, :]
429
+
430
+ if last_id[0][0] == self.tokenizer.eod_id:
431
+ return None, outputs['past_key_values'], 'sl', return_tts_state
432
+ else:
433
+ return last_id, outputs['past_key_values'], 'cs', return_tts_state
vita/model/vita_tts/decoder/decoder.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from typing import Dict, List, Tuple, Optional, Union
7
+ from transformers import LlamaConfig
8
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding
9
+ from transformers.cache_utils import DynamicCache
10
+
11
+ from vita.model.vita_tts.encoder.encoder import add_encoder_args
12
+ from vita.model.vita_tts.masks import *
13
+
14
+ IGNORE_ID = -1
15
+
16
+ class CrossEntropyLoss(torch.nn.Module):
17
+ def __init__(self, ignore_index=-1):
18
+ super(CrossEntropyLoss, self).__init__()
19
+ self.criterion = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=ignore_index)
20
+
21
+ def forward(self, logits, target, target_subsampling_factor=1):
22
+ """
23
+ logits: B*T1*D
24
+ target: B*T2
25
+ """
26
+ logits = logits[:, :target.shape[1], :]
27
+ logits = logits.transpose(1, 2)
28
+ target = target.to(torch.long)
29
+ loss = self.criterion(logits, target)
30
+ return loss
31
+
32
+ class LLM2TTSCodecAR(torch.nn.Module):
33
+ """E2E module.
34
+
35
+ Args:
36
+ idim (int): dimension of inputs
37
+ odim (int): dimension of outputs
38
+ args (namespace): argument Namespace containing options
39
+
40
+ """
41
+
42
+ @staticmethod
43
+ def add_arguments(parser):
44
+ """Extend arguments for transducer."""
45
+ group = parser.add_argument_group("TDNN model setting")
46
+
47
+ group.add_argument('--encoder-pre-norm-type',
48
+ default='ln', type=str, help="Type of input norm.")
49
+ group.add_argument('--encoder-drop-rate', default=0.0,
50
+ type=float, help="Dropout rate for output.")
51
+ group.add_argument('--encoder-criterion', default='cross-entropy',
52
+ type=str, help="Criterion for output")
53
+ group.add_argument('--encoder-upsample-rate', default=1, type=int)
54
+ group.add_argument('--kv-cache-prefix-finetune', default=0, type=int)
55
+
56
+ group = add_encoder_args(group)
57
+
58
+ return parser
59
+
60
+ def __init__(self, idim, odim, args):
61
+ """Initialize transducer modules.
62
+
63
+ Args:
64
+ idim (int): dimension of inputs
65
+ odim (int): dimension of outputs
66
+ args (Namespace): argument Namespace containing options
67
+
68
+ """
69
+ super(LLM2TTSCodecAR, self).__init__()
70
+ self.idim = args.idim
71
+ self.odim = args.odim
72
+ self.encoder_pre_norm_type = args.encoder_pre_norm_type
73
+ self.encoder_drop_rate = args.encoder_drop_rate
74
+ self.encoder_criterion = args.encoder_criterion
75
+ self.encoder_upsample_rate = args.encoder_upsample_rate
76
+ self.reporter = None
77
+
78
+ self.vocab_size = self.odim
79
+ config = LlamaConfig(vocab_size=self.vocab_size + 4, hidden_size=args.transformer_attention_dim,
80
+ intermediate_size=args.transformer_linear_units,
81
+ num_hidden_layers=args.transformer_num_blocks,
82
+ num_attention_heads=args.transformer_attention_heads, max_position_embeddings=2048,
83
+ bos_token_id=self.vocab_size + 1,
84
+ eos_token_id=self.vocab_size + 2, pad_token_id=self.vocab_size + 3,
85
+ attention_dropout=args.transformer_dropout_rate)
86
+
87
+ self.embedding = nn.Embedding(self.vocab_size + 4, self.idim, padding_idx=self.vocab_size + 3)
88
+ self.init_pre_nn(config)
89
+
90
+ self.layers = nn.ModuleList(
91
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
92
+ )
93
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
94
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
95
+
96
+ self.dropout = nn.Dropout(p=self.encoder_drop_rate)
97
+ self.out_fnn = nn.Linear(args.encoder_output_dim, self.vocab_size + 4)
98
+
99
+ self.kv_cache_prefix_finetune = args.kv_cache_prefix_finetune
100
+ if self.kv_cache_prefix_finetune:
101
+ self.init_kv_cache_prefix(config)
102
+ self.embedding.eval()
103
+ self.layers.eval()
104
+ self.norm.eval()
105
+ self.rotary_emb.eval()
106
+ self.out_fnn.eval()
107
+ for (name, param) in self.embedding.named_parameters():
108
+ param.requires_grad = False
109
+ for (name, param) in self.layers.named_parameters():
110
+ param.requires_grad = False
111
+ for (name, param) in self.norm.named_parameters():
112
+ param.requires_grad = False
113
+ for (name, param) in self.rotary_emb.named_parameters():
114
+ param.requires_grad = False
115
+ for (name, param) in self.out_fnn.named_parameters():
116
+ param.requires_grad = False
117
+
118
+ if self.encoder_criterion == 'ce':
119
+ self.criterion = CrossEntropyLoss(ignore_index=self.vocab_size + 3)
120
+
121
+ def init_kv_cache_prefix(self, config):
122
+ self.layers_prefix = nn.ModuleList(
123
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
124
+ )
125
+ self.rotary_emb_prefix = LlamaRotaryEmbedding(config=config)
126
+
127
+ def kv_cache_prefix_forward(self, prefix, prefix_lens, past_key_values):
128
+ inputs_embeds = prefix
129
+ past_seen_tokens = 0
130
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \
131
+ inputs_embeds.shape[1], device=inputs_embeds.device)
132
+ position_ids = cache_position.unsqueeze(0)
133
+ hidden_states = inputs_embeds
134
+ position_embeddings = self.rotary_emb_prefix(hidden_states, position_ids)
135
+ next_decoder_cache = None
136
+ batch_size, max_len, _ = prefix.size()
137
+ input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=prefix.device)
138
+ for i in range(batch_size):
139
+ input_mask[i, :prefix_lens[i], :prefix_lens[i]] = True
140
+ attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min
141
+ for decoder_layer in self.layers_prefix:
142
+ layer_outputs = decoder_layer(
143
+ hidden_states,
144
+ attention_mask=attention_mask,
145
+ position_ids=position_ids,
146
+ past_key_value=past_key_values,
147
+ output_attentions=False,
148
+ use_cache=True,
149
+ cache_position=None,
150
+ position_embeddings=position_embeddings,
151
+ )
152
+ hidden_states = layer_outputs[0]
153
+ next_decoder_cache = layer_outputs[1]
154
+ past_key_values = next_decoder_cache
155
+
156
+ def init_pre_nn(self, config):
157
+ self.layers_pre_nn = nn.ModuleList(
158
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers // 2)]
159
+ )
160
+ self.rotary_emb_pre_nn = LlamaRotaryEmbedding(config=config)
161
+
162
+ def pre_nn_forward(self, hidden, hidden_lens):
163
+ inputs_embeds = hidden
164
+ past_seen_tokens = 0
165
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \
166
+ inputs_embeds.shape[1], device=inputs_embeds.device)
167
+ position_ids = cache_position.unsqueeze(0)
168
+ hidden_states = inputs_embeds
169
+ position_embeddings = self.rotary_emb_pre_nn(hidden_states, position_ids)
170
+ next_decoder_cache = None
171
+ batch_size, max_len, _ = hidden.size()
172
+ input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=hidden.device)
173
+ for i in range(batch_size):
174
+ input_mask[i, :hidden_lens[i], :hidden_lens[i]] = True
175
+ attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min
176
+ for decoder_layer in self.layers_pre_nn:
177
+ layer_outputs = decoder_layer(
178
+ hidden_states,
179
+ attention_mask=attention_mask,
180
+ position_ids=position_ids,
181
+ past_key_value=None,
182
+ output_attentions=False,
183
+ use_cache=False,
184
+ cache_position=None,
185
+ position_embeddings=position_embeddings,
186
+ )
187
+ hidden_states = layer_outputs[0]
188
+ return hidden_states
189
+
190
+ def forward(self, batch):
191
+ llm_hidden = batch['x']
192
+ llm_hidden_lens = batch['x_lens']
193
+ y = batch['y']
194
+ y[y == IGNORE_ID] = self.vocab_size + 3
195
+ y_lens = batch['y_lens']
196
+ past_key_values = DynamicCache.from_legacy_cache(None)
197
+
198
+ if self.kv_cache_prefix_finetune:
199
+ self.kv_cache_prefix_forward(batch['x_prefix'], batch['x_prefix_lens'], past_key_values)
200
+
201
+ # text_ids: (batch_size, max_len)
202
+ batch_size, max_len = y.size()
203
+
204
+ # Create bos, sos and eos tokens
205
+ bos_token = torch.full((batch_size, 1), self.vocab_size, dtype=torch.long, device=y.device)
206
+ sos_token = torch.full((batch_size, 1), self.vocab_size + 1, dtype=torch.long, device=y.device)
207
+ eos_token = torch.full((batch_size, 1), self.vocab_size + 2, dtype=torch.long, device=y.device)
208
+ padding_token = torch.full((batch_size, 1), self.vocab_size + 3, dtype=torch.long, device=y.device)
209
+
210
+ # Pass through pre_nn
211
+ llm_hidden = self.pre_nn_forward(llm_hidden, llm_hidden_lens)
212
+
213
+ # Concat bos embedding
214
+ bos_emb = self.embedding(bos_token)
215
+ llm_hidden = torch.cat([bos_emb, llm_hidden], dim=1)
216
+ llm_hidden_lens = llm_hidden_lens + 1
217
+
218
+ # Create input x with sos token at the beginning
219
+ x = torch.cat([sos_token, y], dim=1) # (batch_size, max_len + 1)
220
+
221
+ # Create output y with eos token at the end
222
+ y = torch.cat([y, padding_token], dim=1)
223
+ eos_positions = torch.arange(max_len + 1, device=y.device).expand(batch_size, max_len + 1) \
224
+ == y_lens.unsqueeze(1)
225
+ y = y.masked_scatter(eos_positions, eos_token.expand_as(y)[eos_positions])
226
+
227
+ # Embed the input sequence
228
+ x_emb = self.embedding(x) # (batch_size, max_len + 1, d_model)
229
+
230
+ # compute masks
231
+ if self.kv_cache_prefix_finetune:
232
+ x_prefix = batch['x_prefix']
233
+ x_prefix_lens = batch['x_prefix_lens']
234
+ input_lens = llm_hidden.size(1) + max_len + 1
235
+ input_mask = torch.zeros(batch_size, input_lens, x_prefix.size(1) + input_lens, \
236
+ dtype=torch.bool, device=x_emb.device)
237
+ for i in range(batch_size):
238
+ input_mask[i, :llm_hidden_lens[i], :x_prefix_lens[i]] = True
239
+ input_mask[i, :llm_hidden_lens[i], x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True
240
+ input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :x_prefix_lens[i]] = True
241
+ input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \
242
+ x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True
243
+ input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \
244
+ x_prefix.size(1) + llm_hidden.size(1): x_prefix.size(1) + \
245
+ llm_hidden.size(1) + y_lens[i] + 1] \
246
+ = subsequent_mask(y_lens[i] + 1, x_emb.device)
247
+ else:
248
+ input_lens = llm_hidden.size(1) + max_len + 1
249
+ input_mask = torch.zeros(batch_size, input_lens, input_lens, dtype=torch.bool, device=x_emb.device)
250
+ for i in range(batch_size):
251
+ input_mask[i, :llm_hidden_lens[i], :llm_hidden_lens[i]] = True
252
+ input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :llm_hidden_lens[i]] = True
253
+ input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \
254
+ llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1] \
255
+ = subsequent_mask(y_lens[i] + 1, x_emb.device)
256
+
257
+ # Pass through the transformer
258
+ inputs_embeds = torch.cat([llm_hidden, x_emb], 1)
259
+ llm_hidden = self.dropout(llm_hidden)
260
+ past_seen_tokens = 0
261
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \
262
+ device=inputs_embeds.device)
263
+ position_ids = cache_position.unsqueeze(0)
264
+ hidden_states = inputs_embeds
265
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
266
+ attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min
267
+ for decoder_layer in self.layers:
268
+ layer_outputs = decoder_layer(
269
+ hidden_states,
270
+ attention_mask=attention_mask,
271
+ position_ids=position_ids,
272
+ past_key_value=past_key_values,
273
+ output_attentions=False,
274
+ use_cache=True,
275
+ cache_position=None,
276
+ position_embeddings=position_embeddings,
277
+ )
278
+ hidden_states = layer_outputs[0]
279
+ hidden_states = self.norm(hidden_states)
280
+
281
+ encoder_out = hidden_states[:, llm_hidden.size(1):]
282
+
283
+ # Project to vocabulary size
284
+ logits = self.out_fnn(encoder_out)
285
+
286
+ if self.encoder_criterion == 'ce':
287
+ loss = self.criterion(logits, y)
288
+
289
+ if self.training:
290
+ self.reporter.log_loss('loss', float(loss))
291
+
292
+ return loss
293
+
294
+ def transformer_infer(self, inputs_embeds, cache_position, past_key_values):
295
+ position_ids = cache_position.unsqueeze(0)
296
+ hidden_states = inputs_embeds
297
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
298
+ next_decoder_cache = None
299
+ for decoder_layer in self.layers:
300
+ layer_outputs = decoder_layer(
301
+ hidden_states,
302
+ attention_mask=None,
303
+ position_ids=position_ids,
304
+ past_key_value=past_key_values,
305
+ output_attentions=False,
306
+ use_cache=True,
307
+ cache_position=None,
308
+ position_embeddings=position_embeddings,
309
+ )
310
+ hidden_states = layer_outputs[0]
311
+ next_decoder_cache = layer_outputs[1]
312
+ return hidden_states
313
+
314
+ def infer(self, hidden, top_k, prefix, penalty_window_size, penalty, max_tokens=1000):
315
+ # Pass through pre_nn
316
+ hidden = self.pre_nn_forward(hidden, [hidden.size(1)])
317
+ # Concat bos embedding
318
+ bos_emb = self.embedding(torch.full((1, 1), self.vocab_size, dtype=torch.long, device=hidden.device))
319
+ hidden = torch.cat([bos_emb, hidden], dim=1)
320
+ # init past key values
321
+ past_key_values = DynamicCache.from_legacy_cache(None)
322
+ # Pass through the prefix nar decoder
323
+ if prefix is not None and self.kv_cache_prefix_finetune:
324
+ self.kv_cache_prefix_forward(prefix, [prefix.size(1)], past_key_values)
325
+ inputs_embeds = hidden
326
+ past_seen_tokens = 0
327
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \
328
+ device=inputs_embeds.device)
329
+ hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values)
330
+
331
+ # init generated tokens
332
+ cur_token = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device)
333
+ generated_tokens = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device)
334
+ # generate tokens
335
+ for i in range(max_tokens):
336
+ inputs_embeds = self.embedding(cur_token)
337
+ past_seen_tokens = past_key_values.get_seq_length()
338
+ if prefix is not None:
339
+ past_seen_tokens = past_seen_tokens - prefix.size(1)
340
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \
341
+ device=inputs_embeds.device)
342
+ hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values)
343
+ hidden_states = self.norm(hidden_states)
344
+
345
+ # Project to vocabulary size
346
+ logits = self.out_fnn(hidden_states)
347
+
348
+ # apply penalty
349
+ if penalty_window_size > 0:
350
+ for token in set(generated_tokens[0][-penalty_window_size:]):
351
+ logits[:, :, token] /= penalty
352
+
353
+ # top k sampling
354
+ output = logits.squeeze(0).squeeze(0)
355
+ probs = torch.nn.functional.softmax(output, dim=-1)
356
+ top_k_probs, top_k_indices = torch.topk(probs, top_k)
357
+ probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs)
358
+ probs = probs / probs.sum()
359
+ next_token_id = torch.multinomial(probs, 1).unsqueeze(0)
360
+
361
+ generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1)
362
+ cur_token = next_token_id
363
+
364
+ # eos
365
+ if next_token_id == self.vocab_size + 2:
366
+ break
367
+ yield next_token_id
vita/model/vita_tts/decoder/llm2tts.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import json
5
+ import torch
6
+ import random
7
+ import argparse
8
+ import subprocess
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import subprocess
12
+ import concurrent.futures
13
+
14
+ from vita.model.vita_tts.decoder.decoder import LLM2TTSCodecAR
15
+ from vita.model.vita_tts.decoder.ticodec.vqvae_tester import VqvaeTester
16
+
17
+ class llm2TTS():
18
+ def __init__(self, model_path):
19
+ self.model = self.get_model(model_path).cuda().to(
20
+ torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
21
+ )
22
+ self.infer = self.model.infer
23
+
24
+ self.codec_model = VqvaeTester(config_path=model_path + "/codec/model.json",
25
+ model_path=model_path + "/codec/final.pt",
26
+ sample_rate=24000)
27
+ self.codec_model = self.codec_model.cuda()
28
+ self.codec_model.vqvae.generator.remove_weight_norm()
29
+ self.codec_model.vqvae.encoder.remove_weight_norm()
30
+ self.codec_model.eval()
31
+
32
+ def get_model_conf(self, model_path):
33
+ model_conf = model_path + "/decoder/model.json"
34
+ with open(model_conf, "rb") as f:
35
+ print('reading a config file from ' + model_conf)
36
+ confs = json.load(f)
37
+ # for asr, tts, mt
38
+ idim, odim, args = confs
39
+ return argparse.Namespace(**args)
40
+
41
+ def get_model(self, model_path):
42
+ args_load = self.get_model_conf(model_path)
43
+ args_load = vars(args_load)
44
+ args = argparse.Namespace(**args_load)
45
+ odim = args.odim
46
+ idim = args.idim
47
+ model = LLM2TTSCodecAR(idim, odim, args)
48
+
49
+ # Resume from a snapshot
50
+ snapshot_dict = torch.load(model_path + "/decoder/final.pt", map_location=lambda storage, loc: storage)
51
+ if 'model' in snapshot_dict.keys():
52
+ resume_model_dict = snapshot_dict['model']
53
+ else:
54
+ resume_model_dict = snapshot_dict
55
+
56
+ model_dict = model.state_dict()
57
+ for key in model_dict.keys():
58
+ if key in resume_model_dict.keys():
59
+ if model_dict[key].shape == resume_model_dict[key].shape:
60
+ model_dict[key] = resume_model_dict[key]
61
+ else:
62
+ print('Key {} has different shape, {} VS {}'.format(key, model_dict[key].shape,
63
+ resume_model_dict[key].shape))
64
+ else:
65
+ print('Key {} has not in resume model'.format(key))
66
+ model.load_state_dict(model_dict)
67
+ model.eval()
68
+ return model
69
+
70
+ def find_min_sum_index(self, buffer, syn, N, threshold):
71
+ """
72
+ Find the index with the minimum sum of a sliding window in the given audio segment
73
+ and perform operations based on this index.
74
+
75
+ Parameters:
76
+ - buffer (torch.Tensor): The buffer containing previously processed audio segments.
77
+ - syn (torch.Tensor): The current audio segment to be processed.
78
+ - N (int): The size of the sliding window used to calculate the sum.
79
+ - threshold (float): Threshold value to determine whether to concatenate buffer and current segment or not.
80
+
81
+ Returns:
82
+ - tuple: A tuple containing the updated buffer and the processed audio segment.
83
+
84
+ """
85
+ arr = syn[0, 0, :]
86
+ L = len(arr)
87
+ mid = L // 2
88
+
89
+ kernel = torch.ones(N).to(arr.device)
90
+ window_sums = torch.nn.functional.conv1d(arr.abs().view(1, 1, -1), kernel.view(1, 1, -1), padding=0).squeeze()
91
+
92
+ start_index = mid - (N // 2)
93
+ min_sum, min_index = torch.min(window_sums[start_index:], dim=0)
94
+
95
+ # get the start and end index of the window
96
+ start_index = max(0, min_index.item() + start_index)
97
+ end_index = min(L, min_index.item() + N + start_index)
98
+
99
+ # calculate the real min_sum and min_index
100
+ min_sum_real, min_index_real = torch.min(arr[start_index: end_index].abs(), dim=0)
101
+ min_index = min_index_real.item() + start_index
102
+
103
+ min_sum = min_sum / N
104
+ syn_clone = syn.clone()
105
+
106
+ if min_sum < threshold:
107
+ syn = torch.cat([buffer.clone(), syn[:, :, :min_index]], dim=-1)
108
+ buffer = syn_clone[:, :, min_index:]
109
+ else:
110
+ buffer = torch.cat([buffer, syn_clone], dim=-1)
111
+ syn = None
112
+ return buffer, syn
113
+
114
+ def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10,
115
+ penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01):
116
+ """
117
+ Run the speech decoder process.
118
+
119
+ Parameters:
120
+ - hidden (torch.Tensor): The output for embedding layer of the language model.
121
+ - top_k (int): The number of top-k tokens to consider during inference.
122
+ - prefix (str, optional): The hidden state from the language model.
123
+ - codec_chunk_size (int, default=40): The size of each chunk to process in the codec model.
124
+ - codec_padding_size (int, default=10): The amount of padding to add on each side of the codec chunk.
125
+ - penalty_window_size (int, default=20): The window size for applying penalties during decoding.
126
+ - penalty (float, default=1.1): The penalty factor.
127
+
128
+ Yields:
129
+ - torch.Tensor: Intermediate audio segments generated by the codec model.
130
+
131
+ """
132
+ codec_upsample_rate = 600
133
+ left_padding = 0
134
+ right_padding = codec_padding_size
135
+ prefix = None
136
+ buffer = torch.zeros([1, 1, 0]).to(hidden.device)
137
+ with torch.no_grad():
138
+ with torch.autocast(device_type="cuda",
139
+ dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32):
140
+ print("Starting TTS...")
141
+ token = torch.full((1, 0), self.model.vocab_size, dtype=torch.long, device=hidden.device)
142
+ for next_token_id in self.infer(hidden, top_k, prefix, penalty_window_size, penalty):
143
+ token = torch.cat([token, next_token_id], dim=-1)
144
+ if token.size(1) == left_padding + codec_chunk_size + right_padding:
145
+ syn = self.codec_model.vqvae(token.unsqueeze(-1),
146
+ torch.tensor(self.codec_model.vqvae.h.global_tokens,
147
+ device=token.device).unsqueeze(0).unsqueeze(0))
148
+ print("Codec Done")
149
+ syn = syn[:, :, left_padding * codec_upsample_rate: -right_padding * codec_upsample_rate]
150
+ left_padding = codec_padding_size
151
+ token = token[:, -(left_padding + right_padding):]
152
+ buffer, syn = self.find_min_sum_index(buffer, syn, N, seg_threshold)
153
+ if syn is not None:
154
+ yield syn
155
+ if token.size(1) > 0:
156
+ print("Codec Done")
157
+ syn = self.codec_model.vqvae(token.unsqueeze(-1),
158
+ torch.tensor(self.codec_model.vqvae.h.global_tokens,
159
+ device=token.device).unsqueeze(0).unsqueeze(0))
160
+ syn = syn[:, :, left_padding * codec_upsample_rate:]
161
+ yield torch.cat([buffer, syn], dim=-1)
vita/model/vita_tts/decoder/ticodec/models.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import AvgPool1d
5
+ from torch.nn import Conv1d
6
+ from torch.nn import Conv2d
7
+ from torch.nn import ConvTranspose1d
8
+ from torch.nn.utils import remove_weight_norm
9
+ from torch.nn.utils import spectral_norm
10
+ from torch.nn.utils import weight_norm
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size * dilation - dilation) / 2)
16
+
17
+ def init_weights(m, mean=0.0, std=0.01):
18
+ classname = m.__class__.__name__
19
+ if classname.find("Conv") != -1:
20
+ m.weight.data.normal_(mean, std)
21
+
22
+ class GlobalTokenEncoder(nn.Module):
23
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size=3, stride=1):
24
+ super().__init__()
25
+ self.pad = (kernel_size - stride) // 2
26
+ self.conv = nn.Sequential(
27
+ nn.Conv1d(in_channels, hidden_channels, kernel_size, stride, self.pad, bias=False),
28
+ nn.LeakyReLU(LRELU_SLOPE),
29
+ nn.Conv1d(hidden_channels, hidden_channels, kernel_size, stride, self.pad, bias=False),
30
+ nn.LeakyReLU(LRELU_SLOPE),
31
+ nn.Conv1d(hidden_channels, out_channels, kernel_size, stride, self.pad, bias=False),
32
+ nn.LeakyReLU(LRELU_SLOPE),
33
+ )
34
+ self.fn = nn.Sequential(
35
+ # # 2 layers
36
+ # nn.Linear(out_channels, hidden_channels),
37
+ # nn.LeakyReLU(LRELU_SLOPE),
38
+ # nn.Linear(hidden_channels, out_channels),
39
+ # nn.LeakyReLU(LRELU_SLOPE),
40
+ # 1 layer
41
+ nn.Linear(out_channels, out_channels),
42
+ nn.LeakyReLU(LRELU_SLOPE),
43
+ nn.BatchNorm1d(out_channels),
44
+ )
45
+ def forward(self, x):
46
+ """
47
+ x --- [B, in_channels, T]
48
+ out -- [B, out_channels]
49
+ """
50
+ # x_mask = torch.unsqueeze(sequence_mask(
51
+ # x_lengths, x.size(2)), 1).to(x.dtype)
52
+ # x = self.conv(x) * x_mask
53
+ x = self.conv(x)
54
+ # x = torch.sum(x, dim=2) / torch.sum(x_mask, dim=2) # [B, out_channels]
55
+ x = torch.mean(x, dim=2) # [B, out_channels]
56
+ x = self.fn(x)
57
+ return x
58
+
59
+ class ResBlock1(torch.nn.Module):
60
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
61
+ super(ResBlock1, self).__init__()
62
+ self.h = h
63
+ self.convs1 = nn.ModuleList([
64
+ weight_norm(
65
+ Conv1d(
66
+ channels,
67
+ channels,
68
+ kernel_size,
69
+ 1,
70
+ dilation=dilation[0],
71
+ padding=get_padding(kernel_size, dilation[0]))),
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=dilation[1],
79
+ padding=get_padding(kernel_size, dilation[1]))),
80
+ weight_norm(
81
+ Conv1d(
82
+ channels,
83
+ channels,
84
+ kernel_size,
85
+ 1,
86
+ dilation=dilation[2],
87
+ padding=get_padding(kernel_size, dilation[2])))
88
+ ])
89
+ self.convs1.apply(init_weights)
90
+
91
+ self.convs2 = nn.ModuleList([
92
+ weight_norm(
93
+ Conv1d(
94
+ channels,
95
+ channels,
96
+ kernel_size,
97
+ 1,
98
+ dilation=1,
99
+ padding=get_padding(kernel_size, 1))), weight_norm(
100
+ Conv1d(
101
+ channels,
102
+ channels,
103
+ kernel_size,
104
+ 1,
105
+ dilation=1,
106
+ padding=get_padding(kernel_size, 1))), weight_norm(
107
+ Conv1d(
108
+ channels,
109
+ channels,
110
+ kernel_size,
111
+ 1,
112
+ dilation=1,
113
+ padding=get_padding(kernel_size, 1)))
114
+ ])
115
+ self.convs2.apply(init_weights)
116
+
117
+ def forward(self, x):
118
+ for c1, c2 in zip(self.convs1, self.convs2):
119
+ xt = F.leaky_relu(x, LRELU_SLOPE)
120
+ xt = c1(xt)
121
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
122
+ xt = c2(xt)
123
+ x = xt + x
124
+ return x
125
+
126
+ def remove_weight_norm(self):
127
+ for l in self.convs1:
128
+ remove_weight_norm(l)
129
+ for l in self.convs2:
130
+ remove_weight_norm(l)
131
+
132
+
133
+ class ResBlock2(torch.nn.Module):
134
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
135
+ super(ResBlock2, self).__init__()
136
+ self.h = h
137
+ self.convs = nn.ModuleList([
138
+ weight_norm(
139
+ Conv1d(
140
+ channels,
141
+ channels,
142
+ kernel_size,
143
+ 1,
144
+ dilation=dilation[0],
145
+ padding=get_padding(kernel_size, dilation[0]))),
146
+ weight_norm(
147
+ Conv1d(
148
+ channels,
149
+ channels,
150
+ kernel_size,
151
+ 1,
152
+ dilation=dilation[1],
153
+ padding=get_padding(kernel_size, dilation[1])))
154
+ ])
155
+ self.convs.apply(init_weights)
156
+
157
+ def forward(self, x):
158
+ for c in self.convs:
159
+ xt = F.leaky_relu(x, LRELU_SLOPE)
160
+ xt = c(xt)
161
+ x = xt + x
162
+ return x
163
+
164
+ def remove_weight_norm(self):
165
+ for l in self.convs:
166
+ remove_weight_norm(l)
167
+
168
+
169
+ class Generator(torch.nn.Module):
170
+ def __init__(self, h):
171
+ """
172
+ Initializes the Generator module.
173
+
174
+ Parameters:
175
+ - h (object): Configuration object containing hyperparameters for the generator.
176
+ """
177
+ super(Generator, self).__init__()
178
+ self.h = h
179
+ self.num_kernels = len(h.resblock_kernel_sizes)
180
+ self.num_upsamples = len(h.upsample_rates)
181
+ self.conv_pre = weight_norm(
182
+ Conv1d(512, h.upsample_initial_channel, 7, 1, padding=3))
183
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
184
+
185
+ self.ups = nn.ModuleList()
186
+ for i, (u,
187
+ k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
188
+ self.ups.append(
189
+ weight_norm(
190
+ ConvTranspose1d(
191
+ h.upsample_initial_channel // (2**i),
192
+ h.upsample_initial_channel // (2**(i + 1)),
193
+ k,
194
+ u,
195
+ # padding=(u//2 + u%2),
196
+ padding=(k - u) // 2,
197
+ # output_padding=u%2
198
+ )))
199
+
200
+ self.resblocks = nn.ModuleList()
201
+ for i in range(len(self.ups)):
202
+ ch = h.upsample_initial_channel // (2**(i + 1))
203
+ for j, (k, d) in enumerate(
204
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
205
+ self.resblocks.append(resblock(h, ch, k, d))
206
+
207
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
208
+ self.ups.apply(init_weights)
209
+ self.conv_post.apply(init_weights)
210
+
211
+ def forward(self, x, global_features):
212
+ """
213
+ Forward pass of the Generator module.
214
+
215
+ Parameters:
216
+ - x (torch.Tensor): Input tensor of shape [B, C, T], where B is the batch size,
217
+ C is the number of channels, and T is the sequence length.
218
+ - global_features (torch.Tensor): Global features tensor of shape [B, 128].
219
+
220
+ Returns:
221
+ - torch.Tensor: Output tensor of shape [B, 1, T],
222
+ where B is the batch size, and T is the sequence length.
223
+ """
224
+ x = self.conv_pre(x)
225
+ for i in range(self.num_upsamples):
226
+ x = F.leaky_relu(x, LRELU_SLOPE)
227
+ x = self.ups[i](x)
228
+ xs = None
229
+ for j in range(self.num_kernels):
230
+ if xs is None:
231
+ xs = self.resblocks[i * self.num_kernels + j](x)
232
+ else:
233
+ xs += self.resblocks[i * self.num_kernels + j](x)
234
+ x = xs / self.num_kernels
235
+ # if i == self.num_upsamples//2 - 1:
236
+ if x.shape[-2] == global_features.shape[-1]:
237
+ x += global_features.unsqueeze(-1).repeat(1, 1, x.shape[-1])
238
+ x = F.leaky_relu(x, LRELU_SLOPE)
239
+ x = self.conv_post(x)
240
+ x = torch.tanh(x)
241
+
242
+ return x
243
+
244
+ def remove_weight_norm(self):
245
+ """
246
+ Removes weight normalization from all layers in the Generator module.
247
+ """
248
+ print('Removing weight norm...')
249
+ for l in self.ups:
250
+ remove_weight_norm(l)
251
+ for l in self.resblocks:
252
+ l.remove_weight_norm()
253
+ remove_weight_norm(self.conv_pre)
254
+ remove_weight_norm(self.conv_post)
255
+
256
+
257
+ class DiscriminatorP(torch.nn.Module):
258
+ def __init__(self, period, kernel_size=5, stride=3,
259
+ use_spectral_norm=False):
260
+ super(DiscriminatorP, self).__init__()
261
+ self.period = period
262
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
263
+ self.convs = nn.ModuleList([
264
+ norm_f(
265
+ Conv2d(
266
+ 1,
267
+ 32, (kernel_size, 1), (stride, 1),
268
+ padding=(get_padding(5, 1), 0))),
269
+ norm_f(
270
+ Conv2d(
271
+ 32,
272
+ 128, (kernel_size, 1), (stride, 1),
273
+ padding=(get_padding(5, 1), 0))),
274
+ norm_f(
275
+ Conv2d(
276
+ 128,
277
+ 512, (kernel_size, 1), (stride, 1),
278
+ padding=(get_padding(5, 1), 0))),
279
+ norm_f(
280
+ Conv2d(
281
+ 512,
282
+ 1024, (kernel_size, 1), (stride, 1),
283
+ padding=(get_padding(5, 1), 0))),
284
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
285
+ ])
286
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
287
+
288
+ def forward(self, x):
289
+ fmap = []
290
+
291
+ # 1d to 2d
292
+ b, c, t = x.shape
293
+ if t % self.period != 0: # pad first
294
+ n_pad = self.period - (t % self.period)
295
+ x = F.pad(x, (0, n_pad), "reflect")
296
+ t = t + n_pad
297
+ x = x.view(b, c, t // self.period, self.period)
298
+
299
+ for l in self.convs:
300
+ x = l(x)
301
+ x = F.leaky_relu(x, LRELU_SLOPE)
302
+ fmap.append(x)
303
+ x = self.conv_post(x)
304
+ fmap.append(x)
305
+ x = torch.flatten(x, 1, -1)
306
+
307
+ return x, fmap
308
+
309
+
310
+ class MultiPeriodDiscriminator(torch.nn.Module):
311
+ def __init__(self):
312
+ super(MultiPeriodDiscriminator, self).__init__()
313
+ self.discriminators = nn.ModuleList([
314
+ DiscriminatorP(2),
315
+ DiscriminatorP(3),
316
+ DiscriminatorP(5),
317
+ DiscriminatorP(7),
318
+ DiscriminatorP(11),
319
+ ])
320
+
321
+ def forward(self, y, y_hat):
322
+ y_d_rs = []
323
+ y_d_gs = []
324
+ fmap_rs = []
325
+ fmap_gs = []
326
+ for i, d in enumerate(self.discriminators):
327
+ y_d_r, fmap_r = d(y)
328
+ y_d_g, fmap_g = d(y_hat)
329
+ y_d_rs.append(y_d_r)
330
+ fmap_rs.append(fmap_r)
331
+ y_d_gs.append(y_d_g)
332
+ fmap_gs.append(fmap_g)
333
+
334
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
335
+
336
+
337
+ class DiscriminatorS(torch.nn.Module):
338
+ def __init__(self, use_spectral_norm=False):
339
+ super(DiscriminatorS, self).__init__()
340
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
341
+ self.convs = nn.ModuleList([
342
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
343
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
344
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
345
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
346
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
347
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
348
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
349
+ ])
350
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
351
+
352
+ def forward(self, x):
353
+ fmap = []
354
+ for l in self.convs:
355
+ x = l(x)
356
+ x = F.leaky_relu(x, LRELU_SLOPE)
357
+ fmap.append(x)
358
+ x = self.conv_post(x)
359
+ fmap.append(x)
360
+ x = torch.flatten(x, 1, -1)
361
+
362
+ return x, fmap
363
+
364
+
365
+ class MultiScaleDiscriminator(torch.nn.Module):
366
+ def __init__(self):
367
+ super(MultiScaleDiscriminator, self).__init__()
368
+ self.discriminators = nn.ModuleList([
369
+ DiscriminatorS(use_spectral_norm=True),
370
+ DiscriminatorS(),
371
+ DiscriminatorS(),
372
+ ])
373
+ self.meanpools = nn.ModuleList(
374
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
375
+
376
+ def forward(self, y, y_hat):
377
+ y_d_rs = []
378
+ y_d_gs = []
379
+ fmap_rs = []
380
+ fmap_gs = []
381
+ for i, d in enumerate(self.discriminators):
382
+ if i != 0:
383
+ y = self.meanpools[i - 1](y)
384
+ y_hat = self.meanpools[i - 1](y_hat)
385
+ y_d_r, fmap_r = d(y)
386
+ y_d_g, fmap_g = d(y_hat)
387
+ y_d_rs.append(y_d_r)
388
+ fmap_rs.append(fmap_r)
389
+ y_d_gs.append(y_d_g)
390
+ fmap_gs.append(fmap_g)
391
+
392
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
393
+
394
+
395
+ def feature_loss(fmap_r, fmap_g):
396
+ loss = 0
397
+ for dr, dg in zip(fmap_r, fmap_g):
398
+ for rl, gl in zip(dr, dg):
399
+ loss += torch.mean(torch.abs(rl - gl))
400
+
401
+ return loss * 2
402
+
403
+
404
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
405
+ loss = 0
406
+ r_losses = []
407
+ g_losses = []
408
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
409
+ r_loss = torch.mean((1 - dr)**2)
410
+ g_loss = torch.mean(dg**2)
411
+ loss += (r_loss + g_loss)
412
+ r_losses.append(r_loss.item())
413
+ g_losses.append(g_loss.item())
414
+
415
+ return loss, r_losses, g_losses
416
+
417
+
418
+ def generator_loss(disc_outputs):
419
+ loss = 0
420
+ gen_losses = []
421
+ for dg in disc_outputs:
422
+ l = torch.mean((1 - dg)**2)
423
+ gen_losses.append(l)
424
+ loss += l
425
+
426
+ return loss, gen_losses
427
+
428
+
429
+ class Encoder(torch.nn.Module):
430
+ def __init__(self, h):
431
+ super(Encoder, self).__init__()
432
+ self.h = h
433
+ self.num_kernels = len(h.resblock_kernel_sizes)
434
+ self.num_upsamples = len(h.upsample_rates)
435
+ self.conv_pre = weight_norm(Conv1d(1, 32, 7, 1, padding=3))
436
+ self.normalize = nn.ModuleList()
437
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
438
+
439
+ self.ups = nn.ModuleList()
440
+ for i, (u, k) in enumerate(
441
+ list(
442
+ reversed(
443
+ list(zip(h.upsample_rates, h.upsample_kernel_sizes))))):
444
+ self.ups.append(
445
+ weight_norm(
446
+ Conv1d(
447
+ 32 * (2**i),
448
+ 32 * (2**(i + 1)),
449
+ k,
450
+ u,
451
+ padding=((k - u) // 2)
452
+ # padding=(u//2 + u%2)
453
+ )))
454
+ self.resblocks = nn.ModuleList()
455
+ for i in range(len(self.ups)):
456
+ ch = 32 * (2**(i + 1))
457
+ for j, (k, d) in enumerate(
458
+ zip(
459
+ list(reversed(h.resblock_kernel_sizes)),
460
+ list(reversed(h.resblock_dilation_sizes)))):
461
+ self.resblocks.append(resblock(h, ch, k, d))
462
+ self.normalize.append(
463
+ torch.nn.GroupNorm(ch // 16, ch, eps=1e-6, affine=True))
464
+ self.conv_post = Conv1d(512, 512, 3, 1, padding=1)
465
+ self.ups.apply(init_weights)
466
+ self.conv_post.apply(init_weights)
467
+ self.linear = nn.Sequential(
468
+ nn.Linear(128, 128),
469
+ nn.LeakyReLU(LRELU_SLOPE)
470
+ )
471
+ self.gfc = h.global_feature_conv
472
+ self.GlobalTokenEncoder = GlobalTokenEncoder(self.gfc[0], self.gfc[1], self.gfc[2], self.gfc[3], self.gfc[4])
473
+ self.GlobalTokenEncoder.apply(init_weights)
474
+
475
+ def forward(self, x, xx=None):
476
+ x = self.conv_pre(x)
477
+ global_features = None
478
+ for i in range(self.num_upsamples):
479
+ x = F.leaky_relu(x, LRELU_SLOPE)
480
+ x = self.ups[i](x)
481
+ xs = None
482
+ for j in range(self.num_kernels):
483
+ if xs is None:
484
+ xs = self.resblocks[i * self.num_kernels + j](x)
485
+ xs = self.normalize[i * self.num_kernels + j](xs)
486
+ else:
487
+ xs += self.resblocks[i * self.num_kernels + j](x)
488
+ xs = self.normalize[i * self.num_kernels + j](xs)
489
+ x = xs / self.num_kernels
490
+ if i == self.num_upsamples//2 - 1:
491
+ mid_features = x
492
+ global_features = self.GlobalTokenEncoder(x)
493
+ x = F.leaky_relu(x)
494
+ x = self.conv_post(x)
495
+ if xx is not None:
496
+ xx = self.conv_pre(xx)
497
+ global_features2 = None
498
+ for i in range(self.num_upsamples//2):
499
+ xx = F.leaky_relu(xx, LRELU_SLOPE)
500
+ xx = self.ups[i](xx)
501
+ xxs = None
502
+ for j in range(self.num_kernels):
503
+ if xxs is None:
504
+ xxs = self.resblocks[i * self.num_kernels + j](xx)
505
+ xxs = self.normalize[i * self.num_kernels + j](xxs)
506
+ else:
507
+ xxs += self.resblocks[i * self.num_kernels + j](xx)
508
+ xxs = self.normalize[i * self.num_kernels + j](xxs)
509
+ xx = xxs / self.num_kernels
510
+ mid_features2 = xx
511
+ global_features2 = self.GlobalTokenEncoder(xx)
512
+ global_features2 = global_features2.detach()
513
+ return x, global_features, global_features2
514
+ return x, global_features
515
+
516
+ def remove_weight_norm(self):
517
+ print('Removing weight norm...')
518
+ for l in self.ups:
519
+ remove_weight_norm(l)
520
+ for l in self.resblocks:
521
+ l.remove_weight_norm()
522
+ remove_weight_norm(self.conv_pre)
523
+
524
+
525
+ class Quantizer_module(torch.nn.Module):
526
+ def __init__(self, n_e, e_dim):
527
+ super(Quantizer_module, self).__init__()
528
+ self.embedding = nn.Embedding(n_e, e_dim)
529
+ self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
530
+
531
+ def forward(self, x):
532
+ # compute Euclidean distance
533
+ d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) \
534
+ - 2 * torch.matmul(x, self.embedding.weight.T)
535
+ min_indicies = torch.argmin(d, 1)
536
+ z_q = self.embedding(min_indicies)
537
+ return z_q, min_indicies
538
+
539
+
540
+ class Quantizer(torch.nn.Module):
541
+ def __init__(self, h):
542
+ super(Quantizer, self).__init__()
543
+ assert 512 % h.n_code_groups == 0
544
+ self.quantizer_modules = nn.ModuleList([
545
+ Quantizer_module(h.n_codes, 512 // h.n_code_groups)
546
+ for _ in range(h.n_code_groups)
547
+ ])
548
+ self.residul_layer = h.residul_layer
549
+ if h.residul_layer == 2:
550
+ self.quantizer_modules2 = nn.ModuleList([
551
+ Quantizer_module(h.n_codes, 512 // h.n_code_groups)
552
+ for _ in range(h.n_code_groups)
553
+ ])
554
+ if h.residul_layer == 4:
555
+ self.quantizer_modules2 = nn.ModuleList([
556
+ Quantizer_module(h.n_codes, 512 // h.n_code_groups)
557
+ for _ in range(h.n_code_groups)
558
+ ])
559
+ self.quantizer_modules3 = nn.ModuleList([
560
+ Quantizer_module(h.n_codes, 512 // h.n_code_groups)
561
+ for _ in range(h.n_code_groups)
562
+ ])
563
+ self.quantizer_modules4 = nn.ModuleList([
564
+ Quantizer_module(h.n_codes, 512 // h.n_code_groups)
565
+ for _ in range(h.n_code_groups)
566
+ ])
567
+
568
+ self.quantizer_modules_globaltokens = nn.ModuleList([
569
+ Quantizer_module(h.n_codes, 128//h.global_code_num)
570
+ for _ in range(h.global_code_num)
571
+ ])
572
+ # self.quantizer_modules3 = nn.ModuleList([
573
+ # Quantizer_module(h.n_codes, 128//h.global_code_num)
574
+ # for _ in range(h.global_code_num)
575
+ # ])
576
+ self.h = h
577
+ self.codebook_loss_lambda = self.h.codebook_loss_lambda # e.g., 1
578
+ self.commitment_loss_lambda = self.h.commitment_loss_lambda # e.g., 0.25
579
+ # self.residul_layer = 2
580
+ self.n_code_groups = h.n_code_groups
581
+ self.global_code_num = h.global_code_num
582
+
583
+ def for_one_step(self, xin, idx):
584
+ xin = xin.transpose(1, 2)
585
+ x = xin.reshape(-1, 512)
586
+ x = torch.split(x, 512 // self.h.n_code_groups, dim=-1)
587
+ min_indicies = []
588
+ z_q = []
589
+ if idx == 0:
590
+ for _x, m in zip(x, self.quantizer_modules):
591
+ _z_q, _min_indicies = m(_x)
592
+ z_q.append(_z_q)
593
+ min_indicies.append(_min_indicies) #B * T,
594
+ elif idx == 1:
595
+ for _x, m in zip(x, self.quantizer_modules2):
596
+ _z_q, _min_indicies = m(_x)
597
+ z_q.append(_z_q)
598
+ min_indicies.append(_min_indicies) #B * T,
599
+ elif idx == 2:
600
+ for _x, m in zip(x, self.quantizer_modules3):
601
+ _z_q, _min_indicies = m(_x)
602
+ z_q.append(_z_q)
603
+ min_indicies.append(_min_indicies)
604
+ elif idx == 3:
605
+ for _x, m in zip(x, self.quantizer_modules4):
606
+ _z_q, _min_indicies = m(_x)
607
+ z_q.append(_z_q)
608
+ min_indicies.append(_min_indicies)
609
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
610
+ # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
611
+ loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \
612
+ + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2)
613
+ z_q = xin + (z_q - xin).detach()
614
+ z_q = z_q.transpose(1, 2)
615
+ return z_q, loss, min_indicies
616
+
617
+ def for_one_step_gst(self, xin):
618
+ # xin = xin.transpose(1, 2)
619
+ x = xin.reshape(-1, 128) #B * 1, 128
620
+ x = torch.split(x, 128 // self.global_code_num, dim=-1)
621
+ min_indicies = []
622
+ z_q = []
623
+ for _x, m in zip(x, self.quantizer_modules_globaltokens):
624
+ _z_q, _min_indicies = m(_x)
625
+ z_q.append(_z_q)
626
+ min_indicies.append(_min_indicies)
627
+ # for _x, m in zip(x, self.quantizer_modules3):
628
+ # _z_q, _min_indicies = m(_x)
629
+ # z_q.append(_z_q)
630
+ # min_indicies.append(_min_indicies)
631
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
632
+ # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
633
+ loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \
634
+ + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2)
635
+ z_q = xin + (z_q - xin).detach()
636
+ z_q = z_q.squeeze(1)
637
+ return z_q, loss, min_indicies
638
+
639
+ def forward(self, xin, global_style):
640
+ #B, C, T
641
+ quantized_out = 0.0
642
+ residual = xin
643
+ all_losses = []
644
+ all_indices = []
645
+ for i in range(self.residul_layer):
646
+ quantized, loss, indices = self.for_one_step(residual, i) #
647
+ residual = residual - quantized
648
+ quantized_out = quantized_out + quantized
649
+ all_indices.extend(indices) #
650
+ all_losses.append(loss)
651
+ all_losses = torch.stack(all_losses)
652
+ loss = torch.mean(all_losses)
653
+ global_style_quantized, loss_gst_vq, global_style_tokens= self.for_one_step_gst(global_style)
654
+ loss += loss_gst_vq
655
+ # global_style_quantized = global_style
656
+ # global_style_tokens = global_style
657
+ # global_style_quantized = global_style_quantized.squeeze(1)
658
+ # global_style_tokens = global_style_tokens.squeeze(1)
659
+ return quantized_out, loss, all_indices, global_style_quantized, global_style_tokens
660
+
661
+ def embed(self, x):
662
+ #idx: N, T, 4
663
+ #print('x ', x.shape)
664
+ quantized_out = torch.tensor(0.0, device=x.device)
665
+ x = torch.split(x, 1, 2) # split, 将最后一个维度分开, 每个属于一个index group
666
+ #print('x.shape ', len(x),x[0].shape)
667
+ for i in range(self.residul_layer):
668
+ ret = []
669
+ if i == 0:
670
+ for j in range(self.n_code_groups):
671
+ q = x[j]
672
+ embed = self.quantizer_modules[j]
673
+ q = embed.embedding(q.squeeze(-1))
674
+ ret.append(q)
675
+ ret = torch.cat(ret, -1)
676
+ #print(ret.shape)
677
+ quantized_out = quantized_out + ret
678
+ elif i == 1:
679
+ for j in range(self.n_code_groups):
680
+ q = x[j + self.n_code_groups]
681
+ embed = self.quantizer_modules2[j]
682
+ q = embed.embedding(q.squeeze(-1))
683
+ ret.append(q)
684
+ ret = torch.cat(ret, -1)
685
+ quantized_out = quantized_out + ret
686
+ elif i == 2:
687
+ for j in range(self.n_code_groups):
688
+ q = x[j + self.n_code_groups * 2]
689
+ embed = self.quantizer_modules3[j]
690
+ q = embed.embedding(q.squeeze(-1))
691
+ ret.append(q)
692
+ ret = torch.cat(ret, -1)
693
+ quantized_out = quantized_out + ret
694
+ elif i == 3:
695
+ for j in range(self.n_code_groups):
696
+ q = x[j + self.n_code_groups * 3]
697
+ embed = self.quantizer_modules4[j]
698
+ q = embed.embedding(q.squeeze(-1))
699
+ ret.append(q)
700
+ ret = torch.cat(ret, -1)
701
+ quantized_out = quantized_out + ret
702
+ return quantized_out.transpose(1, 2) #N, C, T
703
+ def embed_gst(self, x):
704
+ quantized_out = torch.tensor(0.0, device=x.device)
705
+ ret = []
706
+ x = torch.split(x, 1, 2)
707
+ for j in range(self.global_code_num):
708
+ q = x[j]
709
+ embed = self.quantizer_modules_globaltokens[j]
710
+ # embed = self.quantizer_modules3[j]
711
+ q = embed.embedding(q.squeeze(-1))
712
+ ret.append(q)
713
+ ret = torch.cat(ret, -1)
714
+ quantized_out = quantized_out + ret
715
+ return quantized_out.transpose(1, 2)
716
+ # return x
vita/model/vita_tts/decoder/ticodec/vqvae.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from vita.model.vita_tts.decoder.ticodec.models import Encoder
7
+ from vita.model.vita_tts.decoder.ticodec.models import Generator
8
+ from vita.model.vita_tts.decoder.ticodec.models import Quantizer
9
+
10
+ class AttrDict(dict):
11
+ def __init__(self, *args, **kwargs):
12
+ super(AttrDict, self).__init__(*args, **kwargs)
13
+ self.__dict__ = self
14
+
15
+ class VQVAE(nn.Module):
16
+ def __init__(self,
17
+ config_path,
18
+ ckpt_path,
19
+ with_encoder=False):
20
+ super(VQVAE, self).__init__()
21
+ ckpt = torch.load(ckpt_path)
22
+ with open(config_path) as f:
23
+ data = f.read()
24
+ json_config = json.loads(data)
25
+ self.h = AttrDict(json_config)
26
+ # self.gst = GST()
27
+ # self.gst = Proposed(n_specs=128, token_num=10, E=128, n_layers=4)
28
+ self.quantizer = Quantizer(self.h)
29
+ self.generator = Generator(self.h)
30
+ self.generator.load_state_dict(ckpt['generator'])
31
+ self.quantizer.load_state_dict(ckpt['quantizer'])
32
+ # self.gst.load_state_dict(ckpt['gst'])
33
+ if with_encoder:
34
+ self.encoder = Encoder(self.h)
35
+ self.encoder.load_state_dict(ckpt['encoder'])
36
+
37
+ def forward(self, x, global_style_token):
38
+ # x is the codebook
39
+ # x.shape (B, T, Nq)
40
+ quant_emb = self.quantizer.embed(x)
41
+ global_style_quantized_emb = self.quantizer.embed_gst(global_style_token).squeeze(-1)
42
+ return self.generator(quant_emb, global_style_quantized_emb)
43
+
44
+ def encode(self, x):
45
+ batch_size = x.size(0)
46
+ if len(x.shape) == 3 and x.shape[-1] == 1:
47
+ x = x.squeeze(-1)
48
+ # print(x.shape)
49
+
50
+ c, global_features = self.encoder(x.unsqueeze(1))
51
+ # mid = mid.transpose(1, 2).unsqueeze(1)
52
+ # global_style = self.gst(mid)
53
+ q, loss_q, local_token, g, global_style_token = self.quantizer(c, global_features)
54
+ local_token = [code.reshape(batch_size, -1) for code in local_token]
55
+ global_style_token = torch.stack(global_style_token, -1).unsqueeze(1)
56
+ # shape: [N, T, 4]
57
+ return torch.stack(local_token, -1), global_style_token
vita/model/vita_tts/decoder/ticodec/vqvae_tester.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import librosa
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from vita.model.vita_tts.decoder.ticodec.vqvae import VQVAE
8
+
9
+ class VqvaeTester(nn.Module):
10
+ def __init__(self, config_path, model_path, sample_rate=24000):
11
+ super().__init__()
12
+ self.vqvae = VQVAE(config_path, model_path, with_encoder=True)
13
+ self.sample_rate = sample_rate
14
+
15
+ @torch.no_grad()
16
+ def forward(self, wav_path):
17
+ # 单声道
18
+ # wav.shape (T, ), 按照模型的 sr 读取
19
+ wav, sr = librosa.load(wav_path, sr=self.sample_rate)
20
+ fid = os.path.basename(wav_path)[:-4]
21
+ wav = torch.tensor(wav).unsqueeze(0)
22
+ wav = wav.cuda()
23
+ # vq_codes is acoustic token
24
+ vq_codes, global_token = self.vqvae.encode(wav)
25
+ import pdb; pdb.set_trace()
26
+ syn = self.vqvae(vq_codes, global_token)
27
+ return fid, syn
28
+
29
+ @torch.no_grad()
30
+ def vq(self, wav_path):
31
+ wav, sr = librosa.load(wav_path, sr=self.sample_rate)
32
+ fid = os.path.basename(wav_path)[:-4]
33
+ wav = torch.tensor(wav).unsqueeze(0)
34
+ wav = wav.cuda()
35
+ # vq_codes is acoustic token
36
+ vq_codes, global_token = self.vqvae.encode(wav)
37
+ return fid, vq_codes, global_token
vita/model/vita_tts/encoder/attention.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import numpy
5
+ import pdb
6
+
7
+ class PositionalEncoding(torch.nn.Module):
8
+ """Positional encoding.
9
+ :param int d_model: embedding dim
10
+ :param float dropout_rate: dropout rate
11
+ :param int max_len: maximum input length
12
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
13
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
14
+ """
15
+ def __init__(self,
16
+ d_model: int,
17
+ dropout_rate: float,
18
+ max_len: int = 1500,
19
+ reverse: bool = False):
20
+ """Construct an PositionalEncoding object."""
21
+ super().__init__()
22
+ self.d_model = d_model
23
+ self.xscale = math.sqrt(self.d_model)
24
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
25
+ self.max_len = max_len
26
+
27
+ self.pe = torch.zeros(self.max_len, self.d_model)
28
+ position = torch.arange(0, self.max_len,
29
+ dtype=torch.float32).unsqueeze(1)
30
+ div_term = torch.exp(
31
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
32
+ -(math.log(10000.0) / self.d_model))
33
+ self.pe[:, 0::2] = torch.sin(position * div_term)
34
+ self.pe[:, 1::2] = torch.cos(position * div_term)
35
+ self.pe = self.pe.unsqueeze(0)
36
+
37
+ def forward(self,
38
+ x: torch.Tensor,
39
+ offset: int = 0):
40
+ """Add positional encoding.
41
+ Args:
42
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
43
+ offset (int): position offset
44
+ Returns:
45
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
46
+ torch.Tensor: for compatibility to RelPositionalEncoding
47
+ """
48
+ assert offset + x.size(1) < self.max_len
49
+ self.pe = self.pe.to(x.device)
50
+ pos_emb = self.pe[:, offset:offset + x.size(1)]
51
+ x = x * self.xscale + pos_emb
52
+ return self.dropout(x), self.dropout(pos_emb)
53
+
54
+ def position_encoding(self, offset: int, size: int):
55
+ """ For getting encoding in a streaming fashion
56
+ Attention!!!!!
57
+ we apply dropout only once at the whole utterance level in a none
58
+ streaming way, but will call this function several times with
59
+ increasing input size in a streaming scenario, so the dropout will
60
+ be applied several times.
61
+ Args:
62
+ offset (int): start offset
63
+ size (int): requried size of position encoding
64
+ Returns:
65
+ torch.Tensor: Corresponding encoding
66
+ """
67
+ assert offset + size < self.max_len
68
+ return self.dropout(self.pe[:, offset:offset + size])
69
+
70
+ class RelPositionalEncoding(PositionalEncoding):
71
+ """Relative positional encoding module.
72
+ See : Appendix B in https://arxiv.org/abs/1901.02860
73
+ Args:
74
+ d_model (int): Embedding dimension.
75
+ dropout_rate (float): Dropout rate.
76
+ max_len (int): Maximum input length.
77
+ """
78
+ def __init__(self, d_model: int, dropout_rate: float, chunk_size: int, left_chunks: int, max_len: int = 5000):
79
+ """Initialize class."""
80
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
81
+ self.chunk_size = chunk_size
82
+ self.left_chunks = left_chunks
83
+ self.full_chunk_size = (self.left_chunks + 1) * self.chunk_size
84
+
85
+ self.div_term = torch.exp(
86
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
87
+ -(math.log(10000.0) / self.d_model))
88
+ self.max_len = self.chunk_size * (max_len // self.chunk_size) - self.full_chunk_size
89
+
90
+ def forward(self,
91
+ x: torch.Tensor,
92
+ offset: int = 0):
93
+ """Compute positional encoding.
94
+ Args:
95
+ x (torch.Tensor): Input tensor (batch, time, `*`).
96
+ Returns:
97
+ torch.Tensor: Encoded tensor (batch, time, `*`).
98
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
99
+ """
100
+ self.pe = self.pe.to(x.device)
101
+ x = x * self.xscale
102
+ pos_emb = self.pe[:, offset:offset + x.size(1)]
103
+ return self.dropout(x), self.dropout(pos_emb)
104
+
105
+ def infer(self, xs, pe_index, pe_length):
106
+ # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
107
+ pe_index = pe_index % self.max_len
108
+ xs = xs * self.xscale
109
+
110
+ # pe = torch.zeros(self.full_chunk_size, self.d_model)
111
+ pe = torch.zeros(pe_length, self.d_model)
112
+ position = torch.arange(max(0, pe_index-self.full_chunk_size),
113
+ max(0, pe_index-self.full_chunk_size)
114
+ + pe_length, # self.full_chunk_size,
115
+ dtype=torch.float32).unsqueeze(1)
116
+ pe[:, 0::2] = torch.sin(position * self.div_term)
117
+ pe[:, 1::2] = torch.cos(position * self.div_term)
118
+ pos_emb = pe.unsqueeze(0)
119
+
120
+ pe_index = pe_index + self.chunk_size
121
+ return xs, pos_emb, pe_index
122
+
123
+ class PositionwiseFeedForward(torch.nn.Module):
124
+ """Positionwise feed forward layer.
125
+ :param int idim: input dimenstion
126
+ :param int hidden_units: number of hidden units
127
+ :param float dropout_rate: dropout rate
128
+ """
129
+
130
+ def __init__(self, idim, hidden_units, dropout_rate):
131
+ """Construct an PositionwiseFeedForward object."""
132
+ super(PositionwiseFeedForward, self).__init__()
133
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
134
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
135
+ self.dropout = torch.nn.Dropout(dropout_rate)
136
+
137
+ def forward(self, x):
138
+ """Forward funciton."""
139
+ return self.w_2(self.dropout(torch.relu(self.w_1(x))))
140
+
141
+ def infer(self, xs, buffer, buffer_index, buffer_out):
142
+ # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
143
+ return self.w_2(torch.relu(self.w_1(xs))), buffer, buffer_index, buffer_out
144
+
145
+ class MultiLayeredConv1d(torch.nn.Module):
146
+ """Multi-layered conv1d for Transformer block.
147
+
148
+ This is a module of multi-leyered conv1d designed
149
+ to replace positionwise feed-forward network
150
+ in Transformer block, which is introduced in
151
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
152
+
153
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
154
+ https://arxiv.org/pdf/1905.09263.pdf
155
+
156
+ """
157
+
158
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
159
+ """Initialize MultiLayeredConv1d module.
160
+
161
+ Args:
162
+ in_chans (int): Number of input channels.
163
+ hidden_chans (int): Number of hidden channels.
164
+ kernel_size (int): Kernel size of conv1d.
165
+ dropout_rate (float): Dropout rate.
166
+
167
+ """
168
+ super(MultiLayeredConv1d, self).__init__()
169
+ self.w_1 = torch.nn.Conv1d(
170
+ in_chans,
171
+ hidden_chans,
172
+ kernel_size,
173
+ stride=1,
174
+ padding=(kernel_size - 1) // 2,
175
+ )
176
+ self.w_2 = torch.nn.Conv1d(
177
+ hidden_chans,
178
+ in_chans,
179
+ kernel_size,
180
+ stride=1,
181
+ padding=(kernel_size - 1) // 2,
182
+ )
183
+ self.dropout = torch.nn.Dropout(dropout_rate)
184
+
185
+ def forward(self, x):
186
+ """Calculate forward propagation.
187
+
188
+ Args:
189
+ x (Tensor): Batch of input tensors (B, ..., in_chans).
190
+
191
+ Returns:
192
+ Tensor: Batch of output tensors (B, ..., hidden_chans).
193
+
194
+ """
195
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
196
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
197
+
198
+ class Conv1dLinear(torch.nn.Module):
199
+ """Conv1D + Linear for Transformer block.
200
+
201
+ A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
202
+
203
+ """
204
+
205
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
206
+ """Initialize Conv1dLinear module.
207
+
208
+ Args:
209
+ in_chans (int): Number of input channels.
210
+ hidden_chans (int): Number of hidden channels.
211
+ kernel_size (int): Kernel size of conv1d.
212
+ dropout_rate (float): Dropout rate.
213
+
214
+ """
215
+ super(Conv1dLinear, self).__init__()
216
+ self.lorder = (kernel_size - 1)
217
+ self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
218
+ self.w_1 = torch.nn.Sequential(
219
+ torch.nn.Conv1d(
220
+ in_chans,
221
+ in_chans,
222
+ kernel_size,
223
+ stride=1,
224
+ padding=0,
225
+ groups=in_chans
226
+ ),
227
+ torch.nn.Conv1d(
228
+ in_chans,
229
+ hidden_chans,
230
+ 1,
231
+ padding=0
232
+ )
233
+ )
234
+ self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
235
+ self.dropout = torch.nn.Dropout(dropout_rate)
236
+ self.in_chans = in_chans
237
+
238
+ # cnn_buffer = 1, in_chans, self.lorder
239
+ self.buffer_size = 1 * self.in_chans * self.lorder
240
+
241
+ def forward(self, x):
242
+ """Calculate forward propagation.
243
+
244
+ Args:
245
+ x (Tensor): Batch of input tensors (B, ..., in_chans).
246
+
247
+ Returns:
248
+ Tensor: Batch of output tensors (B, ..., hidden_chans).
249
+
250
+ """
251
+ x = torch.relu(self.w_1(self.left_padding(x.transpose(-1, 1)))).transpose(-1, 1)
252
+ return self.w_2(self.dropout(x))
253
+
254
+ def infer(self, x, buffer, buffer_index, buffer_out):
255
+ # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
256
+ x = x.transpose(-1, 1)
257
+
258
+ cnn_buffer = buffer[buffer_index: buffer_index + self.buffer_size].reshape([1, self.in_chans, self.lorder])
259
+ x = torch.cat([cnn_buffer, x], dim=2)
260
+ buffer_out.append(x[:, :, -self.lorder:].reshape(-1))
261
+ buffer_index = buffer_index + self.buffer_size
262
+
263
+ x = self.w_1(x)
264
+ x = torch.relu(x).transpose(-1, 1)
265
+ x = self.w_2(x)
266
+ return x, buffer, buffer_index, buffer_out
267
+
268
+ class MultiHeadedAttention(nn.Module):
269
+ """Multi-Head Attention layer.
270
+
271
+ :param int n_head: the number of head s
272
+ :param int n_feat: the number of features
273
+ :param float dropout_rate: dropout rate
274
+
275
+ """
276
+ def __init__(self, n_head, n_feat, dropout_rate, chunk_size, left_chunks, pos_enc_class):
277
+ """Construct an MultiHeadedAttention object."""
278
+ super(MultiHeadedAttention, self).__init__()
279
+ assert n_feat % n_head == 0
280
+ # We assume d_v always equals d_k
281
+ self.d_k = n_feat // n_head
282
+ self.h = n_head
283
+ self.linear_q = nn.Linear(n_feat, n_feat)
284
+ self.linear_k = nn.Linear(n_feat, n_feat)
285
+ self.linear_v = nn.Linear(n_feat, n_feat)
286
+ self.linear_out = nn.Linear(n_feat, n_feat)
287
+ self.dropout = nn.Dropout(p=dropout_rate)
288
+ self.min_value = float(numpy.finfo(torch.tensor(0, dtype=torch.float16).numpy().dtype).min)
289
+ # chunk par
290
+ if chunk_size > 0 and left_chunks > 0: #for streaming mode
291
+ self.buffersize = chunk_size * (left_chunks)
292
+ self.left_chunk_size = chunk_size * left_chunks
293
+ else: # for non-streaming mode
294
+ self.buffersize = 1
295
+ self.left_chunk_size = 1
296
+ self.chunk_size = chunk_size
297
+
298
+ # encoding setup
299
+ if pos_enc_class == "rel-enc":
300
+ self.rel_enc = True
301
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
302
+ # these two learnable bias are used in matrix c and matrix d
303
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
304
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
305
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
306
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
307
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
308
+ else:
309
+ self.rel_enc = False
310
+ self.linear_pos = nn.Identity()
311
+ self.pos_bias_u = torch.tensor([0])
312
+ self.pos_bias_v = torch.tensor([0])
313
+
314
+ # buffer
315
+ # key_buffer = 1, self.h, self.buffersize, self.d_k
316
+ self.key_buffer_size = 1 * self.h * self.buffersize * self.d_k
317
+ # value_buffer = 1, self.h, self.buffersize, self.d_k
318
+ self.value_buffer_size = 1 * self.h * self.buffersize * self.d_k
319
+ if self.chunk_size > 0:
320
+ # buffer_mask_size = 1, self.h, self.chunk_size, self.buffersize
321
+ self.buffer_mask_size = 1 * self.h * self.chunk_size * self.buffersize
322
+ else:
323
+ self.buffer_mask = torch.ones([1, self.h, 1, 1], dtype=torch.bool)
324
+
325
+ def rel_shift(self, x, zero_triu: bool = False):
326
+ """Compute relative positinal encoding.
327
+ Args:
328
+ x (torch.Tensor): Input tensor (batch, time, size).
329
+ zero_triu (bool): If true, return the lower triangular part of
330
+ the matrix.
331
+ Returns:
332
+ torch.Tensor: Output tensor.
333
+ """
334
+
335
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
336
+ device=x.device,
337
+ dtype=x.dtype)
338
+ x_padded = torch.cat([zero_pad, x], dim=-1)
339
+
340
+ x_padded = x_padded.view(x.size()[0],
341
+ x.size()[1],
342
+ x.size(3) + 1, x.size(2))
343
+ x = x_padded[:, :, 1:].view_as(x)
344
+
345
+ if zero_triu:
346
+ ones = torch.ones((x.size(2), x.size(3)))
347
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
348
+ return x
349
+
350
+ def forward(self, query, key, value, mask=None, pos_emb=torch.tensor(1.0)):
351
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], Tensor) -> Tensor
352
+ """Compute 'Scaled Dot Product Attention'.
353
+
354
+ :param torch.Tensor query: (batch, time1, size)
355
+ :param torch.Tensor key: (batch, time2, size)
356
+ :param torch.Tensor value: (batch, time2, size)
357
+ :param torch.Tensor mask: (batch, time1, time2)
358
+ :param torch.nn.Dropout dropout:
359
+ :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
360
+ weighted by the query dot key attention (batch, head, time1, time2)
361
+ """
362
+ n_batch = query.size(0)
363
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
364
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
365
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
366
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
367
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
368
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
369
+
370
+ if self.rel_enc:
371
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
372
+ n_batch_pos = pos_emb.size(0)
373
+ p = self.linear_pos(pos_emb.to(query.dtype)).view(n_batch_pos, -1, self.h, self.d_k)
374
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
375
+ # (batch, head, time1, d_k)
376
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
377
+ # (batch, head, time1, d_k)
378
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
379
+ # compute attention score
380
+ # first compute matrix a and matrix c
381
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
382
+ # (batch, head, time1, time2)
383
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
384
+ # compute matrix b and matrix d
385
+ # (batch, head, time1, time2)
386
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
387
+ # Remove rel_shift since it is useless in speech recognition,
388
+ # and it requires special attention for streaming.
389
+ # matrix_bd = self.rel_shift(matrix_bd)
390
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
391
+ else:
392
+ scores = torch.matmul(q, k.transpose(-2, -1) ) / math.sqrt(self.d_k) # (batch, head, time1, time2)
393
+
394
+ if mask is not None:
395
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
396
+ scores = scores.masked_fill(mask, self.min_value)
397
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
398
+ else:
399
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
400
+
401
+ p_attn = self.dropout(attn)
402
+
403
+ x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
404
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
405
+ return self.linear_out(x) # (batch, time1, d_model)
406
+
407
+ def infer(self, query, key, value, pos_emb, buffer, buffer_index, buffer_out):
408
+ # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
409
+ n_batch = query.size(0)
410
+
411
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2) # (batch, head, len_q, d_k)
412
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2) # (batch, head, len_k, d_k)
413
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2) # (batch, head, len_v, d_k)
414
+
415
+ key_value_buffer = buffer[buffer_index]
416
+ if buffer[buffer_index] is None:
417
+ buffer[buffer_index] = [None, None]
418
+ key_buffer = k
419
+ value_buffer = v
420
+ else:
421
+ key_buffer = torch.cat([key_value_buffer[0], k], dim=2)
422
+ value_buffer = torch.cat([key_value_buffer[1], v], dim=2)
423
+ if key_buffer.size(2) > self.buffersize:
424
+ buffer[buffer_index][0] = key_buffer[:, :, -self.buffersize:, :]
425
+ buffer[buffer_index][1] = value_buffer[:, :, -self.buffersize:, :]
426
+ else:
427
+ buffer[buffer_index] = [key_buffer, value_buffer]
428
+ buffer_index += 1
429
+
430
+ if self.rel_enc:
431
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
432
+ n_batch_pos = pos_emb.size(0)
433
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
434
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
435
+ # (batch, head, time1, d_k)
436
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
437
+ # (batch, head, time1, d_k)
438
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
439
+ # compute attention score
440
+ # first compute matrix a and matrix c
441
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
442
+ # (batch, head, time1, time2)
443
+ matrix_ac = torch.matmul(q_with_bias_u, key_buffer.transpose(-2, -1))
444
+ # compute matrix b and matrix d
445
+ # (batch, head, time1, time2)
446
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
447
+ # Remove rel_shift since it is useless in speech recognition,
448
+ # and it requires special attention for streaming.
449
+ # matrix_bd = self.rel_shift(matrix_bd)
450
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
451
+ else:
452
+ # (batch, head, len_q, buffersize)
453
+ scores = torch.matmul(q, key_buffer.transpose(-2, -1) ) / math.sqrt(self.d_k)
454
+
455
+ attn = torch.softmax(scores, dim=-1)
456
+
457
+ x = torch.matmul(attn, value_buffer) # (batch, head, len_q, d_k)
458
+ x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
459
+ return self.linear_out(x), buffer, buffer_index, buffer_out # (batch, time1, d_model)
vita/model/vita_tts/encoder/cmvn.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import math
4
+
5
+ import numpy as np
6
+
7
+ class GlobalCMVN(torch.nn.Module):
8
+ def __init__(self,
9
+ mean: torch.Tensor,
10
+ istd: torch.Tensor,
11
+ norm_var: bool = True):
12
+ """
13
+ Args:
14
+ mean (torch.Tensor): mean stats
15
+ istd (torch.Tensor): inverse std, std which is 1.0 / std
16
+ """
17
+ super().__init__()
18
+ assert mean.shape == istd.shape
19
+ self.norm_var = norm_var
20
+ # The buffer can be accessed from this module using self.mean
21
+ self.register_buffer("mean", mean)
22
+ self.register_buffer("istd", istd)
23
+
24
+ def forward(self, x: torch.Tensor):
25
+ """
26
+ Args:
27
+ x (torch.Tensor): (batch, max_len, feat_dim)
28
+
29
+ Returns:
30
+ (torch.Tensor): normalized feature
31
+ """
32
+ x = x - self.mean
33
+ if self.norm_var:
34
+ x = x * self.istd
35
+ return x
36
+
37
+ def _load_json_cmvn(json_cmvn_file):
38
+ """ Load the json format cmvn stats file and calculate cmvn
39
+
40
+ Args:
41
+ json_cmvn_file: cmvn stats file in json format
42
+
43
+ Returns:
44
+ a numpy array of [means, vars]
45
+ """
46
+ with open(json_cmvn_file) as f:
47
+ cmvn_stats = json.load(f)
48
+
49
+ means = cmvn_stats['mean_stat']
50
+ variance = cmvn_stats['var_stat']
51
+ count = cmvn_stats['frame_num']
52
+ for i in range(len(means)):
53
+ means[i] /= count
54
+ variance[i] = variance[i] / count - means[i] * means[i]
55
+ if variance[i] < 1.0e-20:
56
+ variance[i] = 1.0e-20
57
+ variance[i] = 1.0 / math.sqrt(variance[i])
58
+ cmvn = np.array([means, variance])
59
+ return cmvn
60
+
61
+ def _load_kaldi_cmvn(kaldi_cmvn_file):
62
+ """ Load the kaldi format cmvn stats file and calculate cmvn
63
+
64
+ Args:
65
+ kaldi_cmvn_file: kaldi text style global cmvn file, which
66
+ is generated by:
67
+ compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
68
+
69
+ Returns:
70
+ a numpy array of [means, vars]
71
+ """
72
+ means = []
73
+ variance = []
74
+ with open(kaldi_cmvn_file, 'r') as fid:
75
+ # kaldi binary file start with '\0B'
76
+ if fid.read(2) == '\0B':
77
+ print('kaldi cmvn binary file is not supported, please '
78
+ 'recompute it by: compute-cmvn-stats --binary=false '
79
+ ' scp:feats.scp global_cmvn')
80
+ sys.exit(1)
81
+ fid.seek(0)
82
+ arr = fid.read().split()
83
+ assert (arr[0] == '[')
84
+ assert (arr[-2] == '0')
85
+ assert (arr[-1] == ']')
86
+ feat_dim = int((len(arr) - 2 - 2) / 2)
87
+ for i in range(1, feat_dim + 1):
88
+ means.append(float(arr[i]))
89
+ count = float(arr[feat_dim + 1])
90
+ for i in range(feat_dim + 2, 2 * feat_dim + 2):
91
+ variance.append(float(arr[i]))
92
+
93
+ for i in range(len(means)):
94
+ means[i] /= count
95
+ variance[i] = variance[i] / count - means[i] * means[i]
96
+ if variance[i] < 1.0e-20:
97
+ variance[i] = 1.0e-20
98
+ variance[i] = 1.0 / math.sqrt(variance[i])
99
+ cmvn = np.array([means, variance])
100
+ return cmvn
101
+
102
+ def load_cmvn(cmvn_file, is_json):
103
+ if is_json:
104
+ cmvn = _load_json_cmvn(cmvn_file)
105
+ else:
106
+ cmvn = _load_kaldi_cmvn(cmvn_file)
107
+ return cmvn[0], cmvn[1]
vita/model/vita_tts/encoder/encoder.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import argparse
6
+
7
+ from typing import Tuple, Dict, Optional
8
+
9
+ from vita.model.vita_tts.encoder.transformer import Transformer
10
+ from vita.model.vita_tts.encoder.subsampling import Subsampling
11
+
12
+ def add_encoder_args(group):
13
+ """Add Encoder common arguments."""
14
+ group.add_argument(
15
+ "--encoder-layer-config",
16
+ type=str,
17
+ default="tdnn-dtc",
18
+ help="Layer config of encoder. Format layername-layername-...",
19
+ )
20
+ group.add_argument(
21
+ "--encoder-input-dim",
22
+ type=int,
23
+ default=256,
24
+ help="Input dim of encoder. Must equal to the input dim of the first Component (default=40)"
25
+ )
26
+ group.add_argument(
27
+ "--encoder-output-dim",
28
+ type=int,
29
+ default=256,
30
+ help="Output dim of encoder. Must enqual to the output dim of the last Component ! (default=256)"
31
+ )
32
+ group = Transformer.add_arguments(group)
33
+ group = Subsampling.add_arguments(group)
34
+ return group
35
+
36
+ def assign_args_from_dict(args, dict, prefix_key=None):
37
+ if prefix_key is not None:
38
+ dict = dict[prefix_key]
39
+ for k, v in dict.items():
40
+ k_args = k.replace('-', '_')
41
+ if hasattr(args, k_args):
42
+ setattr(args, k_args, dict[k])
43
+ return args
44
+
45
+ class speechEncoder(torch.nn.Module):
46
+ def __init__(
47
+ self,
48
+ input_dim,
49
+ overview_conf = None,
50
+ para_conf = None,
51
+ global_cmvn = None):
52
+ super(speechEncoder, self).__init__()
53
+
54
+ parser = argparse.ArgumentParser()
55
+ add_encoder_args(parser)
56
+ args, _ = parser.parse_known_args()
57
+ assign_args_from_dict(args, overview_conf)
58
+
59
+ self.config = args.encoder_layer_config.split('-')
60
+ encoder_input_dim = args.encoder_input_dim
61
+ encoder_output_dim = args.encoder_output_dim
62
+ prev_output_dim = encoder_input_dim
63
+ prev_component_name = "encoder"
64
+
65
+ self.global_cmvn = global_cmvn
66
+ self.enc = torch.nn.ModuleList([])
67
+ for name in self.config:
68
+ assign_args_from_dict(args, para_conf[name])
69
+ if len(name.split('_')) == 2:
70
+ name = name.split('_')[0]
71
+ elif len(name.split('_')) == 1:
72
+ name = name
73
+ else:
74
+ print("WRONG CONFIG! {} is not valid".format("encoder", name))
75
+ sys.exit()
76
+ if name == "transformer":
77
+ self.enc.append(Transformer(args))
78
+ elif name == "subsampling":
79
+ self.enc.append(Subsampling(args))
80
+ else:
81
+ print("{} is not supported now!".format(name))
82
+ return NotImplemented
83
+ component_input_dim = getattr(args, name + "_input_dim")
84
+ if component_input_dim != prev_output_dim:
85
+ print("WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-input-dim ({})"
86
+ .format(prev_component_name, prev_output_dim, name, component_input_dim))
87
+ sys.exit()
88
+ prev_output_dim = getattr(args, name + "_output_dim")
89
+ prev_component_name = name
90
+
91
+ if (prev_output_dim != encoder_output_dim):
92
+ print("WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-output-dim ({}, the last component)"
93
+ .format("encoder", encoder_output_dim, name, prev_output_dim))
94
+ sys.exit()
95
+
96
+ self._output_size = encoder_output_dim
97
+
98
+ num_params = sum(p.numel() for p in self.parameters())
99
+ print('the number of speech encoder params: {}M'.format(num_params/1024/1024))
100
+
101
+ def output_size(self) -> int:
102
+ return self._output_size
103
+
104
+ def forward(self, xs, ilens, decoding_chunk_size=None, num_decoding_left_chunks=None):
105
+ """
106
+ Forward pass through the encoder.
107
+
108
+ Parameters:
109
+ - xs: torch.Tensor, shape (batch_size, sequence_length, input_dim)
110
+ The input tensor containing the sequence of input vectors.
111
+ - batch_size: The number of sequences in the batch.
112
+ - sequence_length: The length of each sequence.
113
+ - input_dim: The dimensionality of each input vector.
114
+
115
+ - ilens: torch.Tensor, shape (batch_size,)
116
+ The lengths of each sequence in the batch, used for padding masks.
117
+
118
+ - decoding_chunk_size: int, optional (default=None)
119
+ The size of chunks to use for decoding
120
+
121
+ - num_decoding_left_chunks: int, optional (default=None)
122
+ The number of left chunks to use for decoding
123
+
124
+ Returns:
125
+ - xs: torch.Tensor, shape (batch_size, sequence_length, encoded_dim)
126
+ The encoded output tensor, where encoded_dim is the dimensionality of the encoded representation.
127
+
128
+ - masks: torch.Tensor, shape (batch_size, 1, sequence_length)
129
+ The padding mask tensor, where True indicates valid elements and False indicates padded elements.
130
+ """
131
+ if decoding_chunk_size is not None and num_decoding_left_chunks is not None:
132
+ for layer in self.enc:
133
+ if hasattr(layer, "chunk_size"):
134
+ layer.chunk_size = decoding_chunk_size
135
+ if hasattr(layer, "left_chunks"):
136
+ layer.left_chunks = num_decoding_left_chunks
137
+ if hasattr(layer, "transformer_dynamic_chunks"):
138
+ layer.transformer_dynamic_chunks = False
139
+
140
+ assert(len(xs.shape)) == 3
141
+ T = xs.size(1)
142
+ masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T)
143
+ if self.global_cmvn is not None:
144
+ xs = self.global_cmvn(xs)
145
+ for module in self.enc:
146
+ xs, ilens, masks = module(xs, ilens, masks)
147
+ return xs, masks
148
+
149
+ def infer(self, xs_pad, buffer, buffer_index, buffer_out, pe_index):
150
+ if self.global_cmvn is not None:
151
+ xs_pad = self.global_cmvn(xs_pad)
152
+ for module in self.enc:
153
+ xs_pad, buffer, buffer_index, buffer_out, pe_index = module.infer(xs_pad,
154
+ buffer, buffer_index, buffer_out, pe_index)
155
+ return xs_pad, buffer, buffer_index, buffer_out, pe_index
vita/model/vita_tts/encoder/subsampling.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+
5
+ class BaseSubsampling(torch.nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.right_context = 0
9
+ self.subsampling_rate = 1
10
+
11
+ def position_encoding(self, offset: Union[int, torch.Tensor],
12
+ size: int) -> torch.Tensor:
13
+ return self.pos_enc.position_encoding(offset, size)
14
+
15
+ class Conv2dSubsampling4(BaseSubsampling):
16
+ """Convolutional 2D subsampling (to 1/4 length).
17
+
18
+ Args:
19
+ idim (int): Input dimension.
20
+ odim (int): Output dimension.
21
+ dropout_rate (float): Dropout rate.
22
+
23
+ """
24
+ def __init__(self, idim: int, odim: int, dropout_rate: float):
25
+ """Construct an Conv2dSubsampling4 object."""
26
+ super().__init__()
27
+ self.conv = torch.nn.Sequential(
28
+ torch.nn.Conv2d(1, odim, 3, 2),
29
+ torch.nn.ReLU(),
30
+ torch.nn.Conv2d(odim, odim, 3, 2),
31
+ torch.nn.ReLU(),
32
+ )
33
+ self.out = torch.nn.Sequential(
34
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
35
+ # The right context for every conv layer is computed by:
36
+ # (kernel_size - 1) * frame_rate_of_this_layer
37
+ self.subsampling_rate = 4
38
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
39
+ self.right_context = 6
40
+
41
+ def forward(
42
+ self,
43
+ x: torch.Tensor,
44
+ x_mask: torch.Tensor
45
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
46
+ """Subsample x.
47
+
48
+ Args:
49
+ x (torch.Tensor): Input tensor (#batch, time, idim).
50
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
51
+
52
+ Returns:
53
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
54
+ where time' = time // 4.
55
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
56
+ where time' = time // 4.
57
+ torch.Tensor: positional encoding
58
+
59
+ """
60
+ x = x.unsqueeze(1) # (b, c=1, t, f)
61
+ x = self.conv(x)
62
+ b, c, t, f = x.size()
63
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
64
+
65
+ return x, x_mask[:, :, 2::2][:, :, 2::2]
66
+
67
+ def infer(self, x, buffer, buffer_index, buffer_out):
68
+ x = x.unsqueeze(1) # (b, c=1, t, f)
69
+ x = self.conv(x)
70
+ b, c, t, f = x.size()
71
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
72
+
73
+ return x, buffer, buffer_index, buffer_out
74
+
75
+ class Subsampling(torch.nn.Module):
76
+ @staticmethod
77
+ def add_arguments(group):
78
+ """Add Subsampling common arguments."""
79
+ group.add_argument('--subsampling-rate', default=4, type=int)
80
+ group.add_argument('--subsampling-input-dim', default=256, type=int)
81
+ group.add_argument('--subsampling-output-dim', default=256, type=int)
82
+ group.add_argument('--subsampling-dropout-rate', default=0.1, type=float)
83
+
84
+ return group
85
+
86
+ def __init__(self, args):
87
+ super().__init__()
88
+ self.subsampling_rate = args.subsampling_rate
89
+ self.subsampling_input_dim = args.subsampling_input_dim
90
+ self.subsampling_output_dim = args.subsampling_output_dim
91
+ self.subsampling_dropout_rate = args.subsampling_dropout_rate
92
+
93
+ if self.subsampling_rate == 4:
94
+ self.core = Conv2dSubsampling4(self.subsampling_input_dim,
95
+ self.subsampling_output_dim,
96
+ self.subsampling_dropout_rate)
97
+
98
+ def forward(self, xs, ilens, masks):
99
+ xs, masks = self.core(xs, masks)
100
+ ilens = masks.squeeze(1).sum(1)
101
+ return xs, ilens, masks
102
+
103
+ def infer(self, x, buffer, buffer_index, buffer_out, pe_index):
104
+ x, buffer, buffer_index, buffer_out = self.core.infer(x,
105
+ buffer, buffer_index, buffer_out)
106
+ return x, buffer, buffer_index, buffer_out, pe_index
vita/model/vita_tts/encoder/transformer.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import pdb
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from vita.model.vita_tts.encoder.attention import *
10
+ from vita.model.vita_tts.masks import *
11
+
12
+ IGNORE_ID = -1
13
+
14
+ def strtobool(x):
15
+ return bool(dist_strtobool(x))
16
+
17
+ def repeat(N, fn):
18
+ """Repeat module N times.
19
+
20
+ :param int N: repeat time
21
+ :param function fn: function to generate module
22
+ :return: repeated modules
23
+ :rtype: MultiSequential
24
+ """
25
+ return MultiSequential(*[fn(n) for n in range(N)])
26
+
27
+ class MultiSequential(torch.nn.Sequential):
28
+ """Multi-input multi-output torch.nn.Sequential."""
29
+ def forward(self, x, masks, pos_emb):
30
+
31
+ """Repeat."""
32
+ for m in self:
33
+ x, masks, pos_emb = m(x, masks, pos_emb)
34
+ return x, masks, pos_emb
35
+
36
+ @torch.jit.export
37
+ def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
38
+ # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
39
+ """Repeat."""
40
+ for m in self:
41
+ x, pos_emb, buffer, buffer_index, buffer_out = m.infer(x, pos_emb, buffer, buffer_index, buffer_out)
42
+ return x, pos_emb, buffer, buffer_index, buffer_out
43
+
44
+ class TransformerLayer(nn.Module):
45
+ """Transformer layer module.
46
+
47
+ :param int size: input dim
48
+ :param self_attn: self attention module
49
+ :param feed_forward: feed forward module
50
+ :param float dropout_rate: dropout rate
51
+ :param bool normalize_before: whether to use layer_norm before the first block
52
+ :param bool concat_after: whether to concat attention layer's input and output
53
+ if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
54
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
55
+
56
+ """
57
+ def __init__(self, size, self_attn, feed_forward, dropout_rate,
58
+ normalize_before=True, concat_after=False):
59
+ """Construct an TransformerLayer object."""
60
+ super(TransformerLayer, self).__init__()
61
+ self.self_attn = self_attn
62
+ self.feed_forward = feed_forward
63
+ self.norm1 = torch.nn.LayerNorm(size)
64
+ self.norm2 = torch.nn.LayerNorm(size)
65
+ self.dropout = nn.Dropout(dropout_rate)
66
+ self.size = size
67
+ self.normalize_before = normalize_before
68
+ self.concat_after = concat_after
69
+ if self.concat_after:
70
+ self.concat_linear = nn.Linear(size + size, size)
71
+ else:
72
+ self.concat_linear = nn.Identity()
73
+
74
+ @torch.jit.unused
75
+ def forward(self, x, mask, pos_emb):
76
+ """Compute encoded features.
77
+
78
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
79
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
80
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
81
+ """
82
+ residual = x
83
+ if self.normalize_before:
84
+ x = self.norm1(x)
85
+ if self.concat_after:
86
+ x_concat = torch.cat((x, self.self_attn(x, x, x, mask, pos_emb)), dim=-1)
87
+ x = residual + self.concat_linear(x_concat)
88
+ else:
89
+ x = residual + self.dropout(self.self_attn(x, x, x, mask, pos_emb))
90
+ if not self.normalize_before:
91
+ x = self.norm1(x)
92
+
93
+ residual = x
94
+ if self.normalize_before:
95
+ x = self.norm2(x)
96
+ x = residual + self.dropout(self.feed_forward(x))
97
+ if not self.normalize_before:
98
+ x = self.norm2(x)
99
+
100
+ return x, mask, pos_emb
101
+
102
+ @torch.jit.export
103
+ def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
104
+ # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
105
+ residual = x.clone()
106
+ if self.normalize_before:
107
+ x = self.norm1(x)
108
+ if self.concat_after:
109
+ x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(x, x, x,
110
+ pos_emb, buffer,
111
+ buffer_index, buffer_out)
112
+ x_concat = torch.cat((x, x_att), dim=-1)
113
+ x = residual + self.concat_linear(x_concat)
114
+ else:
115
+ x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(x, x, x,
116
+ pos_emb, buffer,
117
+ buffer_index, buffer_out)
118
+ x = residual + x_att
119
+ if not self.normalize_before:
120
+ x = self.norm1(x)
121
+
122
+ residual = x.clone()
123
+ if self.normalize_before:
124
+ x = self.norm2(x)
125
+ x_feed, buffer, buffer_index, buffer_out = self.feed_forward.infer(x, buffer, buffer_index, buffer_out)
126
+ x = residual + x_feed
127
+ if not self.normalize_before:
128
+ x = self.norm2(x)
129
+
130
+ return x, pos_emb, buffer, buffer_index, buffer_out
131
+
132
+ class Transformer(torch.nn.Module):
133
+ @staticmethod
134
+ def add_arguments(group):
135
+ """Add TDNN common arguments."""
136
+ group.add_argument('--transformer-input-dim', default=256, type=int)
137
+ group.add_argument('--transformer-output-dim', default=4, type=int)
138
+ group.add_argument('--transformer-attention-dim', default=256, type=int)
139
+ group.add_argument('--transformer-attention-heads', default=4, type=int)
140
+ group.add_argument('--transformer-linear-units', default=1024, type=int)
141
+ group.add_argument('--transformer-num-blocks', default=6, type=int)
142
+ group.add_argument('--transformer-dropout-rate', default=0.1, type=float)
143
+ group.add_argument('--transformer-attention-dropout-rate', default=0.0, type=float)
144
+ group.add_argument('--transformer-positional-dropout-rate', default=0.1, type=float)
145
+ group.add_argument('--transformer-input-layer', default='linear', type=str)
146
+ group.add_argument('--transformer-pos-enc-class', default='abs-enc', type=str)
147
+ group.add_argument('--transformer-normalize-before', default=True, type=strtobool)
148
+ group.add_argument('--transformer-concat-after', default=False, type=strtobool)
149
+ group.add_argument('--transformer-positionwise-layer-type', default='linear', type=str)
150
+ group.add_argument('--transformer-positionwise-conv-kernel_size', default=1, type=int)
151
+ group.add_argument('--transformer-chunk_size', default=-1, type=int)
152
+ group.add_argument('--transformer-left_chunks', default=-1, type=int)
153
+ group.add_argument('--transformer-dynamic-chunks', default=True, type=strtobool)
154
+ return group
155
+
156
+ def __init__(self, args):
157
+ """Construct an Encoder object."""
158
+ super(Transformer, self).__init__()
159
+
160
+ self.input_dim = args.transformer_input_dim
161
+ self.output_dim = args.transformer_output_dim
162
+ self.attention_dim = args.transformer_attention_dim
163
+ self.attention_heads = args.transformer_attention_heads
164
+ self.linear_units = args.transformer_linear_units
165
+ self.num_blocks = args.transformer_num_blocks
166
+ self.dropout_rate = args.transformer_dropout_rate
167
+ self.positional_dropout_rate = args.transformer_positional_dropout_rate
168
+ self.attention_dropout_rate = args.transformer_attention_dropout_rate
169
+ self.input_layer = args.transformer_input_layer
170
+ self.pos_enc_class = args.transformer_pos_enc_class
171
+ self.normalize_before = args.transformer_normalize_before
172
+ self.concat_after = args.transformer_concat_after
173
+ self.positionwise_layer_type = args.transformer_positionwise_layer_type
174
+ self.positionwise_conv_kernel_size = args.transformer_positionwise_conv_kernel_size
175
+ self.chunk_size = args.transformer_chunk_size
176
+ self.left_chunks = args.transformer_left_chunks
177
+ self.transformer_dynamic_chunks = args.transformer_dynamic_chunks
178
+
179
+ if self.pos_enc_class == "abs-enc":
180
+ pos_enc_args = (self.attention_dim, self.positional_dropout_rate)
181
+ pos_enc_class = PositionalEncoding
182
+ elif self.pos_enc_class == "rel-enc":
183
+ pos_enc_args = (self.attention_dim, self.positional_dropout_rate, self.chunk_size, self.left_chunks)
184
+ pos_enc_class = RelPositionalEncoding
185
+
186
+ if self.input_layer == "linear":
187
+ self.embed = torch.nn.Sequential(
188
+ torch.nn.Linear(self.input_dim, self.attention_dim),
189
+ torch.nn.LayerNorm(self.attention_dim),
190
+ torch.nn.Dropout(self.dropout_rate),
191
+ torch.nn.ReLU()
192
+ )
193
+ elif self.input_layer == "embed":
194
+ self.embed = torch.nn.Sequential(
195
+ torch.nn.Embedding(self.input_dim, self.attention_dim, padding_idx=IGNORE_ID)
196
+ )
197
+ elif self.input_layer == "none":
198
+ self.embed = torch.nn.Sequential(
199
+ torch.nn.Identity()
200
+ )
201
+ else:
202
+ raise ValueError("unknown input_layer: " + self.input_layer)
203
+ self.pe = pos_enc_class(*pos_enc_args)
204
+ self.embed_layer_num = len(self.embed)
205
+
206
+ if self.positionwise_layer_type == "linear":
207
+ positionwise_layer = PositionwiseFeedForward
208
+ positionwise_layer_args = (self.attention_dim, self.linear_units, self.dropout_rate)
209
+ elif self.positionwise_layer_type == "conv1d":
210
+ positionwise_layer = MultiLayeredConv1d
211
+ positionwise_layer_args = (self.attention_dim, self.linear_units,
212
+ self.positionwise_conv_kernel_size, self.dropout_rate)
213
+ elif self.positionwise_layer_type == "conv1d-linear":
214
+ positionwise_layer = Conv1dLinear
215
+ positionwise_layer_args = (self.attention_dim, self.linear_units,
216
+ self.positionwise_conv_kernel_size, self.dropout_rate)
217
+ else:
218
+ raise NotImplementedError("Support only linear or conv1d.")
219
+
220
+ self.encoders = repeat(
221
+ self.num_blocks,
222
+ lambda lnum: TransformerLayer(
223
+ self.attention_dim,
224
+ MultiHeadedAttention(self.attention_heads, self.attention_dim,
225
+ self.attention_dropout_rate, self.chunk_size,
226
+ self.left_chunks, self.pos_enc_class),
227
+ positionwise_layer(*positionwise_layer_args),
228
+ self.dropout_rate,
229
+ self.normalize_before,
230
+ self.concat_after
231
+ )
232
+ )
233
+ if self.normalize_before:
234
+ self.after_norm = torch.nn.LayerNorm(self.attention_dim)
235
+
236
+ @torch.jit.unused
237
+ def forward(self, xs, ilens=None, masks=None):
238
+ """Embed positions in tensor.
239
+
240
+ :param torch.Tensor xs: input tensor
241
+ :param torch.Tensor masks: input mask
242
+ :return: position embedded tensor and mask
243
+ :rtype Tuple[torch.Tensor, torch.Tensor]:
244
+ """
245
+ if self.transformer_dynamic_chunks == True: # and self.training:
246
+ chunk_masks = add_optional_chunk_mask(xs, masks,
247
+ True,
248
+ True,
249
+ 0,
250
+ 0,
251
+ -1)
252
+ else:
253
+ chunk_masks = add_optional_chunk_mask(xs, masks,
254
+ False,
255
+ False,
256
+ self.chunk_size,
257
+ self.chunk_size,
258
+ self.left_chunks).to(xs.device)
259
+ xs = self.embed(xs)
260
+ xs, pos_emb = self.pe(xs)
261
+ xs, chunk_masks, pos_emb = self.encoders(xs, chunk_masks, pos_emb)
262
+ if self.normalize_before:
263
+ xs = self.after_norm(xs)
264
+ return xs, ilens, masks
265
+
266
+ @torch.jit.export
267
+ def infer(self, xs, buffer, buffer_index, buffer_out, pe_index):
268
+ xs = self.embed(xs)
269
+
270
+ # pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
271
+ # xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
272
+ # buffer_out.append(pe_index.reshape(-1).to(torch.float32))
273
+ # buffer_index = buffer_index + 1
274
+ if buffer[0] is None:
275
+ pe_length = xs.size(1)
276
+ else:
277
+ pe_length = buffer[0][0].size(2) + xs.size(1)
278
+ xs, pos_emb, pe_index = self.pe.infer(xs, pe_index, pe_length)
279
+ pos_emb = pos_emb.to('cuda')
280
+ xs, pos_emb, buffer, buffer_index, buffer_out = self.encoders.infer(xs, pos_emb,
281
+ buffer, buffer_index, buffer_out)
282
+
283
+ if self.normalize_before:
284
+ xs = self.after_norm(xs)
285
+ return xs, buffer, buffer_index, buffer_out, pe_index
vita/model/vita_tts/masks.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def casual_chunk_mask(ilens, chunk_size, left_chunks=1):
4
+ # type: (List[int], int, int) -> Tensor
5
+ # param ilens: list, (B, )
6
+ # param chunk_size: int
7
+ # return chunk_mask: torch.Tensor, (B, T, T)
8
+ B = len(ilens)
9
+ T = max(ilens)
10
+ chunk_mask = torch.zeros(B, T, T)
11
+ for b in range(0, B):
12
+ if chunk_size == -1 :
13
+ chunk_mask[b, 0:ilens[b], 0:ilens[b]] = 1
14
+ else:
15
+ for t in range(0, ilens[b], chunk_size):
16
+ ty_start = t
17
+ ty_end = min(t + chunk_size, ilens[b])
18
+ tx_start = max(t - chunk_size * left_chunks, 0)
19
+ tx_end = min(t + chunk_size, ilens[b])
20
+ chunk_mask[b, ty_start:ty_end, tx_start:tx_end] = 1
21
+ return chunk_mask
22
+
23
+ def subsequent_chunk_mask(
24
+ size: int,
25
+ chunk_size: int,
26
+ num_left_chunks: int = -1
27
+ ) -> torch.Tensor:
28
+ """Create mask for subsequent steps (size, size) with chunk size,
29
+ this is for streaming encoder
30
+
31
+ Args:
32
+ size (int): size of mask
33
+ chunk_size (int): size of chunk
34
+ num_left_chunks (int): number of left chunks
35
+ <0: use full chunk
36
+ >=0: use num_left_chunks
37
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
38
+
39
+ Returns:
40
+ torch.Tensor: mask
41
+
42
+ Examples:
43
+ >>> subsequent_chunk_mask(4, 2)
44
+ [[1, 1, 0, 0],
45
+ [1, 1, 0, 0],
46
+ [1, 1, 1, 1],
47
+ [1, 1, 1, 1]]
48
+ """
49
+ ret = torch.zeros(size, size, dtype=torch.bool)
50
+ for i in range(size):
51
+ if num_left_chunks < 0:
52
+ start = 0
53
+ else:
54
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
55
+ ending = min((i // chunk_size + 1) * chunk_size, size)
56
+ ret[i, start:ending] = torch.ones(ending-start, dtype=torch.bool)
57
+ return ret
58
+
59
+ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
60
+ use_dynamic_chunk: bool,
61
+ use_dynamic_left_chunk: bool,
62
+ decoding_chunk_size: int, static_chunk_size: int,
63
+ num_decoding_left_chunks: int):
64
+ """ Apply optional mask for encoder.
65
+
66
+ Args:
67
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
68
+ mask (torch.Tensor): mask for xs, (B, 1, L)
69
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
70
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
71
+ training.
72
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
73
+ 0: default for training, use random dynamic chunk.
74
+ <0: for decoding, use full chunk.
75
+ >0: for decoding, use fixed chunk size as set.
76
+ static_chunk_size (int): chunk size for static chunk training/decoding
77
+ if it's greater than 0, if use_dynamic_chunk is true,
78
+ this parameter will be ignored
79
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
80
+ the chunk size is decoding_chunk_size.
81
+ >=0: use num_decoding_left_chunks
82
+ <0: use all left chunks
83
+
84
+ Returns:
85
+ torch.Tensor: chunk mask of the input xs.
86
+ """
87
+ # Whether to use chunk mask or not
88
+ if use_dynamic_chunk:
89
+ max_len = xs.size(1)
90
+ if decoding_chunk_size < 0:
91
+ chunk_size = max_len
92
+ num_left_chunks = -1
93
+ elif decoding_chunk_size > 0:
94
+ chunk_size = decoding_chunk_size
95
+ num_left_chunks = num_decoding_left_chunks
96
+ else:
97
+ # chunk size is either [1, 25] or full context(max_len).
98
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
99
+ # delay, the maximum frame is 100 / 4 = 25.
100
+ chunk_size = torch.randint(1, max_len, (1, )).item()
101
+ num_left_chunks = -1
102
+ if chunk_size > max_len // 2:
103
+ chunk_size = max_len
104
+ else:
105
+ chunk_size = chunk_size % 25 + 1
106
+ if use_dynamic_left_chunk:
107
+ max_left_chunks = (max_len - 1) // chunk_size
108
+ num_left_chunks = torch.randint(0, max_left_chunks,
109
+ (1, )).item()
110
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
111
+ num_left_chunks) # (L, L)
112
+ chunk_masks = chunk_masks.to(xs.device)
113
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
114
+ chunk_masks = masks & chunk_masks # (B, L, L)
115
+ elif static_chunk_size > 0:
116
+ num_left_chunks = num_decoding_left_chunks
117
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
118
+ num_left_chunks) # (L, L)
119
+ chunk_masks = chunk_masks.unsqueeze(0).to(masks.device) # (1, L, L)
120
+ chunk_masks = masks & chunk_masks # (B, L, L)
121
+ else:
122
+ chunk_masks = masks
123
+ return chunk_masks
124
+
125
+ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
126
+ """Make mask tensor containing indices of padded part.
127
+
128
+ See description of make_non_pad_mask.
129
+
130
+ Args:
131
+ lengths (torch.Tensor): Batch of lengths (B,).
132
+ Returns:
133
+ torch.Tensor: Mask tensor containing indices of padded part.
134
+
135
+ Examples:
136
+ >>> lengths = [5, 3, 2]
137
+ >>> make_pad_mask(lengths)
138
+ masks = [[0, 0, 0, 0 ,0],
139
+ [0, 0, 0, 1, 1],
140
+ [0, 0, 1, 1, 1]]
141
+ """
142
+ batch_size = int(lengths.size(0))
143
+ max_len = int(lengths.max().item())
144
+ seq_range = torch.arange(0,
145
+ max_len,
146
+ dtype=torch.int64,
147
+ device=lengths.device)
148
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
149
+ seq_length_expand = lengths.unsqueeze(-1)
150
+ mask = seq_range_expand >= seq_length_expand
151
+ return mask
152
+
153
+ def subsequent_mask(
154
+ size: int,
155
+ device: torch.device = torch.device("cpu"),
156
+ ) -> torch.Tensor:
157
+ """Create mask for subsequent steps (size, size).
158
+
159
+ This mask is used only in decoder which works in an auto-regressive mode.
160
+ This means the current step could only do attention with its left steps.
161
+
162
+ In encoder, fully attention is used when streaming is not necessary and
163
+ the sequence is not long. In this case, no attention mask is needed.
164
+
165
+ When streaming is need, chunk-based attention is used in encoder. See
166
+ subsequent_chunk_mask for the chunk-based attention mask.
167
+
168
+ Args:
169
+ size (int): size of mask
170
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
171
+ dtype (torch.device): result dtype
172
+
173
+ Returns:
174
+ torch.Tensor: mask
175
+
176
+ Examples:
177
+ >>> subsequent_mask(3)
178
+ [[1, 0, 0],
179
+ [1, 1, 0],
180
+ [1, 1, 1]]
181
+ """
182
+ ret = torch.ones(size, size, device=device, dtype=torch.bool)
183
+ return torch.tril(ret, out=ret)
184
+
185
+ def target_mask(ys_in_pad, ignore_id):
186
+ # type: (Tensor, int) -> Tensor
187
+ # Create mask for decoder self-attention.
188
+ # :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
189
+ # :param int ignore_id: index of padding
190
+ # :param torch.dtype dtype: result dtype
191
+ # :rtype: torch.Tensor
192
+
193
+ ys_mask = ys_in_pad != ignore_id
194
+ m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
195
+ return ys_mask.unsqueeze(-2) & m
vita/model/vita_tts/pipeline.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml
3
+ import os
4
+ import re
5
+
6
+ from vita.model.vita_tts.utils import init_encoder_llm, load_checkpoint
7
+
8
+ class inferencePipeline():
9
+ def __init__(self, args):
10
+ self.args = args
11
+
12
+ with open(self.args.model_path + "/audiollm/train.yaml", 'r') as fin:
13
+ configs = yaml.safe_load(fin)
14
+ configs['cmvn_file'] = self.args.model_path + "/audiollm/global_cmvn"
15
+ configs['model_conf']['llm_path'] = self.args.llm_path
16
+
17
+ # Init asr model from configs
18
+ self.model = init_encoder_llm(configs)
19
+
20
+ load_checkpoint(self.model, self.args.model_path + "/audiollm/final.pt")
21
+ device = torch.device('cuda')
22
+ self.model = self.model.to(device)
23
+ self.model.eval()
24
+
25
+ def speech_dialogue(self,
26
+ audio: tuple,
27
+ role: str=None,
28
+ stat: str='sl',
29
+ past_key_values=None,
30
+ last_id=None,
31
+ past_tokens=None,
32
+ adapter_cache=None,
33
+ encoder_cache=None,
34
+ pe_index=0):
35
+ with torch.no_grad():
36
+ ## input fbank
37
+ feats = audio
38
+ if feats is not None:
39
+ feats = feats.to('cuda')
40
+ feats_lengths = torch.tensor([feats.size(1)]).to('cuda')
41
+ else:
42
+ feats_lengths = None
43
+
44
+ extra_inputs = {}
45
+ extra_inputs['top_p'] = self.args.top_p
46
+ extra_inputs['top_k'] = self.args.top_k
47
+ extra_inputs['temperature'] = self.args.temperature
48
+ extra_inputs['past_key_values'] = past_key_values
49
+ extra_inputs['stat'] = stat
50
+ extra_inputs['last_id'] = last_id
51
+ extra_inputs['adapter_cache'] = adapter_cache
52
+ extra_inputs['encoder_cache'] = encoder_cache
53
+ extra_inputs['pe_index'] = pe_index
54
+ if role is not None and past_key_values is None:
55
+ # add <|im_end|> in chat_prefix
56
+ extra_inputs['role'] = '<|im_start|>system\n' + role # + '<|im_end|>'
57
+
58
+ with torch.autocast(device_type="cuda",
59
+ dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32):
60
+ # preprocess system role first
61
+ if stat == 'pre':
62
+ past_key_values = self.model.set_system_role(extra_inputs)
63
+ stat = 'sl'
64
+ else:
65
+ (last_id, stat, past_key_values, adapter_cache,
66
+ encoder_cache, pe_index, hidden_state) = self.model.recognize(
67
+ feats,
68
+ feats_lengths,
69
+ extra_inputs=extra_inputs)
70
+
71
+ outputs = dict(
72
+ past_key_values=past_key_values,
73
+ stat=stat,
74
+ last_id=last_id,
75
+ adapter_cache=adapter_cache,
76
+ encoder_cache=encoder_cache,
77
+ pe_index=pe_index,
78
+ )
79
+
80
+ if stat == 'cs':
81
+ if past_tokens is None:
82
+ past_tokens = []
83
+ past_tokens.append(last_id[0][0])
84
+ text = self.model.tokenizer.decode(past_tokens, skip_special_tokens=True)
85
+ outputs['hidden_state'] = hidden_state
86
+ outputs['text'] = text
87
+ outputs['past_tokens'] = past_tokens
88
+
89
+ return outputs
90
+
91
+ def post_process(self, text):
92
+ """
93
+ Post-processes the input text to standardize various characters and formatting.
94
+
95
+ Parameters:
96
+ - text (str): The input text string to be post-processed.
97
+
98
+ Actions:
99
+ 1. Replaces various Chinese and English punctuation marks with standardized ones.
100
+ 2. Removes newline, tab, and other unwanted whitespace characters.
101
+ 3. Removes special characters like asterisks, underscores, backticks, and tildes.
102
+ 4. Condenses whitespace following periods and colons.
103
+ 5. Adjusts the format of numbered lists to use appropriate separators
104
+ 6. Ensures the text ends with an appropriate punctuation mark
105
+
106
+ Returns:
107
+ - str: The post-processed text string.
108
+ """
109
+ text = text.replace('、', ',')
110
+ text = text.replace('(', ',')
111
+ text = text.replace(')', ',')
112
+ text = text.replace('(', ',')
113
+ text = text.replace(')', ',')
114
+
115
+ text = re.sub(r'[\n\r\t]', '', text)
116
+ text = re.sub(r'[*_`~]', '', text)
117
+
118
+ text = re.sub(r'(\.|\:)\s+', r'\1', text)
119
+
120
+ if re.search(r'[\u4e00-\u9fa5]', text):
121
+ text = re.sub(r'(\d+)\.\s*([\u4e00-\u9fa5A-Za-z])', r'\1:\2', text)
122
+ else:
123
+ text = re.sub(r'(\d+)\.\s*([\w])', r'\1:\2', text)
124
+
125
+ if text and text[-1] not in ["。", "?", "!", ".", "?", "!"]:
126
+ if text[-1] in [",", ",", ";", ";", ":", ":", "、"]:
127
+ text = text[:-1] + "。"
128
+ else:
129
+ text += "。"
130
+
131
+ return text
vita/model/vita_tts/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import os
4
+
5
+ from vita.model.vita_tts.audioLLM import AudioLLM
6
+
7
+ from vita.model.vita_tts.encoder.cmvn import GlobalCMVN, load_cmvn
8
+ from vita.model.vita_tts.encoder.encoder import speechEncoder
9
+
10
+ def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
11
+ if torch.cuda.is_available():
12
+ print('Checkpoint: loading from checkpoint %s for GPU' % path)
13
+ checkpoint = torch.load(path)
14
+ else:
15
+ print('Checkpoint: loading from checkpoint %s for CPU' % path)
16
+ checkpoint = torch.load(path, map_location='cpu')
17
+
18
+ # load parm from checkpoint
19
+ model.load_state_dict(checkpoint, strict=False)
20
+
21
+ info_path = re.sub('.pt$', '.yaml', path)
22
+ configs = {}
23
+ # get configs
24
+ if os.path.exists(info_path):
25
+ with open(info_path, 'r') as fin:
26
+ configs = yaml.safe_load(fin)
27
+ return configs
28
+
29
+ def init_encoder_llm(configs):
30
+ if configs['cmvn_file'] is not None:
31
+ # read cmvn
32
+ mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn'])
33
+ # init cmvn layer
34
+ global_cmvn = GlobalCMVN(
35
+ torch.from_numpy(mean).float(),
36
+ torch.from_numpy(istd).float())
37
+ else:
38
+ global_cmvn = None
39
+
40
+ input_dim = configs['input_dim']
41
+ vocab_size = configs['output_dim']
42
+
43
+ # init speech encoder
44
+ encoder = speechEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
45
+ # init audioLLM
46
+ model = AudioLLM(encoder=encoder, **configs['model_conf'])
47
+
48
+ return model