Spaces:
Running
on
Zero
Running
on
Zero
upload vita-1.5 app.py
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +479 -8
- vita/config/__init__.py +10 -0
- vita/config/dataset_config.py +8 -0
- vita/constants.py +14 -0
- vita/conversation.py +401 -0
- vita/model/__init__.py +5 -0
- vita/model/builder.py +287 -0
- vita/model/language_model/vita_fo_qwen2.py +227 -0
- vita/model/language_model/vita_mixtral.py +420 -0
- vita/model/language_model/vita_nemo.py +282 -0
- vita/model/language_model/vita_qwen2.py +304 -0
- vita/model/multimodal_encoder/builder.py +83 -0
- vita/model/multimodal_encoder/clip/clip_encoder.py +78 -0
- vita/model/multimodal_encoder/eva_clip/eva_clip_encoder.py +66 -0
- vita/model/multimodal_encoder/eva_clip/eva_clip_processors.py +69 -0
- vita/model/multimodal_encoder/eva_clip/eva_vit.py +982 -0
- vita/model/multimodal_encoder/internvit/configuration_intern_vit.py +125 -0
- vita/model/multimodal_encoder/internvit/flash_attention.py +101 -0
- vita/model/multimodal_encoder/internvit/internvit_encoder.py +105 -0
- vita/model/multimodal_encoder/internvit/modeling_intern_vit.py +394 -0
- vita/model/multimodal_encoder/siglip/siglip_encoder.py +149 -0
- vita/model/multimodal_encoder/whale/adapter.py +137 -0
- vita/model/multimodal_encoder/whale/cmvn.py +89 -0
- vita/model/multimodal_encoder/whale/init_model.py +192 -0
- vita/model/multimodal_encoder/whale/module/component/mamba.py +131 -0
- vita/model/multimodal_encoder/whale/module/component/subsampling.py +74 -0
- vita/model/multimodal_encoder/whale/module/component/transformer.py +428 -0
- vita/model/multimodal_encoder/whale/module/encoder/encoder.py +171 -0
- vita/model/multimodal_encoder/whale/module/layer/attention.py +571 -0
- vita/model/multimodal_encoder/whale/module/layer/conv1d.py +88 -0
- vita/model/multimodal_encoder/whale/module/layer/dtcblock.py +95 -0
- vita/model/multimodal_encoder/whale/module/layer/fsmn.py +129 -0
- vita/model/multimodal_encoder/whale/utils.py +146 -0
- vita/model/multimodal_projector/builder.py +185 -0
- vita/model/vita_arch.py +639 -0
- vita/model/vita_tts/adapter.py +157 -0
- vita/model/vita_tts/audioLLM.py +433 -0
- vita/model/vita_tts/decoder/decoder.py +367 -0
- vita/model/vita_tts/decoder/llm2tts.py +161 -0
- vita/model/vita_tts/decoder/ticodec/models.py +716 -0
- vita/model/vita_tts/decoder/ticodec/vqvae.py +57 -0
- vita/model/vita_tts/decoder/ticodec/vqvae_tester.py +37 -0
- vita/model/vita_tts/encoder/attention.py +459 -0
- vita/model/vita_tts/encoder/cmvn.py +107 -0
- vita/model/vita_tts/encoder/encoder.py +155 -0
- vita/model/vita_tts/encoder/subsampling.py +106 -0
- vita/model/vita_tts/encoder/transformer.py +285 -0
- vita/model/vita_tts/masks.py +195 -0
- vita/model/vita_tts/pipeline.py +131 -0
- vita/model/vita_tts/utils.py +48 -0
app.py
CHANGED
@@ -1,14 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
import spaces
|
3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
@spaces.GPU
|
9 |
-
def
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
demo
|
14 |
-
demo.launch()
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import copy
|
6 |
import gradio as gr
|
7 |
+
import re
|
8 |
+
import torchaudio
|
9 |
+
import io
|
10 |
+
import cv2
|
11 |
+
import math
|
12 |
import spaces
|
13 |
+
from numba import jit
|
14 |
+
from huggingface_hub import snapshot_download
|
15 |
+
|
16 |
+
from vita.constants import DEFAULT_AUDIO_TOKEN, DEFAULT_IMAGE_TOKEN, MAX_IMAGE_LENGTH, MIN_IMAGE_LENGTH, IMAGE_TOKEN_INDEX, AUDIO_TOKEN_INDEX
|
17 |
+
from vita.conversation import conv_templates, SeparatorStyle
|
18 |
+
from vita.util.mm_utils import tokenizer_image_token, tokenizer_image_audio_token
|
19 |
+
from PIL import Image
|
20 |
+
from decord import VideoReader, cpu
|
21 |
+
from vita.model.builder import load_pretrained_model
|
22 |
+
from vita.model.vita_tts.decoder.llm2tts import llm2TTS
|
23 |
+
from vita.model.language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM
|
24 |
+
|
25 |
+
decoder_topk = 2
|
26 |
+
codec_chunk_size = 40
|
27 |
+
codec_padding_size = 10
|
28 |
+
|
29 |
+
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛""„‟…‧﹏."
|
30 |
+
|
31 |
+
MODEL_NAME = "VITA-MLLM/VITA-1.5"
|
32 |
+
model_path = snapshot_download(MODEL_NAME, local_dir="VITA_ckpt")
|
33 |
+
tokenizer, model, feature_extractor, context_len = load_pretrained_model(
|
34 |
+
model_path, model_base=None, model_name="VITA-1.5", model_type="qwen2p5_instruct"
|
35 |
+
)
|
36 |
+
llm_embedding = model.get_input_embeddings().cuda()
|
37 |
+
tts = llm2TTS(os.path.join(model_path, 'vita_tts_ckpt/'))
|
38 |
+
|
39 |
+
@jit
|
40 |
+
def float_to_int16(audio: np.ndarray) -> np.ndarray:
|
41 |
+
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
|
42 |
+
am = 32767 * 32768 // am
|
43 |
+
return np.multiply(audio, am).astype(np.int16)
|
44 |
+
|
45 |
+
|
46 |
+
def remove_special_characters(input_str):
|
47 |
+
# Remove special tokens
|
48 |
+
special_tokens = ['☞', '☟', '☜', '<unk>', '<|im_end|>']
|
49 |
+
for token in special_tokens:
|
50 |
+
input_str = input_str.replace(token, '')
|
51 |
+
return input_str
|
52 |
+
|
53 |
+
|
54 |
+
def replace_equation(sentence):
|
55 |
+
special_notations = {
|
56 |
+
"sin": " sine ",
|
57 |
+
"cos": " cosine ",
|
58 |
+
"tan": " tangent ",
|
59 |
+
"cot": " cotangent ",
|
60 |
+
"sec": " secant ",
|
61 |
+
"csc": " cosecant ",
|
62 |
+
"log": " logarithm ",
|
63 |
+
"exp": "e^",
|
64 |
+
"sqrt": "根号 ",
|
65 |
+
"abs": "绝对值 ",
|
66 |
+
}
|
67 |
+
|
68 |
+
special_operators = {
|
69 |
+
"+": "加",
|
70 |
+
"-": "减",
|
71 |
+
"*": "乘",
|
72 |
+
"/": "除",
|
73 |
+
"=": "等于",
|
74 |
+
'!=': '不等于',
|
75 |
+
'>': '大于',
|
76 |
+
'<': '小于',
|
77 |
+
'>=': '大于等于',
|
78 |
+
'<=': '小于等于',
|
79 |
+
}
|
80 |
+
|
81 |
+
greek_letters = {
|
82 |
+
"α": "alpha ",
|
83 |
+
"β": "beta ",
|
84 |
+
"γ": "gamma ",
|
85 |
+
"δ": "delta ",
|
86 |
+
"ε": "epsilon ",
|
87 |
+
"ζ": "zeta ",
|
88 |
+
"η": "eta ",
|
89 |
+
"θ": "theta ",
|
90 |
+
"ι": "iota ",
|
91 |
+
"κ": "kappa ",
|
92 |
+
"λ": "lambda ",
|
93 |
+
"μ": "mu ",
|
94 |
+
"ν": "nu ",
|
95 |
+
"ξ": "xi ",
|
96 |
+
"ο": "omicron ",
|
97 |
+
"π": "派 ",
|
98 |
+
"ρ": "rho ",
|
99 |
+
"σ": "sigma ",
|
100 |
+
"τ": "tau ",
|
101 |
+
"υ": "upsilon ",
|
102 |
+
"φ": "phi ",
|
103 |
+
"χ": "chi ",
|
104 |
+
"ψ": "psi ",
|
105 |
+
"ω": "omega "
|
106 |
+
}
|
107 |
+
|
108 |
+
sentence = sentence.replace('**', ' ')
|
109 |
+
|
110 |
+
sentence = re.sub(r'(?<![\d)])-(\d+)', r'负\1', sentence)
|
111 |
+
|
112 |
+
for key in special_notations:
|
113 |
+
sentence = sentence.replace(key, special_notations[key])
|
114 |
+
for key in special_operators:
|
115 |
+
sentence = sentence.replace(key, special_operators[key])
|
116 |
+
for key in greek_letters:
|
117 |
+
sentence = sentence.replace(key, greek_letters[key])
|
118 |
+
|
119 |
+
|
120 |
+
sentence = re.sub(r'\(?(\d+)\)?\((\d+)\)', r'\1乘\2', sentence)
|
121 |
+
sentence = re.sub(r'\(?(\w+)\)?\^\(?(\w+)\)?', r'\1的\2次方', sentence)
|
122 |
+
|
123 |
+
return sentence
|
124 |
+
|
125 |
+
|
126 |
+
def is_video(file_path):
|
127 |
+
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'}
|
128 |
+
_, ext = os.path.splitext(file_path)
|
129 |
+
return ext.lower() in video_extensions
|
130 |
+
|
131 |
+
def is_image(file_path):
|
132 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'}
|
133 |
+
_, ext = os.path.splitext(file_path)
|
134 |
+
return ext.lower() in image_extensions
|
135 |
+
|
136 |
+
def is_wav(file_path):
|
137 |
+
wav_extensions = {'.wav'}
|
138 |
+
_, ext = os.path.splitext(file_path)
|
139 |
+
return ext.lower() in wav_extensions
|
140 |
+
|
141 |
+
def load_model_embemding(model_path):
|
142 |
+
config_path = os.path.join(model_path, 'origin_config.json')
|
143 |
+
config = VITAQwen2Config.from_pretrained(config_path)
|
144 |
+
model = VITAQwen2ForCausalLM.from_pretrained(model_path, config=config, low_cpu_mem_usage=True)
|
145 |
+
embedding = model.get_input_embeddings()
|
146 |
+
del model
|
147 |
+
return embedding
|
148 |
+
|
149 |
+
def split_into_sentences(text):
|
150 |
+
sentence_endings = re.compile(r'[,。?\n!?、,?.!]')
|
151 |
+
sentences = sentence_endings.split(text)
|
152 |
+
return [sentence.strip() for sentence in sentences if sentence.strip()]
|
153 |
+
|
154 |
+
def convert_webm_to_mp4(input_file, output_file):
|
155 |
+
try:
|
156 |
+
cap = cv2.VideoCapture(input_file)
|
157 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
158 |
+
out = cv2.VideoWriter(output_file, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))
|
159 |
+
|
160 |
+
while cap.isOpened():
|
161 |
+
ret, frame = cap.read()
|
162 |
+
if not ret:
|
163 |
+
break
|
164 |
+
out.write(frame)
|
165 |
+
|
166 |
+
cap.release()
|
167 |
+
out.release()
|
168 |
+
except Exception as e:
|
169 |
+
print(f"Error: {e}")
|
170 |
+
raise
|
171 |
+
|
172 |
+
|
173 |
+
def _get_rawvideo_dec(video_path, max_frames=MAX_IMAGE_LENGTH, min_frames=MIN_IMAGE_LENGTH, video_framerate=1, s=None, e=None):
|
174 |
+
if s is None or e is None:
|
175 |
+
start_time, end_time = None, None
|
176 |
+
else:
|
177 |
+
start_time = int(s)
|
178 |
+
end_time = int(e)
|
179 |
+
start_time = max(start_time, 0)
|
180 |
+
end_time = max(end_time, 0)
|
181 |
+
if start_time > end_time:
|
182 |
+
start_time, end_time = end_time, start_time
|
183 |
+
elif start_time == end_time:
|
184 |
+
end_time = start_time + 1
|
185 |
+
|
186 |
+
if os.path.exists(video_path):
|
187 |
+
vreader = VideoReader(video_path, ctx=cpu(0))
|
188 |
+
else:
|
189 |
+
raise FileNotFoundError
|
190 |
+
|
191 |
+
fps = vreader.get_avg_fps()
|
192 |
+
f_start = 0 if start_time is None else int(start_time * fps)
|
193 |
+
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
|
194 |
+
num_frames = f_end - f_start + 1
|
195 |
+
|
196 |
+
if num_frames > 0:
|
197 |
+
sample_fps = int(video_framerate)
|
198 |
+
t_stride = int(round(float(fps) / sample_fps))
|
199 |
+
all_pos = list(range(f_start, f_end + 1, t_stride))
|
200 |
+
|
201 |
+
if len(all_pos) > max_frames:
|
202 |
+
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
|
203 |
+
elif len(all_pos) < min_frames:
|
204 |
+
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)]
|
205 |
+
else:
|
206 |
+
sample_pos = all_pos
|
207 |
+
|
208 |
+
patch_images = [Image.fromarray(f).convert("RGB") for f in vreader.get_batch(sample_pos).asnumpy()]
|
209 |
+
return patch_images, len(patch_images)
|
210 |
+
else:
|
211 |
+
print(f"video path: {video_path} error.")
|
212 |
|
213 |
+
def _parse_text(text):
|
214 |
+
lines = text.split("\n")
|
215 |
+
lines = [line for line in lines if line != ""]
|
216 |
+
count = 0
|
217 |
+
|
218 |
+
for i, line in enumerate(lines):
|
219 |
+
if "```" in line:
|
220 |
+
count += 1
|
221 |
+
items = line.split("`")
|
222 |
+
if count % 2 == 1:
|
223 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
224 |
+
else:
|
225 |
+
lines[i] = "<br></code></pre>"
|
226 |
+
else:
|
227 |
+
if i > 0 and count % 2 == 1:
|
228 |
+
line = line.replace("`", r"\`")
|
229 |
+
line = line.replace("<", "<")
|
230 |
+
line = line.replace(">", ">")
|
231 |
+
line = line.replace(" ", " ")
|
232 |
+
line = line.replace("*", "*")
|
233 |
+
line = line.replace("_", "_")
|
234 |
+
line = line.replace("-", "-")
|
235 |
+
line = line.replace(".", ".")
|
236 |
+
line = line.replace("!", "!")
|
237 |
+
line = line.replace("(", "(")
|
238 |
+
line = line.replace(")", ")")
|
239 |
+
line = line.replace("$", "$")
|
240 |
+
lines[i] = "<br>" + line
|
241 |
+
|
242 |
+
return "".join(lines)
|
243 |
+
|
244 |
+
|
245 |
+
@spaces.GPU
|
246 |
+
def predict(_chatbot, task_history):
|
247 |
+
chat_query = task_history[-1][0]
|
248 |
+
print(task_history)
|
249 |
+
|
250 |
+
conv_mode = "qwen2p5_instruct"
|
251 |
+
conv = conv_templates[conv_mode].copy()
|
252 |
+
|
253 |
+
all_audio_path = []
|
254 |
+
all_visual_tensor = []
|
255 |
+
|
256 |
+
qs = ''
|
257 |
+
input_mode = 'lang'
|
258 |
+
for i, (q, a) in enumerate(task_history):
|
259 |
+
if isinstance(q, (tuple, list)):
|
260 |
+
if is_image(q[0]):
|
261 |
+
images = [Image.open(q[0]).convert("RGB")]
|
262 |
+
all_visual_tensor.extend(images)
|
263 |
+
input_mode = 'image'
|
264 |
+
qs += DEFAULT_IMAGE_TOKEN * len(images) + '\n'
|
265 |
+
elif is_video(q[0]):
|
266 |
+
video_frames, slice_len = _get_rawvideo_dec(q[0])
|
267 |
+
all_visual_tensor.extend(video_frames)
|
268 |
+
input_mode = 'video'
|
269 |
+
qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n'
|
270 |
+
elif is_wav(q[0]):
|
271 |
+
if a is not None and a.startswith('☜'):
|
272 |
+
continue
|
273 |
+
else:
|
274 |
+
all_audio_path.append(q[0])
|
275 |
+
new_q = qs + DEFAULT_AUDIO_TOKEN
|
276 |
+
qs = ''
|
277 |
+
conv.append_message(conv.roles[0], new_q)
|
278 |
+
conv.append_message(conv.roles[1], a)
|
279 |
+
else:
|
280 |
+
new_q = qs + q
|
281 |
+
qs = ''
|
282 |
+
conv.append_message(conv.roles[0], new_q)
|
283 |
+
conv.append_message(conv.roles[1], a)
|
284 |
+
|
285 |
+
prompt = conv.get_prompt(input_mode)
|
286 |
+
|
287 |
+
if all_audio_path != []:
|
288 |
+
input_ids = tokenizer_image_audio_token(
|
289 |
+
prompt, tokenizer,
|
290 |
+
image_token_index=IMAGE_TOKEN_INDEX,
|
291 |
+
audio_token_index=AUDIO_TOKEN_INDEX
|
292 |
+
)
|
293 |
+
audio_list = []
|
294 |
+
for single_audio_path in all_audio_path:
|
295 |
+
try:
|
296 |
+
audio, original_sr = torchaudio.load(single_audio_path)
|
297 |
+
target_sr = 16000
|
298 |
+
if original_sr != target_sr:
|
299 |
+
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
|
300 |
+
audio = resampler(audio)
|
301 |
+
audio_features = feature_extractor(audio, sampling_rate=target_sr, return_tensors="pt")["input_features"]
|
302 |
+
audio_list.append(audio_features.squeeze(0))
|
303 |
+
except Exception as e:
|
304 |
+
print(f"Error processing {single_audio_path}: {e}")
|
305 |
+
else:
|
306 |
+
input_ids = tokenizer_image_token(
|
307 |
+
prompt, tokenizer,
|
308 |
+
image_token_index=IMAGE_TOKEN_INDEX
|
309 |
+
)
|
310 |
+
|
311 |
+
if all_visual_tensor == [] and all_audio_path == []:
|
312 |
+
datapromt = {
|
313 |
+
"prompt_token_ids": input_ids,
|
314 |
+
}
|
315 |
+
elif all_visual_tensor != [] and all_audio_path == []:
|
316 |
+
datapromt = {
|
317 |
+
"prompt_token_ids": input_ids,
|
318 |
+
"multi_modal_data": {
|
319 |
+
"image": all_visual_tensor
|
320 |
+
},
|
321 |
+
}
|
322 |
+
elif all_visual_tensor == [] and all_audio_path != []:
|
323 |
+
datapromt = {
|
324 |
+
"prompt_token_ids": input_ids,
|
325 |
+
"multi_modal_data": {
|
326 |
+
"audio": audio_list
|
327 |
+
},
|
328 |
+
}
|
329 |
+
else:
|
330 |
+
datapromt = {
|
331 |
+
"prompt_token_ids": input_ids,
|
332 |
+
"multi_modal_data": {
|
333 |
+
"image": all_visual_tensor,
|
334 |
+
"audio": audio_list
|
335 |
+
},
|
336 |
+
}
|
337 |
+
|
338 |
+
print(datapromt)
|
339 |
+
|
340 |
+
with torch.inference_mode():
|
341 |
+
output_ids = model.generate(
|
342 |
+
input_ids,
|
343 |
+
images=all_visual_tensor,
|
344 |
+
audios=audio_list,
|
345 |
+
do_sample=False,
|
346 |
+
temperature=0.01,
|
347 |
+
top_p=None,
|
348 |
+
num_beams=1,
|
349 |
+
output_scores=True,
|
350 |
+
return_dict_in_generate=True,
|
351 |
+
max_new_tokens=1024,
|
352 |
+
use_cache=True,
|
353 |
+
)
|
354 |
+
|
355 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0]
|
356 |
+
outputs = outputs.strip()
|
357 |
+
|
358 |
+
task_history[-1] = (chat_query, outputs)
|
359 |
+
remove_special_characters_output = remove_special_characters(outputs)
|
360 |
+
_chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output))
|
361 |
+
print("query", chat_query)
|
362 |
+
print("task_history", task_history)
|
363 |
+
print(_chatbot)
|
364 |
+
print("answer: ", outputs)
|
365 |
+
yield _chatbot
|
366 |
+
|
367 |
+
|
368 |
+
def add_text(history, task_history, text):
|
369 |
+
task_text = text
|
370 |
+
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
|
371 |
+
task_text = text[:-1]
|
372 |
+
history = history + [(_parse_text(text), None)]
|
373 |
+
task_history = task_history + [(task_text, None)]
|
374 |
+
return history, task_history, ""
|
375 |
+
|
376 |
+
def add_file(history, task_history, file):
|
377 |
+
history = history + [((file.name,), None)]
|
378 |
+
task_history = task_history + [((file.name,), None)]
|
379 |
+
return history, task_history
|
380 |
+
|
381 |
+
def add_audio(history, task_history, file):
|
382 |
+
print(file)
|
383 |
+
if file is None:
|
384 |
+
return history, task_history
|
385 |
+
history = history + [((file,), None)]
|
386 |
+
task_history = task_history + [((file,), None)]
|
387 |
+
return history, task_history
|
388 |
+
|
389 |
+
def add_video(history, task_history, file):
|
390 |
+
print(file)
|
391 |
+
if file is None:
|
392 |
+
return history, task_history
|
393 |
+
new_file_name = file.replace(".webm",".mp4")
|
394 |
+
if file.endswith(".webm"):
|
395 |
+
convert_webm_to_mp4(file, new_file_name)
|
396 |
+
task_history = task_history + [((new_file_name,), None)]
|
397 |
+
return history, task_history
|
398 |
+
|
399 |
+
|
400 |
+
def reset_user_input():
|
401 |
+
return gr.update(value="")
|
402 |
+
|
403 |
+
def reset_state(task_history):
|
404 |
+
task_history.clear()
|
405 |
+
return []
|
406 |
|
407 |
@spaces.GPU
|
408 |
+
def stream_audio_output(history, task_history):
|
409 |
+
text = task_history[-1][-1]
|
410 |
+
if not text:
|
411 |
+
# import pdb;pdb.set_trace()
|
412 |
+
yield None,None
|
413 |
+
llm_resounse = replace_equation(remove_special_characters(text))
|
414 |
+
#print('tts_text', llm_resounse)
|
415 |
+
for idx, text in enumerate(split_into_sentences(llm_resounse)):
|
416 |
+
embeddings = llm_embedding(torch.tensor(tokenizer.encode(text)).cuda())
|
417 |
+
for seg in tts.run(embeddings.reshape(-1, 896).unsqueeze(0), decoder_topk,
|
418 |
+
None,
|
419 |
+
codec_chunk_size, codec_padding_size):
|
420 |
+
if idx == 0:
|
421 |
+
try:
|
422 |
+
split_idx = torch.nonzero(seg.abs() > 0.03, as_tuple=True)[-1][0]
|
423 |
+
seg = seg[:, :, split_idx:]
|
424 |
+
except:
|
425 |
+
print('Do not need to split')
|
426 |
+
pass
|
427 |
+
|
428 |
+
if seg is not None and len(seg) > 0:
|
429 |
+
seg = seg.to(torch.float32).cpu().numpy()
|
430 |
+
yield 24000, float_to_int16(seg).T
|
431 |
+
|
432 |
+
|
433 |
+
with gr.Blocks(title="VideoMLLM") as demo:
|
434 |
+
gr.Markdown("""<center><font size=8>VITA</center>""")
|
435 |
+
chatbot = gr.Chatbot(label='VITA', elem_classes="control-height", height=500)
|
436 |
+
query = gr.Textbox(lines=2, label='Text Input')
|
437 |
+
task_history = gr.State([])
|
438 |
+
with gr.Row():
|
439 |
+
add_text_button = gr.Button("Submit Text (提交文本)")
|
440 |
+
add_audio_button = gr.Button("Submit Audio (提交音频)")
|
441 |
+
with gr.Row():
|
442 |
+
with gr.Column(scale=2):
|
443 |
+
addfile_btn = gr.UploadButton("📁 Upload (上传文件[视频,图片])", file_types=["video", "image"])
|
444 |
+
video_input = gr.Video(sources=[ "webcam"], height=400, width=700, container=True, interactive=True, show_download_button=True, label="📹 Video Recording (视频录制)")
|
445 |
+
|
446 |
+
with gr.Column(scale=1):
|
447 |
+
empty_bin = gr.Button("🧹 Clear History (清除历史)")
|
448 |
+
record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
|
449 |
+
audio_output = gr.Audio(
|
450 |
+
label="Output Audio",
|
451 |
+
value=None,
|
452 |
+
format= "wav",
|
453 |
+
autoplay=True,
|
454 |
+
streaming=True,
|
455 |
+
interactive=False,
|
456 |
+
show_label=True,
|
457 |
+
waveform_options=gr.WaveformOptions(
|
458 |
+
sample_rate=24000,
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
|
463 |
+
add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then(
|
464 |
+
reset_user_input, [], [query]
|
465 |
+
).then(
|
466 |
+
predict, [chatbot, task_history], [chatbot], show_progress=True
|
467 |
+
).then(
|
468 |
+
stream_audio_output,[chatbot, task_history], [audio_output],
|
469 |
+
)
|
470 |
+
|
471 |
+
|
472 |
+
video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True)
|
473 |
+
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
|
474 |
+
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
|
475 |
+
|
476 |
+
|
477 |
+
|
478 |
+
add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then(
|
479 |
+
predict, [chatbot, task_history], [chatbot], show_progress=True
|
480 |
+
).then(
|
481 |
+
stream_audio_output,[chatbot, task_history], [audio_output],
|
482 |
+
)
|
483 |
+
|
484 |
|
485 |
+
demo.launch(server_port=18806)
|
|
vita/config/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset_config import *
|
2 |
+
|
3 |
+
NaturalCap0 = [ShareGPT4V0]
|
4 |
+
NaturalCap = [ShareGPT4V]
|
5 |
+
|
6 |
+
DataConfig = {
|
7 |
+
"Pretrain_video": NaturalCap0,
|
8 |
+
}
|
9 |
+
|
10 |
+
NoPatchSets = ["khair", "jester"]
|
vita/config/dataset_config.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioFolder = ""
|
2 |
+
FolderDict = {
|
3 |
+
#### NaturalCap
|
4 |
+
"sharegpt4": "",
|
5 |
+
}
|
6 |
+
#### NaturalCap
|
7 |
+
ShareGPT4V = {"chat_path": ""}
|
8 |
+
ShareGPT4V0 = {"chat_path": ""}
|
vita/constants.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Constants
|
2 |
+
MAX_IMAGE_LENGTH = 16 # 8#16#32#64
|
3 |
+
MIN_IMAGE_LENGTH = 4
|
4 |
+
IGNORE_INDEX = -100
|
5 |
+
IMAGE_TOKEN_INDEX = -200
|
6 |
+
AUDIO_TOKEN_INDEX = -500
|
7 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
8 |
+
DEFAULT_VIDEO_TOKEN = "<video>"
|
9 |
+
DEFAULT_AUDIO_TOKEN = "<audio>"
|
10 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
11 |
+
LOGDIR = "gradio-logs"
|
12 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
13 |
+
DEFAULT_DATA_RATIO = 1.0#0.124#0.5 #0.2 #1.0
|
14 |
+
GLOBAL_WEIGHTS_PATH = "/path/to/model_weights"
|
vita/conversation.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import Enum, auto
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
|
9 |
+
TWO = auto()
|
10 |
+
PLAIN = auto()
|
11 |
+
Nemo = auto()
|
12 |
+
Qwen2p5Instruct = auto()
|
13 |
+
MixtralZh = auto()
|
14 |
+
MixtralTwo = auto()
|
15 |
+
|
16 |
+
|
17 |
+
@dataclasses.dataclass
|
18 |
+
class Conversation:
|
19 |
+
"""A class that keeps all conversation history."""
|
20 |
+
|
21 |
+
system: str
|
22 |
+
roles: List[str]
|
23 |
+
messages: List[List[str]]
|
24 |
+
offset: int
|
25 |
+
sep_style: SeparatorStyle
|
26 |
+
sep: str = "###"
|
27 |
+
sep2: str = None
|
28 |
+
version: str = "Unknown"
|
29 |
+
|
30 |
+
skip_next: bool = False
|
31 |
+
|
32 |
+
def get_prompt(self, modality=None):
|
33 |
+
messages = self.messages
|
34 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
35 |
+
messages = self.messages.copy()
|
36 |
+
init_role, init_msg = messages[0].copy()
|
37 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
38 |
+
if "mmtag" in self.version:
|
39 |
+
messages[0] = (init_role, init_msg)
|
40 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
41 |
+
messages.insert(1, (self.roles[1], "Received."))
|
42 |
+
else:
|
43 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
44 |
+
|
45 |
+
if self.sep_style == SeparatorStyle.TWO:
|
46 |
+
seps = [self.sep, self.sep2]
|
47 |
+
ret = self.system + seps[0]
|
48 |
+
for i, (role, message) in enumerate(messages):
|
49 |
+
if message:
|
50 |
+
if type(message) is tuple:
|
51 |
+
message, _, _ = message
|
52 |
+
ret += role + ": " + message + seps[i % 2]
|
53 |
+
else:
|
54 |
+
ret += role + ":"
|
55 |
+
|
56 |
+
elif self.sep_style == SeparatorStyle.MixtralZh:
|
57 |
+
seps = [self.sep, self.sep2]
|
58 |
+
ret = "system:" + self.system + seps[0]
|
59 |
+
for i, (role, message) in enumerate(messages):
|
60 |
+
if message:
|
61 |
+
if type(message) is tuple:
|
62 |
+
message, _, _ = message
|
63 |
+
ret += "\n" + role + ":" + message + seps[i % 2]
|
64 |
+
else:
|
65 |
+
ret += "\n" + role + ":"
|
66 |
+
|
67 |
+
elif self.sep_style == SeparatorStyle.MixtralTwo:
|
68 |
+
seps = [self.sep, self.sep2]
|
69 |
+
has_image = False
|
70 |
+
for i, (role, message) in enumerate(messages):
|
71 |
+
if message and "<image>" in message:
|
72 |
+
has_image = True
|
73 |
+
break
|
74 |
+
if has_image:
|
75 |
+
assert modality == "image" or modality == "video"
|
76 |
+
if modality == "image":
|
77 |
+
self.system = self.system[0]
|
78 |
+
elif modality == "video":
|
79 |
+
self.system = self.system[1]
|
80 |
+
else:
|
81 |
+
raise ValueError
|
82 |
+
else:
|
83 |
+
assert modality == "lang"
|
84 |
+
self.system = self.system[2]
|
85 |
+
ret = "system:" + self.system + seps[0]
|
86 |
+
for i, (role, message) in enumerate(messages):
|
87 |
+
if message:
|
88 |
+
if type(message) is tuple:
|
89 |
+
message, _, _ = message
|
90 |
+
ret += "\n" + role + ":" + message + seps[i % 2]
|
91 |
+
else:
|
92 |
+
ret += "\n" + role + ":"
|
93 |
+
|
94 |
+
elif self.sep_style == SeparatorStyle.Nemo:
|
95 |
+
wrap_inst = lambda msg: f"[INST]{msg}[/INST]"
|
96 |
+
seps = [self.sep, self.sep2]
|
97 |
+
has_image = False
|
98 |
+
for i, (role, message) in enumerate(messages):
|
99 |
+
if message and "<image>" in message:
|
100 |
+
has_image = True
|
101 |
+
break
|
102 |
+
if has_image:
|
103 |
+
assert modality == "image" or modality == "video"
|
104 |
+
if modality == "image":
|
105 |
+
self.system = self.system[0]
|
106 |
+
elif modality == "video":
|
107 |
+
self.system = self.system[1]
|
108 |
+
else:
|
109 |
+
raise ValueError
|
110 |
+
else:
|
111 |
+
assert modality == "lang"
|
112 |
+
self.system = self.system[2]
|
113 |
+
ret = ""
|
114 |
+
for i, (role, message) in enumerate(messages):
|
115 |
+
if message:
|
116 |
+
if type(message) is tuple:
|
117 |
+
message, _, _ = message
|
118 |
+
if i == 0:
|
119 |
+
message = self.system + '\n' + message
|
120 |
+
if i % 2 == 0:
|
121 |
+
ret += wrap_inst(message)
|
122 |
+
else:
|
123 |
+
ret += message + seps[i % 2]
|
124 |
+
else:
|
125 |
+
ret += ""
|
126 |
+
|
127 |
+
elif self.sep_style == SeparatorStyle.Qwen2p5Instruct:
|
128 |
+
wrap_qa = lambda msg: f"<|im_start|>{msg}<|im_end|>\n"
|
129 |
+
wrap_qa2 = lambda msg: f"<|im_start|>{msg}<|im_end|>"
|
130 |
+
seps = [self.sep, self.sep2]
|
131 |
+
has_image = False
|
132 |
+
for i, (role, message) in enumerate(messages):
|
133 |
+
if message and "<image>" in message:
|
134 |
+
has_image = True
|
135 |
+
break
|
136 |
+
if has_image:
|
137 |
+
assert modality == "image" or modality == "video"
|
138 |
+
if modality == "image":
|
139 |
+
self.system = self.system[0]
|
140 |
+
elif modality == "video":
|
141 |
+
self.system = self.system[1]
|
142 |
+
else:
|
143 |
+
raise ValueError
|
144 |
+
else:
|
145 |
+
assert modality == "lang"
|
146 |
+
self.system = self.system[2]
|
147 |
+
ret = wrap_qa("system\n" + self.system)
|
148 |
+
for i, (role, message) in enumerate(messages):
|
149 |
+
if message:
|
150 |
+
if type(message) is tuple:
|
151 |
+
message, _, _ = message
|
152 |
+
if i < len(messages) - 1:
|
153 |
+
ret += wrap_qa(role + '\n' + message)
|
154 |
+
else:
|
155 |
+
ret += wrap_qa2(role + '\n' + message)
|
156 |
+
else:
|
157 |
+
ret += "<|im_start|>" + role + '\n'
|
158 |
+
|
159 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
160 |
+
seps = [self.sep, self.sep2]
|
161 |
+
ret = self.system
|
162 |
+
for i, (role, message) in enumerate(messages):
|
163 |
+
if message:
|
164 |
+
if type(message) is tuple:
|
165 |
+
message, _, _ = message
|
166 |
+
ret += message + seps[i % 2]
|
167 |
+
else:
|
168 |
+
ret += ""
|
169 |
+
else:
|
170 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
171 |
+
|
172 |
+
return ret
|
173 |
+
|
174 |
+
def append_message(self, role, message):
|
175 |
+
self.messages.append([role, message])
|
176 |
+
|
177 |
+
def get_images(self, return_pil=False):
|
178 |
+
images = []
|
179 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
180 |
+
if i % 2 == 0:
|
181 |
+
if type(msg) is tuple:
|
182 |
+
import base64
|
183 |
+
from io import BytesIO
|
184 |
+
from PIL import Image
|
185 |
+
|
186 |
+
msg, image, image_process_mode = msg
|
187 |
+
if image_process_mode == "Pad":
|
188 |
+
|
189 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
190 |
+
width, height = pil_img.size
|
191 |
+
if width == height:
|
192 |
+
return pil_img
|
193 |
+
elif width > height:
|
194 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
195 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
196 |
+
return result
|
197 |
+
else:
|
198 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
199 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
200 |
+
return result
|
201 |
+
|
202 |
+
image = expand2square(image)
|
203 |
+
elif image_process_mode in ["Default", "Crop"]:
|
204 |
+
pass
|
205 |
+
elif image_process_mode == "Resize":
|
206 |
+
image = image.resize((336, 336))
|
207 |
+
else:
|
208 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
209 |
+
|
210 |
+
if return_pil:
|
211 |
+
images.append(image)
|
212 |
+
else:
|
213 |
+
buffered = BytesIO()
|
214 |
+
image.save(buffered, format="PNG")
|
215 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
216 |
+
images.append(img_b64_str)
|
217 |
+
return images
|
218 |
+
|
219 |
+
def to_gradio_chatbot(self):
|
220 |
+
ret = []
|
221 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
222 |
+
if i % 2 == 0:
|
223 |
+
if type(msg) is tuple:
|
224 |
+
import base64
|
225 |
+
from io import BytesIO
|
226 |
+
|
227 |
+
msg, image, image_process_mode = msg
|
228 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
229 |
+
aspect_ratio = max_hw / min_hw
|
230 |
+
max_len, min_len = 800, 400
|
231 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
232 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
233 |
+
W, H = image.size
|
234 |
+
if H > W:
|
235 |
+
H, W = longest_edge, shortest_edge
|
236 |
+
else:
|
237 |
+
H, W = shortest_edge, longest_edge
|
238 |
+
image = image.resize((W, H))
|
239 |
+
buffered = BytesIO()
|
240 |
+
image.save(buffered, format="JPEG")
|
241 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
242 |
+
img_str = (
|
243 |
+
f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
244 |
+
)
|
245 |
+
msg = img_str + msg.replace("<image>", "").strip()
|
246 |
+
ret.append([msg, None])
|
247 |
+
else:
|
248 |
+
ret.append([msg, None])
|
249 |
+
else:
|
250 |
+
ret[-1][-1] = msg
|
251 |
+
return ret
|
252 |
+
|
253 |
+
def copy(self):
|
254 |
+
return Conversation(
|
255 |
+
system=self.system,
|
256 |
+
roles=self.roles,
|
257 |
+
messages=[[x, y] for x, y in self.messages],
|
258 |
+
offset=self.offset,
|
259 |
+
sep_style=self.sep_style,
|
260 |
+
sep=self.sep,
|
261 |
+
sep2=self.sep2,
|
262 |
+
version=self.version,
|
263 |
+
)
|
264 |
+
|
265 |
+
def dict(self):
|
266 |
+
if len(self.get_images()) > 0:
|
267 |
+
return {
|
268 |
+
"system": self.system,
|
269 |
+
"roles": self.roles,
|
270 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
271 |
+
"offset": self.offset,
|
272 |
+
"sep": self.sep,
|
273 |
+
"sep2": self.sep2,
|
274 |
+
}
|
275 |
+
return {
|
276 |
+
"system": self.system,
|
277 |
+
"roles": self.roles,
|
278 |
+
"messages": self.messages,
|
279 |
+
"offset": self.offset,
|
280 |
+
"sep": self.sep,
|
281 |
+
"sep2": self.sep2,
|
282 |
+
}
|
283 |
+
|
284 |
+
|
285 |
+
conv_mixtral_zh = Conversation(
|
286 |
+
system="你是一个人工智能机器人。\n- 你是研究社区开发的大语言模型。你的设计宗旨是有益、诚实且无害。\n- 你支持使用用户选择的多种语言流利地进行交流并解答用户的问题。\n- 如果用户更正你生成的错误答案,你会向用户致歉并与用户探讨正确的答案。",
|
287 |
+
roles=("user", "bot"),
|
288 |
+
version="mixtral_zh",
|
289 |
+
messages=(),
|
290 |
+
offset=0,
|
291 |
+
sep_style=SeparatorStyle.MixtralZh,
|
292 |
+
sep="</s>",
|
293 |
+
sep2="</s>",
|
294 |
+
)
|
295 |
+
|
296 |
+
conv_mixtral_two = Conversation(
|
297 |
+
system=[
|
298 |
+
"You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the image given by the user, and it is strictly forbidden to answer the question without the content of the image. Please note that you are seeing the image, not the video.",
|
299 |
+
"You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the video given by the user, and it is strictly forbidden to answer the question without the content of the video. Please note that you are seeing the video, not the image.",
|
300 |
+
"You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user.",
|
301 |
+
],
|
302 |
+
roles=("user", "bot"),
|
303 |
+
version="mixtral_two",
|
304 |
+
messages=(),
|
305 |
+
offset=0,
|
306 |
+
sep_style=SeparatorStyle.MixtralTwo,
|
307 |
+
sep="</s>",
|
308 |
+
sep2="</s>",
|
309 |
+
)
|
310 |
+
|
311 |
+
conv_nemo = Conversation(
|
312 |
+
system=[
|
313 |
+
"You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the image given by the user, and it is strictly forbidden to answer the question without the content of the image. Please note that you are seeing the image, not the video.",
|
314 |
+
"You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the video given by the user, and it is strictly forbidden to answer the question without the content of the video. Please note that you are seeing the video, not the image.",
|
315 |
+
"You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user.",
|
316 |
+
],
|
317 |
+
roles=("USER", "ASSISTANT"),
|
318 |
+
version="nemo",
|
319 |
+
messages=(),
|
320 |
+
offset=0,
|
321 |
+
sep_style=SeparatorStyle.Nemo,
|
322 |
+
sep="[/INST]",
|
323 |
+
sep2="</s>",
|
324 |
+
)
|
325 |
+
|
326 |
+
conv_qwen2p5_instruct = Conversation(
|
327 |
+
system=[
|
328 |
+
"You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the image given by the user, and it is strictly forbidden to answer the question without the content of the image. Please note that you are seeing the image, not the video.",
|
329 |
+
"You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the video given by the user, and it is strictly forbidden to answer the question without the content of the video. Please note that you are seeing the video, not the image.",
|
330 |
+
"You are an AI robot and your name is VITA. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user.",
|
331 |
+
],
|
332 |
+
roles=("user", "assistant"),
|
333 |
+
version="qwen2p5_instruct",
|
334 |
+
messages=(),
|
335 |
+
offset=0,
|
336 |
+
sep_style=SeparatorStyle.Qwen2p5Instruct,
|
337 |
+
sep="<|im_start|>",
|
338 |
+
sep2="<|im_start|>",
|
339 |
+
)
|
340 |
+
|
341 |
+
conv_phi3 = Conversation(
|
342 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
343 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
344 |
+
roles=("USER", "ASSISTANT"),
|
345 |
+
version="phi3",
|
346 |
+
messages=(),
|
347 |
+
offset=0,
|
348 |
+
sep_style=SeparatorStyle.TWO,
|
349 |
+
sep=" ",
|
350 |
+
sep2="<|endoftext|>",
|
351 |
+
)
|
352 |
+
|
353 |
+
conv_minicpm = Conversation(
|
354 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
355 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
356 |
+
roles=("USER", "ASSISTANT"),
|
357 |
+
version="minicpm",
|
358 |
+
messages=(),
|
359 |
+
offset=0,
|
360 |
+
sep_style=SeparatorStyle.TWO,
|
361 |
+
sep=" ",
|
362 |
+
sep2="</s>",
|
363 |
+
)
|
364 |
+
|
365 |
+
conv_llama = Conversation(
|
366 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
367 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
368 |
+
roles=("USER", "ASSISTANT"),
|
369 |
+
version="llama",
|
370 |
+
messages=(),
|
371 |
+
offset=0,
|
372 |
+
sep_style=SeparatorStyle.TWO,
|
373 |
+
sep=" ",
|
374 |
+
sep2="<|end_of_text|>",
|
375 |
+
)
|
376 |
+
|
377 |
+
conv_plain = Conversation(
|
378 |
+
system="",
|
379 |
+
roles=("", ""),
|
380 |
+
messages=(),
|
381 |
+
offset=0,
|
382 |
+
sep_style=SeparatorStyle.PLAIN,
|
383 |
+
sep="\n",
|
384 |
+
)
|
385 |
+
|
386 |
+
default_conversation = conv_mixtral_two
|
387 |
+
conv_templates = {
|
388 |
+
"default": conv_mixtral_two,
|
389 |
+
"nemo": conv_nemo,
|
390 |
+
"qwen2p5_instruct": conv_qwen2p5_instruct,
|
391 |
+
"mixtral_zh": conv_mixtral_zh,
|
392 |
+
"mixtral_two": conv_mixtral_two,
|
393 |
+
"phi3": conv_phi3,
|
394 |
+
"plain": conv_plain,
|
395 |
+
"minicpm": conv_minicpm,
|
396 |
+
"llama": conv_llama,
|
397 |
+
}
|
398 |
+
|
399 |
+
if __name__ == "__main__":
|
400 |
+
print(default_conversation.get_prompt())
|
401 |
+
|
vita/model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .language_model.vita_mixtral import VITAMixtralConfig, VITAMixtralForCausalLM
|
2 |
+
from .language_model.vita_nemo import VITAMistralConfig, VITAMistralForCausalLM
|
3 |
+
from .language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM
|
4 |
+
from .language_model.vita_fo_qwen2 import VITAFOQwen2Config, VITAFOQwen2ForCausalLM
|
5 |
+
|
vita/model/builder.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig, logging
|
6 |
+
|
7 |
+
from vita.constants import GLOBAL_WEIGHTS_PATH
|
8 |
+
from vita.model import *
|
9 |
+
|
10 |
+
logging.set_verbosity_error()
|
11 |
+
warnings.filterwarnings("ignore")
|
12 |
+
|
13 |
+
|
14 |
+
def load_pretrained_model(
|
15 |
+
model_path,
|
16 |
+
model_base,
|
17 |
+
model_name,
|
18 |
+
model_type,
|
19 |
+
load_8bit=False,
|
20 |
+
load_4bit=False,
|
21 |
+
device_map="auto",
|
22 |
+
device="cuda",
|
23 |
+
**kwargs,
|
24 |
+
):
|
25 |
+
if model_type not in {"mixtral-8x7b", "nemo", "qwen2p5_instruct", "qwen2p5_fo_instruct"}:
|
26 |
+
raise ValueError(f"Unknown Model Type {model_type}")
|
27 |
+
|
28 |
+
kwargs = {"device_map": device_map, **kwargs}
|
29 |
+
|
30 |
+
if device != "cuda":
|
31 |
+
kwargs["device_map"] = {"": device}
|
32 |
+
|
33 |
+
if load_8bit:
|
34 |
+
kwargs["load_in_8bit"] = True
|
35 |
+
elif load_4bit:
|
36 |
+
kwargs["load_in_4bit"] = True
|
37 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
38 |
+
load_in_4bit=True,
|
39 |
+
bnb_4bit_compute_dtype=torch.float16,
|
40 |
+
bnb_4bit_use_double_quant=True,
|
41 |
+
bnb_4bit_quant_type="nf4",
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
kwargs["torch_dtype"] = torch.float16
|
45 |
+
|
46 |
+
# Load VITA model
|
47 |
+
if "lora" in model_name.lower() and model_base is None:
|
48 |
+
warnings.warn(
|
49 |
+
"There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument."
|
50 |
+
)
|
51 |
+
if "lora" in model_name.lower() and model_base is not None:
|
52 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
53 |
+
|
54 |
+
print("Loading VITA from base model...")
|
55 |
+
if model_type == "mixtral-8x7b":
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
57 |
+
model = VITAMixtralForCausalLM.from_pretrained(
|
58 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
59 |
+
)
|
60 |
+
|
61 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
62 |
+
if model.lm_head.weight.shape[0] != token_num:
|
63 |
+
model.lm_head.weight = torch.nn.Parameter(
|
64 |
+
torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
|
65 |
+
)
|
66 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(
|
67 |
+
torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
|
68 |
+
)
|
69 |
+
|
70 |
+
print("Loading additional VITA weights...")
|
71 |
+
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
|
72 |
+
non_lora_trainables = torch.load(
|
73 |
+
os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu"
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
# this is probably from HF Hub
|
77 |
+
from huggingface_hub import hf_hub_download
|
78 |
+
|
79 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
80 |
+
cache_file = hf_hub_download(
|
81 |
+
repo_id=repo_id, filename=filename, subfolder=subfolder
|
82 |
+
)
|
83 |
+
return torch.load(cache_file, map_location="cpu")
|
84 |
+
|
85 |
+
non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
|
86 |
+
|
87 |
+
non_lora_trainables = {
|
88 |
+
(k[11:] if k.startswith("base_model.") else k): v
|
89 |
+
for k, v in non_lora_trainables.items()
|
90 |
+
}
|
91 |
+
if any(k.startswith("model.model.") for k in non_lora_trainables):
|
92 |
+
non_lora_trainables = {
|
93 |
+
(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()
|
94 |
+
}
|
95 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
96 |
+
|
97 |
+
from peft import PeftModel
|
98 |
+
|
99 |
+
print("Loading LoRA weights...")
|
100 |
+
model = PeftModel.from_pretrained(model, model_path)
|
101 |
+
print("Merging LoRA weights...")
|
102 |
+
model = model.merge_and_unload()
|
103 |
+
print("Model is loaded...")
|
104 |
+
elif model_base is not None:
|
105 |
+
# this may be mm projector only
|
106 |
+
print("Loading VITA from base model...")
|
107 |
+
|
108 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
109 |
+
if model_type == "mixtral-8x7b":
|
110 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
|
111 |
+
model = VITAMixtralForCausalLM.from_pretrained(
|
112 |
+
model_base, low_cpu_mem_usage=True, **kwargs
|
113 |
+
)
|
114 |
+
|
115 |
+
# load vision encoder
|
116 |
+
from types import SimpleNamespace
|
117 |
+
model_args = {
|
118 |
+
"vision_tower": f"{GLOBAL_WEIGHTS_PATH}/InternViT-300M-448px",
|
119 |
+
"pretrain_mm_mlp_adapter": None,
|
120 |
+
"mm_projector_type": "mlp2x_gelu",
|
121 |
+
}
|
122 |
+
model_args = SimpleNamespace(**model_args)
|
123 |
+
model.get_model().initialize_vision_modules(model_args=model_args)
|
124 |
+
|
125 |
+
# load audio encoder
|
126 |
+
from types import SimpleNamespace
|
127 |
+
model_args = {
|
128 |
+
'audio_encoder': f"{GLOBAL_WEIGHTS_PATH}/audio-encoder-2wh_zh_en_audioset_Mixtral-8x7B_New-base-tunning",
|
129 |
+
'freeze_audio_encoder': True,
|
130 |
+
'freeze_audio_encoder_adapter': True
|
131 |
+
}
|
132 |
+
model_args = SimpleNamespace(**model_args)
|
133 |
+
model.get_model().initialize_audio_modules(model_args=model_args)
|
134 |
+
audio_encoder = model.get_audio_encoder()
|
135 |
+
device = torch.device('cuda:0')
|
136 |
+
audio_encoder = audio_encoder.to(device)
|
137 |
+
|
138 |
+
mm_projector_weights = torch.load(
|
139 |
+
os.path.join(model_path, "mm_projector.bin"), map_location="cpu"
|
140 |
+
)
|
141 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
142 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
143 |
+
model.model.mm_projector.to(device="cuda", dtype=torch.float16)
|
144 |
+
model.model.vision_tower.to(device="cuda", dtype=torch.float16)
|
145 |
+
else:
|
146 |
+
if model_type == "mixtral-8x7b":
|
147 |
+
# import pdb; pdb.set_trace()
|
148 |
+
device_map = {
|
149 |
+
"model.embed_tokens": 0,
|
150 |
+
"model.layers.0": 0,
|
151 |
+
"model.layers.1": 0,
|
152 |
+
"model.layers.2": 0,
|
153 |
+
"model.layers.3": 0,
|
154 |
+
"model.layers.4": 0,
|
155 |
+
"model.layers.5": 0,
|
156 |
+
"model.layers.6": 0,
|
157 |
+
"model.layers.7": 0,
|
158 |
+
"model.layers.8": 0,
|
159 |
+
"model.layers.9": 0,
|
160 |
+
"model.layers.10": 0,
|
161 |
+
"model.layers.11": 0,
|
162 |
+
"model.layers.12": 0,
|
163 |
+
"model.layers.13": 0,
|
164 |
+
"model.layers.14": 0,
|
165 |
+
"model.layers.15": 0,
|
166 |
+
"model.layers.16": 1,
|
167 |
+
"model.layers.17": 1,
|
168 |
+
"model.layers.18": 1,
|
169 |
+
"model.layers.19": 1,
|
170 |
+
"model.layers.20": 1,
|
171 |
+
"model.layers.21": 1,
|
172 |
+
"model.layers.22": 1,
|
173 |
+
"model.layers.23": 1,
|
174 |
+
"model.layers.24": 1,
|
175 |
+
"model.layers.25": 1,
|
176 |
+
"model.layers.26": 1,
|
177 |
+
"model.layers.27": 1,
|
178 |
+
"model.layers.28": 1,
|
179 |
+
"model.layers.29": 1,
|
180 |
+
"model.layers.30": 1,
|
181 |
+
"model.layers.31": 1,
|
182 |
+
"model.norm": 1,
|
183 |
+
"model.vision_tower": 1,
|
184 |
+
"model.mm_projector": 1,
|
185 |
+
"model.audio_encoder": 1,
|
186 |
+
"lm_head": 1,
|
187 |
+
}
|
188 |
+
device_map["model.audio_encoder"] = 0
|
189 |
+
kwargs.update(device_map=device_map)
|
190 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
191 |
+
model = VITAMixtralForCausalLM.from_pretrained(
|
192 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
193 |
+
)
|
194 |
+
# model.hf_device_map
|
195 |
+
elif model_type == "nemo":
|
196 |
+
# import pdb; pdb.set_trace()
|
197 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
198 |
+
model = VITAMistralForCausalLM.from_pretrained(
|
199 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
200 |
+
)
|
201 |
+
elif model_type == "qwen2p5_instruct":
|
202 |
+
# import pdb; pdb.set_trace()
|
203 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
204 |
+
model = VITAQwen2ForCausalLM.from_pretrained(
|
205 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
206 |
+
)
|
207 |
+
elif model_type == "qwen2p5_fo_instruct":
|
208 |
+
# import pdb; pdb.set_trace()
|
209 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
210 |
+
model = VITAFOQwen2ForCausalLM.from_pretrained(
|
211 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
212 |
+
)
|
213 |
+
|
214 |
+
model.resize_token_embeddings(len(tokenizer))
|
215 |
+
|
216 |
+
vision_tower = model.get_vision_tower()
|
217 |
+
if not vision_tower.is_loaded:
|
218 |
+
vision_tower.load_model()
|
219 |
+
|
220 |
+
num_params = sum(p.numel() for p in vision_tower.parameters())
|
221 |
+
print("the number of vision encoder params: {}M".format(num_params / 1024 / 1024))
|
222 |
+
|
223 |
+
if getattr(model.config, "unfreeze_vision_tower", False):
|
224 |
+
if "lora" in model_name.lower():
|
225 |
+
assert model_base is not None
|
226 |
+
vision_non_lora_trainables = {
|
227 |
+
k[19:]: v
|
228 |
+
for k, v in non_lora_trainables.items()
|
229 |
+
if k.startswith("model.vision_tower.")
|
230 |
+
}
|
231 |
+
vision_tower.load_state_dict(vision_non_lora_trainables, strict=False)
|
232 |
+
else:
|
233 |
+
assert model_base is None
|
234 |
+
from safetensors.torch import load_file
|
235 |
+
|
236 |
+
vision_weights = {}
|
237 |
+
for file_name in os.listdir(model_path):
|
238 |
+
if file_name.endswith("safetensors"):
|
239 |
+
vision_weights.update(
|
240 |
+
{
|
241 |
+
k[19:]: v
|
242 |
+
for k, v in load_file(os.path.join(model_path, file_name)).items()
|
243 |
+
if k.startswith("model.vision_tower.")
|
244 |
+
}
|
245 |
+
)
|
246 |
+
vision_tower.load_state_dict(vision_weights, strict=True)
|
247 |
+
|
248 |
+
# import pdb; pdb.set_trace()
|
249 |
+
# if (not getattr(model.config, "freeze_audio_encoder", True)) and (not getattr(model.config, "freeze_audio_encoder_adapter", True)):
|
250 |
+
# from safetensors.torch import load_file
|
251 |
+
# audio_weights = {}
|
252 |
+
# for file_name in os.listdir(model_path):
|
253 |
+
# if file_name.endswith('safetensors'):
|
254 |
+
# audio_weights.update(
|
255 |
+
# {k[20:]: v for k, v in load_file(os.path.join(model_path, file_name)).items() if
|
256 |
+
# k.startswith('model.audio_encoder.')})
|
257 |
+
# audio_encoder.load_state_dict(audio_weights, strict=True)
|
258 |
+
# audio_encoder.eval()
|
259 |
+
# import pdb; pdb.set_trace()
|
260 |
+
|
261 |
+
# import pdb; pdb.set_trace()
|
262 |
+
# from safetensors.torch import load_file
|
263 |
+
# audio_weights = {}
|
264 |
+
# for file_name in os.listdir(model_path):
|
265 |
+
# if file_name.endswith('safetensors'):
|
266 |
+
# audio_weights.update(
|
267 |
+
# {k[20:]: v for k, v in load_file(os.path.join(model_path, file_name)).items() if
|
268 |
+
# k.startswith('model.audio_encoder.')})
|
269 |
+
# import pdb; pdb.set_trace()
|
270 |
+
|
271 |
+
vision_tower.to(dtype=torch.float16)
|
272 |
+
image_processor = vision_tower.image_processor
|
273 |
+
|
274 |
+
#import pdb; pdb.set_trace()
|
275 |
+
if hasattr(model.config, "max_sequence_length"):
|
276 |
+
context_len = model.config.max_sequence_length
|
277 |
+
else:
|
278 |
+
context_len = 2048
|
279 |
+
|
280 |
+
if model.generation_config.pad_token_id is None:
|
281 |
+
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
282 |
+
|
283 |
+
if model_type == "phi-3":
|
284 |
+
model.generation_config.eos_token_id = tokenizer.eos_token_id
|
285 |
+
|
286 |
+
return tokenizer, model, image_processor, context_len
|
287 |
+
|
vita/model/language_model/vita_fo_qwen2.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from PIL import Image
|
6 |
+
from torch.nn import CrossEntropyLoss
|
7 |
+
from transformers import (
|
8 |
+
AutoConfig,
|
9 |
+
AutoModelForCausalLM,
|
10 |
+
Qwen2Config,
|
11 |
+
Qwen2ForCausalLM,
|
12 |
+
Qwen2Model,
|
13 |
+
)
|
14 |
+
from transformers.cache_utils import Cache, DynamicCache
|
15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, MoeCausalLMOutputWithPast
|
16 |
+
from transformers.generation.utils import GenerateOutput
|
17 |
+
|
18 |
+
from ..vita_arch import VITAMetaForCausalLM, VITAMetaModel
|
19 |
+
from ...constants import IGNORE_INDEX
|
20 |
+
from .vita_qwen2 import custom_forward
|
21 |
+
|
22 |
+
|
23 |
+
Qwen2ForCausalLM.forward = custom_forward
|
24 |
+
|
25 |
+
|
26 |
+
class VITAFOQwen2Config(Qwen2Config):
|
27 |
+
model_type = "vita-fo-Qwen2"
|
28 |
+
|
29 |
+
|
30 |
+
class VITAFOQwen2Model(VITAMetaModel, Qwen2Model):
|
31 |
+
config_class = VITAFOQwen2Config
|
32 |
+
|
33 |
+
def __init__(self, config: Qwen2Config):
|
34 |
+
super(VITAFOQwen2Model, self).__init__(config)
|
35 |
+
|
36 |
+
|
37 |
+
class VITAFOQwen2ForCausalLM(Qwen2ForCausalLM, VITAMetaForCausalLM):
|
38 |
+
config_class = VITAFOQwen2Config
|
39 |
+
|
40 |
+
def __init__(self, config):
|
41 |
+
super(Qwen2ForCausalLM, self).__init__(config)
|
42 |
+
self.model = VITAFOQwen2Model(config)
|
43 |
+
self.vocab_size = config.vocab_size
|
44 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
45 |
+
self.predict_usr_state = 0#2
|
46 |
+
if self.predict_usr_state:
|
47 |
+
self.predictor_head = torch.nn.Linear(config.hidden_size, self.predict_usr_state + 1) # +1 for the dummy class
|
48 |
+
else:
|
49 |
+
self.predictor_head = None
|
50 |
+
# Initialize weights and apply final processing
|
51 |
+
self.post_init()
|
52 |
+
|
53 |
+
def get_model(self):
|
54 |
+
return self.model
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
input_ids: torch.LongTensor = None,
|
59 |
+
attention_mask: Optional[torch.Tensor] = None,
|
60 |
+
position_ids: Optional[torch.LongTensor] = None,
|
61 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
62 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
63 |
+
labels: Optional[torch.LongTensor] = None,
|
64 |
+
use_cache: Optional[bool] = None,
|
65 |
+
output_attentions: Optional[bool] = None,
|
66 |
+
output_hidden_states: Optional[bool] = None,
|
67 |
+
images: Optional[torch.FloatTensor] = None,
|
68 |
+
audios: Optional[dict] = None,
|
69 |
+
sf_masks: Optional[torch.Tensor] = None,
|
70 |
+
return_dict: Optional[bool] = None,
|
71 |
+
cache_position: Optional[torch.LongTensor] = None,
|
72 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
73 |
+
if inputs_embeds is None:
|
74 |
+
(
|
75 |
+
input_ids,
|
76 |
+
position_ids,
|
77 |
+
attention_mask,
|
78 |
+
past_key_values,
|
79 |
+
inputs_embeds,
|
80 |
+
labels,
|
81 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
82 |
+
input_ids, position_ids, attention_mask, past_key_values, labels, images, audios, sf_masks
|
83 |
+
)
|
84 |
+
if labels is not None:
|
85 |
+
state_labels = labels
|
86 |
+
labels = torch.where(labels>=0, labels, IGNORE_INDEX)
|
87 |
+
output_hidden_states = True
|
88 |
+
outputs = super().forward(
|
89 |
+
input_ids=input_ids,
|
90 |
+
attention_mask=attention_mask,
|
91 |
+
position_ids=position_ids,
|
92 |
+
past_key_values=past_key_values,
|
93 |
+
inputs_embeds=inputs_embeds,
|
94 |
+
labels=labels,
|
95 |
+
use_cache=use_cache,
|
96 |
+
output_attentions=output_attentions,
|
97 |
+
output_hidden_states=output_hidden_states,
|
98 |
+
return_dict=return_dict,
|
99 |
+
cache_position=cache_position,
|
100 |
+
)
|
101 |
+
# state loss
|
102 |
+
if self.predictor_head is not None:
|
103 |
+
state_logits = self.predictor_head(outputs[2][-1]).view(-1, self.predict_usr_state+1) # +1 for the dummy class
|
104 |
+
if labels is not None:
|
105 |
+
loss = outputs[0]
|
106 |
+
weight = torch.Tensor([1, 5, 1]).to(torch.bfloat16).to(inputs_embeds.device)
|
107 |
+
loss_fct = torch.nn.CrossEntropyLoss(weight=weight)
|
108 |
+
s_labels= torch.where(
|
109 |
+
state_labels < IGNORE_INDEX,
|
110 |
+
IGNORE_INDEX-state_labels-1,
|
111 |
+
IGNORE_INDEX).view(-1)
|
112 |
+
#assert all(label in [0, 1, IGNORE_INDEX] for label in s_labels), "s_labels must contain only 0, 1, or -100"
|
113 |
+
state_loss = loss_fct(state_logits, s_labels)
|
114 |
+
loss = loss + state_loss
|
115 |
+
outputs['loss'] = loss
|
116 |
+
return outputs
|
117 |
+
|
118 |
+
@torch.no_grad()
|
119 |
+
def generate(
|
120 |
+
self,
|
121 |
+
inputs: Optional[torch.Tensor] = None,
|
122 |
+
images: Optional[torch.Tensor] = None,
|
123 |
+
audios: Optional[torch.Tensor] = None,
|
124 |
+
sf_masks: Optional[torch.Tensor] = None,
|
125 |
+
**kwargs,
|
126 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
127 |
+
position_ids = kwargs.pop("position_ids", None)
|
128 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
129 |
+
if "inputs_embeds" in kwargs:
|
130 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
131 |
+
|
132 |
+
if images is not None or audios is not None:
|
133 |
+
(
|
134 |
+
inputs,
|
135 |
+
position_ids,
|
136 |
+
attention_mask,
|
137 |
+
_,
|
138 |
+
inputs_embeds,
|
139 |
+
_
|
140 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
141 |
+
inputs,
|
142 |
+
position_ids,
|
143 |
+
attention_mask,
|
144 |
+
None,
|
145 |
+
None,
|
146 |
+
images,
|
147 |
+
audios,
|
148 |
+
sf_masks,
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
152 |
+
|
153 |
+
return super().generate(
|
154 |
+
position_ids=position_ids,
|
155 |
+
attention_mask=attention_mask,
|
156 |
+
inputs_embeds=inputs_embeds,
|
157 |
+
**kwargs
|
158 |
+
)
|
159 |
+
|
160 |
+
def prepare_inputs_for_generation(
|
161 |
+
self,
|
162 |
+
input_ids,
|
163 |
+
past_key_values=None,
|
164 |
+
inputs_embeds=None,
|
165 |
+
attention_mask=None,
|
166 |
+
**kwargs,
|
167 |
+
):
|
168 |
+
images = kwargs.pop("images", None)
|
169 |
+
audios = kwargs.pop("audios", None)
|
170 |
+
sf_masks = kwargs.pop("sf_masks", None)
|
171 |
+
|
172 |
+
_inputs = super().prepare_inputs_for_generation(
|
173 |
+
input_ids,
|
174 |
+
past_key_values=past_key_values,
|
175 |
+
inputs_embeds=inputs_embeds,
|
176 |
+
attention_mask=attention_mask,
|
177 |
+
**kwargs,
|
178 |
+
)
|
179 |
+
|
180 |
+
if images is not None:
|
181 |
+
_inputs["images"] = images
|
182 |
+
if audios is not None:
|
183 |
+
_inputs["audios"] = audios
|
184 |
+
if sf_masks is not None:
|
185 |
+
_inputs["sf_masks"] = sf_masks
|
186 |
+
return _inputs
|
187 |
+
|
188 |
+
def expand2square(self, pil_img, background_color):
|
189 |
+
width, height = pil_img.size
|
190 |
+
if width == height:
|
191 |
+
return pil_img
|
192 |
+
elif width > height:
|
193 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
194 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
195 |
+
return result
|
196 |
+
else:
|
197 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
198 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
199 |
+
return result
|
200 |
+
|
201 |
+
def process_images(self, images, model_cfg):
|
202 |
+
vision_tower = self.get_vision_tower()
|
203 |
+
if not vision_tower.is_loaded:
|
204 |
+
vision_tower.load_model()
|
205 |
+
image_processor = vision_tower.image_processor
|
206 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
207 |
+
new_images = []
|
208 |
+
if image_aspect_ratio == "pad":
|
209 |
+
for image in images:
|
210 |
+
image = self.expand2square(
|
211 |
+
image, tuple(int(x * 255) for x in image_processor.image_mean)
|
212 |
+
)
|
213 |
+
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
214 |
+
new_images.append(image)
|
215 |
+
else:
|
216 |
+
return image_processor(images, return_tensors="pt")["pixel_values"]
|
217 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
218 |
+
new_images = torch.stack(new_images, dim=0)
|
219 |
+
return new_images
|
220 |
+
|
221 |
+
|
222 |
+
AutoConfig.register("vita-fo-Qwen2", VITAFOQwen2Config)
|
223 |
+
AutoModelForCausalLM.register(VITAFOQwen2Config, VITAFOQwen2ForCausalLM)
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
|
vita/model/language_model/vita_mixtral.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from PIL import Image
|
6 |
+
from torch.nn import CrossEntropyLoss
|
7 |
+
from transformers import (
|
8 |
+
AutoConfig,
|
9 |
+
AutoModelForCausalLM,
|
10 |
+
MixtralConfig,
|
11 |
+
MixtralForCausalLM,
|
12 |
+
MixtralModel,
|
13 |
+
)
|
14 |
+
from transformers.cache_utils import Cache, DynamicCache
|
15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, MoeCausalLMOutputWithPast
|
16 |
+
|
17 |
+
from ..vita_arch import VITAMetaForCausalLM, VITAMetaModel
|
18 |
+
|
19 |
+
|
20 |
+
def load_balancing_loss_func(
|
21 |
+
gate_logits: torch.Tensor,
|
22 |
+
num_experts: torch.Tensor = None,
|
23 |
+
top_k=2,
|
24 |
+
attention_mask: Optional[torch.Tensor] = None,
|
25 |
+
) -> float:
|
26 |
+
r"""
|
27 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
28 |
+
|
29 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
30 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
31 |
+
experts is too unbalanced.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
35 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
36 |
+
shape [batch_size X sequence_length, num_experts].
|
37 |
+
attention_mask (`torch.Tensor`, None):
|
38 |
+
The attention_mask used in forward function
|
39 |
+
shape [batch_size X sequence_length] if not None.
|
40 |
+
num_experts (`int`, *optional*):
|
41 |
+
Number of experts
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
The auxiliary loss.
|
45 |
+
"""
|
46 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
47 |
+
return 0
|
48 |
+
|
49 |
+
if isinstance(gate_logits, tuple):
|
50 |
+
compute_device = gate_logits[0].device
|
51 |
+
concatenated_gate_logits = torch.cat(
|
52 |
+
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
|
53 |
+
)
|
54 |
+
|
55 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
56 |
+
|
57 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
58 |
+
|
59 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
60 |
+
|
61 |
+
if attention_mask is None:
|
62 |
+
# Compute the percentage of tokens routed to each experts
|
63 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
64 |
+
|
65 |
+
# Compute the average probability of routing to these experts
|
66 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
67 |
+
else:
|
68 |
+
batch_size, sequence_length = attention_mask.shape
|
69 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
70 |
+
|
71 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
72 |
+
expert_attention_mask = (
|
73 |
+
attention_mask[None, :, :, None, None]
|
74 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
75 |
+
.reshape(-1, top_k, num_experts)
|
76 |
+
.to(compute_device)
|
77 |
+
)
|
78 |
+
|
79 |
+
# Compute the percentage of tokens routed to each experts
|
80 |
+
tokens_per_expert = torch.sum(
|
81 |
+
expert_mask.float() * expert_attention_mask, dim=0
|
82 |
+
) / torch.sum(expert_attention_mask, dim=0)
|
83 |
+
|
84 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
85 |
+
router_per_expert_attention_mask = (
|
86 |
+
attention_mask[None, :, :, None]
|
87 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
88 |
+
.reshape(-1, num_experts)
|
89 |
+
.to(compute_device)
|
90 |
+
)
|
91 |
+
|
92 |
+
# Compute the average probability of routing to these experts
|
93 |
+
router_prob_per_expert = torch.sum(
|
94 |
+
routing_weights * router_per_expert_attention_mask, dim=0
|
95 |
+
) / torch.sum(router_per_expert_attention_mask, dim=0)
|
96 |
+
|
97 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
98 |
+
return overall_loss * num_experts
|
99 |
+
|
100 |
+
|
101 |
+
def custom_forward(
|
102 |
+
self,
|
103 |
+
input_ids: torch.LongTensor = None,
|
104 |
+
attention_mask: Optional[torch.Tensor] = None,
|
105 |
+
position_ids: Optional[torch.LongTensor] = None,
|
106 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
107 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
108 |
+
labels: Optional[torch.LongTensor] = None,
|
109 |
+
use_cache: Optional[bool] = None,
|
110 |
+
output_attentions: Optional[bool] = None,
|
111 |
+
output_hidden_states: Optional[bool] = None,
|
112 |
+
output_router_logits: Optional[bool] = None,
|
113 |
+
return_dict: Optional[bool] = None,
|
114 |
+
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
115 |
+
r"""
|
116 |
+
Args:
|
117 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
118 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
119 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
120 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
|
124 |
+
Example:
|
125 |
+
|
126 |
+
```python
|
127 |
+
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
128 |
+
|
129 |
+
>>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
|
130 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
|
131 |
+
|
132 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
133 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
134 |
+
|
135 |
+
>>> # Generate
|
136 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
137 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
138 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
139 |
+
```"""
|
140 |
+
|
141 |
+
output_attentions = (
|
142 |
+
output_attentions if output_attentions is not None else self.config.output_attentions
|
143 |
+
)
|
144 |
+
output_router_logits = (
|
145 |
+
output_router_logits
|
146 |
+
if output_router_logits is not None
|
147 |
+
else self.config.output_router_logits
|
148 |
+
)
|
149 |
+
|
150 |
+
output_hidden_states = (
|
151 |
+
output_hidden_states
|
152 |
+
if output_hidden_states is not None
|
153 |
+
else self.config.output_hidden_states
|
154 |
+
)
|
155 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
156 |
+
|
157 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
158 |
+
outputs = self.model(
|
159 |
+
input_ids=input_ids,
|
160 |
+
attention_mask=attention_mask,
|
161 |
+
position_ids=position_ids,
|
162 |
+
past_key_values=past_key_values,
|
163 |
+
inputs_embeds=inputs_embeds,
|
164 |
+
use_cache=use_cache,
|
165 |
+
output_attentions=output_attentions,
|
166 |
+
output_hidden_states=output_hidden_states,
|
167 |
+
output_router_logits=output_router_logits,
|
168 |
+
return_dict=return_dict,
|
169 |
+
)
|
170 |
+
|
171 |
+
hidden_states = outputs[0]
|
172 |
+
logits = self.lm_head(hidden_states)
|
173 |
+
# logits = logits.float()
|
174 |
+
|
175 |
+
loss = None
|
176 |
+
if labels is not None:
|
177 |
+
# Shift so that tokens < n predict n
|
178 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
179 |
+
shift_labels = labels[..., 1:].contiguous()
|
180 |
+
# Flatten the tokens
|
181 |
+
loss_fct = CrossEntropyLoss()
|
182 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
183 |
+
shift_labels = shift_labels.view(-1)
|
184 |
+
# Enable model parallelism
|
185 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
186 |
+
loss = loss_fct(shift_logits, shift_labels)
|
187 |
+
|
188 |
+
aux_loss = None
|
189 |
+
if output_router_logits:
|
190 |
+
aux_loss = load_balancing_loss_func(
|
191 |
+
outputs.router_logits if return_dict else outputs[-1],
|
192 |
+
self.num_experts,
|
193 |
+
self.num_experts_per_tok,
|
194 |
+
attention_mask,
|
195 |
+
)
|
196 |
+
if labels is not None:
|
197 |
+
loss += self.router_aux_loss_coef * aux_loss.to(
|
198 |
+
loss.device
|
199 |
+
) # make sure to reside in the same device
|
200 |
+
|
201 |
+
if not return_dict:
|
202 |
+
output = (logits,) + outputs[1:]
|
203 |
+
if output_router_logits:
|
204 |
+
output = (aux_loss,) + output
|
205 |
+
return (loss,) + output if loss is not None else output
|
206 |
+
|
207 |
+
return MoeCausalLMOutputWithPast(
|
208 |
+
loss=loss,
|
209 |
+
aux_loss=aux_loss,
|
210 |
+
logits=logits,
|
211 |
+
past_key_values=outputs.past_key_values,
|
212 |
+
hidden_states=outputs.hidden_states,
|
213 |
+
attentions=outputs.attentions,
|
214 |
+
router_logits=outputs.router_logits,
|
215 |
+
)
|
216 |
+
|
217 |
+
|
218 |
+
MixtralForCausalLM.forward = custom_forward
|
219 |
+
|
220 |
+
|
221 |
+
class VITAMixtralConfig(MixtralConfig):
|
222 |
+
model_type = "vita-mixtral"
|
223 |
+
|
224 |
+
|
225 |
+
class VITAMixtralModel(VITAMetaModel, MixtralModel):
|
226 |
+
config_class = VITAMixtralConfig
|
227 |
+
|
228 |
+
def __init__(self, config: MixtralConfig):
|
229 |
+
super(VITAMixtralModel, self).__init__(config)
|
230 |
+
|
231 |
+
|
232 |
+
class VITAMixtralForCausalLM(MixtralForCausalLM, VITAMetaForCausalLM):
|
233 |
+
config_class = VITAMixtralConfig
|
234 |
+
|
235 |
+
def __init__(self, config):
|
236 |
+
super(MixtralForCausalLM, self).__init__(config)
|
237 |
+
self.model = VITAMixtralModel(config)
|
238 |
+
self.vocab_size = config.vocab_size
|
239 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
240 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
241 |
+
self.num_experts = config.num_local_experts
|
242 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
243 |
+
# Initialize weights and apply final processing
|
244 |
+
self.post_init()
|
245 |
+
|
246 |
+
def get_model(self):
|
247 |
+
return self.model
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
input_ids: torch.LongTensor = None,
|
252 |
+
attention_mask: Optional[torch.Tensor] = None,
|
253 |
+
position_ids: Optional[torch.LongTensor] = None,
|
254 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
255 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
256 |
+
labels: Optional[torch.LongTensor] = None,
|
257 |
+
use_cache: Optional[bool] = None,
|
258 |
+
output_attentions: Optional[bool] = None,
|
259 |
+
output_hidden_states: Optional[bool] = None,
|
260 |
+
images: Optional[torch.FloatTensor] = None,
|
261 |
+
audios: Optional[dict] = None,
|
262 |
+
sf_masks: Optional[torch.Tensor] = None,
|
263 |
+
output_router_logits: Optional[bool] = None,
|
264 |
+
return_dict: Optional[bool] = None,
|
265 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
266 |
+
if inputs_embeds is None:
|
267 |
+
(
|
268 |
+
input_ids,
|
269 |
+
position_ids,
|
270 |
+
attention_mask,
|
271 |
+
past_key_values,
|
272 |
+
inputs_embeds,
|
273 |
+
labels,
|
274 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
275 |
+
input_ids, position_ids, attention_mask, past_key_values, labels, images, audios, sf_masks
|
276 |
+
)
|
277 |
+
|
278 |
+
return super().forward(
|
279 |
+
input_ids=input_ids,
|
280 |
+
attention_mask=attention_mask,
|
281 |
+
position_ids=position_ids,
|
282 |
+
past_key_values=past_key_values,
|
283 |
+
inputs_embeds=inputs_embeds,
|
284 |
+
labels=labels,
|
285 |
+
use_cache=use_cache,
|
286 |
+
output_attentions=output_attentions,
|
287 |
+
output_hidden_states=output_hidden_states,
|
288 |
+
output_router_logits=output_router_logits,
|
289 |
+
return_dict=return_dict,
|
290 |
+
)
|
291 |
+
|
292 |
+
def prepare_inputs_for_generation_original(
|
293 |
+
self,
|
294 |
+
input_ids,
|
295 |
+
past_key_values=None,
|
296 |
+
attention_mask=None,
|
297 |
+
inputs_embeds=None,
|
298 |
+
output_router_logits=False,
|
299 |
+
**kwargs,
|
300 |
+
):
|
301 |
+
# Omit tokens covered by past_key_values
|
302 |
+
if past_key_values is not None:
|
303 |
+
if isinstance(past_key_values, Cache):
|
304 |
+
cache_length = past_key_values.get_seq_length()
|
305 |
+
past_length = past_key_values.seen_tokens
|
306 |
+
max_cache_length = past_key_values.get_max_length()
|
307 |
+
else:
|
308 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
309 |
+
max_cache_length = None
|
310 |
+
|
311 |
+
# Keep only the unprocessed tokens:
|
312 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
313 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
314 |
+
# input)
|
315 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
316 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
317 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
318 |
+
# input_ids based on the past_length.
|
319 |
+
elif past_length < input_ids.shape[1]:
|
320 |
+
input_ids = input_ids[:, past_length:]
|
321 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
322 |
+
else:
|
323 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
324 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
325 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
326 |
+
if (
|
327 |
+
max_cache_length is not None
|
328 |
+
and attention_mask is not None
|
329 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
330 |
+
):
|
331 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
332 |
+
|
333 |
+
position_ids = kwargs.get("position_ids", None)
|
334 |
+
if attention_mask is not None and position_ids is None:
|
335 |
+
# create position_ids on the fly for batch generation
|
336 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
337 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
338 |
+
if past_key_values:
|
339 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
340 |
+
|
341 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
342 |
+
if inputs_embeds is not None and past_key_values is None:
|
343 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
344 |
+
else:
|
345 |
+
model_inputs = {"input_ids": input_ids}
|
346 |
+
|
347 |
+
model_inputs.update(
|
348 |
+
{
|
349 |
+
"position_ids": position_ids,
|
350 |
+
"past_key_values": past_key_values,
|
351 |
+
"use_cache": kwargs.get("use_cache"),
|
352 |
+
"attention_mask": attention_mask,
|
353 |
+
"output_router_logits": output_router_logits,
|
354 |
+
}
|
355 |
+
)
|
356 |
+
return model_inputs
|
357 |
+
|
358 |
+
def prepare_inputs_for_generation(
|
359 |
+
self,
|
360 |
+
input_ids,
|
361 |
+
past_key_values=None,
|
362 |
+
inputs_embeds=None,
|
363 |
+
attention_mask=None,
|
364 |
+
output_router_logits=False,
|
365 |
+
**kwargs,
|
366 |
+
):
|
367 |
+
images = kwargs.pop("images", None)
|
368 |
+
audios = kwargs.pop("audios", None)
|
369 |
+
|
370 |
+
_inputs = self.prepare_inputs_for_generation_original(
|
371 |
+
input_ids,
|
372 |
+
past_key_values=past_key_values,
|
373 |
+
inputs_embeds=inputs_embeds,
|
374 |
+
attention_mask=attention_mask,
|
375 |
+
output_router_logits=output_router_logits,
|
376 |
+
**kwargs,
|
377 |
+
)
|
378 |
+
|
379 |
+
if images is not None:
|
380 |
+
_inputs["images"] = images
|
381 |
+
if audios is not None:
|
382 |
+
_inputs["audios"] = audios
|
383 |
+
return _inputs
|
384 |
+
|
385 |
+
def expand2square(self, pil_img, background_color):
|
386 |
+
width, height = pil_img.size
|
387 |
+
if width == height:
|
388 |
+
return pil_img
|
389 |
+
elif width > height:
|
390 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
391 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
392 |
+
return result
|
393 |
+
else:
|
394 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
395 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
396 |
+
return result
|
397 |
+
|
398 |
+
def process_images(self, images, model_cfg):
|
399 |
+
vision_tower = self.get_vision_tower()
|
400 |
+
if not vision_tower.is_loaded:
|
401 |
+
vision_tower.load_model()
|
402 |
+
image_processor = vision_tower.image_processor
|
403 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
404 |
+
new_images = []
|
405 |
+
if image_aspect_ratio == "pad":
|
406 |
+
for image in images:
|
407 |
+
image = self.expand2square(
|
408 |
+
image, tuple(int(x * 255) for x in image_processor.image_mean)
|
409 |
+
)
|
410 |
+
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
411 |
+
new_images.append(image)
|
412 |
+
else:
|
413 |
+
return image_processor(images, return_tensors="pt")["pixel_values"]
|
414 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
415 |
+
new_images = torch.stack(new_images, dim=0)
|
416 |
+
return new_images
|
417 |
+
|
418 |
+
|
419 |
+
AutoConfig.register("vita-mixtral", VITAMixtralConfig)
|
420 |
+
AutoModelForCausalLM.register(VITAMixtralConfig, VITAMixtralForCausalLM)
|
vita/model/language_model/vita_nemo.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from PIL import Image
|
6 |
+
from torch.nn import CrossEntropyLoss
|
7 |
+
from transformers import (
|
8 |
+
AutoConfig,
|
9 |
+
AutoModelForCausalLM,
|
10 |
+
MistralConfig,
|
11 |
+
MistralForCausalLM,
|
12 |
+
MistralModel,
|
13 |
+
)
|
14 |
+
from transformers.cache_utils import Cache, DynamicCache
|
15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, MoeCausalLMOutputWithPast
|
16 |
+
from transformers.generation.utils import GenerateOutput
|
17 |
+
|
18 |
+
from ..vita_arch import VITAMetaForCausalLM, VITAMetaModel
|
19 |
+
|
20 |
+
|
21 |
+
def custom_forward(
|
22 |
+
self,
|
23 |
+
input_ids: torch.LongTensor = None,
|
24 |
+
attention_mask: Optional[torch.Tensor] = None,
|
25 |
+
position_ids: Optional[torch.LongTensor] = None,
|
26 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
27 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
28 |
+
labels: Optional[torch.LongTensor] = None,
|
29 |
+
use_cache: Optional[bool] = None,
|
30 |
+
output_attentions: Optional[bool] = None,
|
31 |
+
output_hidden_states: Optional[bool] = None,
|
32 |
+
return_dict: Optional[bool] = None,
|
33 |
+
cache_position: Optional[torch.LongTensor] = None,
|
34 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
35 |
+
r"""
|
36 |
+
Args:
|
37 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
38 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
39 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
40 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
|
44 |
+
Example:
|
45 |
+
|
46 |
+
```python
|
47 |
+
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
48 |
+
|
49 |
+
>>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
|
50 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
51 |
+
|
52 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
53 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
54 |
+
|
55 |
+
>>> # Generate
|
56 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
57 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
58 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
59 |
+
```"""
|
60 |
+
|
61 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
62 |
+
output_hidden_states = (
|
63 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
64 |
+
)
|
65 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
66 |
+
|
67 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
68 |
+
outputs = self.model(
|
69 |
+
input_ids=input_ids,
|
70 |
+
attention_mask=attention_mask,
|
71 |
+
position_ids=position_ids,
|
72 |
+
past_key_values=past_key_values,
|
73 |
+
inputs_embeds=inputs_embeds,
|
74 |
+
use_cache=use_cache,
|
75 |
+
output_attentions=output_attentions,
|
76 |
+
output_hidden_states=output_hidden_states,
|
77 |
+
return_dict=return_dict,
|
78 |
+
cache_position=cache_position,
|
79 |
+
)
|
80 |
+
|
81 |
+
hidden_states = outputs[0]
|
82 |
+
logits = self.lm_head(hidden_states)
|
83 |
+
# logits = logits.float()
|
84 |
+
|
85 |
+
loss = None
|
86 |
+
if labels is not None:
|
87 |
+
# Shift so that tokens < n predict n
|
88 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
89 |
+
shift_labels = labels[..., 1:].contiguous()
|
90 |
+
# Flatten the tokens
|
91 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
92 |
+
shift_labels = shift_labels.view(-1)
|
93 |
+
# Ensure tensors are on the same device
|
94 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
95 |
+
loss_fct = CrossEntropyLoss()
|
96 |
+
loss = loss_fct(shift_logits, shift_labels)
|
97 |
+
|
98 |
+
if not return_dict:
|
99 |
+
output = (logits,) + outputs[1:]
|
100 |
+
return (loss,) + output if loss is not None else output
|
101 |
+
|
102 |
+
return CausalLMOutputWithPast(
|
103 |
+
loss=loss,
|
104 |
+
logits=logits,
|
105 |
+
past_key_values=outputs.past_key_values,
|
106 |
+
hidden_states=outputs.hidden_states,
|
107 |
+
attentions=outputs.attentions,
|
108 |
+
)
|
109 |
+
|
110 |
+
MistralForCausalLM.forward = custom_forward
|
111 |
+
|
112 |
+
|
113 |
+
class VITAMistralConfig(MistralConfig):
|
114 |
+
model_type = "vita-Mistral"
|
115 |
+
|
116 |
+
|
117 |
+
class VITAMistralModel(VITAMetaModel, MistralModel):
|
118 |
+
config_class = VITAMistralConfig
|
119 |
+
|
120 |
+
def __init__(self, config: MistralConfig):
|
121 |
+
super(VITAMistralModel, self).__init__(config)
|
122 |
+
|
123 |
+
|
124 |
+
class VITAMistralForCausalLM(MistralForCausalLM, VITAMetaForCausalLM):
|
125 |
+
config_class = VITAMistralConfig
|
126 |
+
|
127 |
+
def __init__(self, config):
|
128 |
+
super(MistralForCausalLM, self).__init__(config)
|
129 |
+
self.model = VITAMistralModel(config)
|
130 |
+
self.vocab_size = config.vocab_size
|
131 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
132 |
+
|
133 |
+
# Initialize weights and apply final processing
|
134 |
+
self.post_init()
|
135 |
+
|
136 |
+
def get_model(self):
|
137 |
+
return self.model
|
138 |
+
|
139 |
+
def forward(
|
140 |
+
self,
|
141 |
+
input_ids: torch.LongTensor = None,
|
142 |
+
attention_mask: Optional[torch.Tensor] = None,
|
143 |
+
position_ids: Optional[torch.LongTensor] = None,
|
144 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
145 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
146 |
+
labels: Optional[torch.LongTensor] = None,
|
147 |
+
use_cache: Optional[bool] = None,
|
148 |
+
output_attentions: Optional[bool] = None,
|
149 |
+
output_hidden_states: Optional[bool] = None,
|
150 |
+
images: Optional[torch.FloatTensor] = None,
|
151 |
+
audios: Optional[dict] = None,
|
152 |
+
return_dict: Optional[bool] = None,
|
153 |
+
cache_position: Optional[torch.LongTensor] = None,
|
154 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
155 |
+
if inputs_embeds is None:
|
156 |
+
(
|
157 |
+
input_ids,
|
158 |
+
position_ids,
|
159 |
+
attention_mask,
|
160 |
+
past_key_values,
|
161 |
+
inputs_embeds,
|
162 |
+
labels,
|
163 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
164 |
+
input_ids, position_ids, attention_mask, past_key_values, labels, images, audios
|
165 |
+
)
|
166 |
+
|
167 |
+
return super().forward(
|
168 |
+
input_ids=input_ids,
|
169 |
+
attention_mask=attention_mask,
|
170 |
+
position_ids=position_ids,
|
171 |
+
past_key_values=past_key_values,
|
172 |
+
inputs_embeds=inputs_embeds,
|
173 |
+
labels=labels,
|
174 |
+
use_cache=use_cache,
|
175 |
+
output_attentions=output_attentions,
|
176 |
+
output_hidden_states=output_hidden_states,
|
177 |
+
return_dict=return_dict,
|
178 |
+
cache_position=cache_position,
|
179 |
+
)
|
180 |
+
|
181 |
+
@torch.no_grad()
|
182 |
+
def generate(
|
183 |
+
self,
|
184 |
+
inputs: Optional[torch.Tensor] = None,
|
185 |
+
images: Optional[torch.Tensor] = None,
|
186 |
+
audios: Optional[torch.Tensor] = None,
|
187 |
+
**kwargs,
|
188 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
189 |
+
position_ids = kwargs.pop("position_ids", None)
|
190 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
191 |
+
if "inputs_embeds" in kwargs:
|
192 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
193 |
+
|
194 |
+
if images is not None or audios is not None:
|
195 |
+
(
|
196 |
+
inputs,
|
197 |
+
position_ids,
|
198 |
+
attention_mask,
|
199 |
+
_,
|
200 |
+
inputs_embeds,
|
201 |
+
_
|
202 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
203 |
+
inputs,
|
204 |
+
position_ids,
|
205 |
+
attention_mask,
|
206 |
+
None,
|
207 |
+
None,
|
208 |
+
images,
|
209 |
+
audios
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
213 |
+
|
214 |
+
return super().generate(
|
215 |
+
position_ids=position_ids,
|
216 |
+
attention_mask=attention_mask,
|
217 |
+
inputs_embeds=inputs_embeds,
|
218 |
+
**kwargs
|
219 |
+
)
|
220 |
+
|
221 |
+
def prepare_inputs_for_generation(
|
222 |
+
self,
|
223 |
+
input_ids,
|
224 |
+
past_key_values=None,
|
225 |
+
inputs_embeds=None,
|
226 |
+
attention_mask=None,
|
227 |
+
**kwargs,
|
228 |
+
):
|
229 |
+
images = kwargs.pop("images", None)
|
230 |
+
audios = kwargs.pop("audios", None)
|
231 |
+
|
232 |
+
_inputs = super().prepare_inputs_for_generation(
|
233 |
+
input_ids,
|
234 |
+
past_key_values=past_key_values,
|
235 |
+
inputs_embeds=inputs_embeds,
|
236 |
+
attention_mask=attention_mask,
|
237 |
+
**kwargs,
|
238 |
+
)
|
239 |
+
|
240 |
+
if images is not None:
|
241 |
+
_inputs["images"] = images
|
242 |
+
if audios is not None:
|
243 |
+
_inputs["audios"] = audios
|
244 |
+
return _inputs
|
245 |
+
|
246 |
+
def expand2square(self, pil_img, background_color):
|
247 |
+
width, height = pil_img.size
|
248 |
+
if width == height:
|
249 |
+
return pil_img
|
250 |
+
elif width > height:
|
251 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
252 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
253 |
+
return result
|
254 |
+
else:
|
255 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
256 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
257 |
+
return result
|
258 |
+
|
259 |
+
def process_images(self, images, model_cfg):
|
260 |
+
vision_tower = self.get_vision_tower()
|
261 |
+
if not vision_tower.is_loaded:
|
262 |
+
vision_tower.load_model()
|
263 |
+
image_processor = vision_tower.image_processor
|
264 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
265 |
+
new_images = []
|
266 |
+
if image_aspect_ratio == "pad":
|
267 |
+
for image in images:
|
268 |
+
image = self.expand2square(
|
269 |
+
image, tuple(int(x * 255) for x in image_processor.image_mean)
|
270 |
+
)
|
271 |
+
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
272 |
+
new_images.append(image)
|
273 |
+
else:
|
274 |
+
return image_processor(images, return_tensors="pt")["pixel_values"]
|
275 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
276 |
+
new_images = torch.stack(new_images, dim=0)
|
277 |
+
return new_images
|
278 |
+
|
279 |
+
|
280 |
+
AutoConfig.register("vita-Mistral", VITAMistralConfig)
|
281 |
+
AutoModelForCausalLM.register(VITAMistralConfig, VITAMistralForCausalLM)
|
282 |
+
|
vita/model/language_model/vita_qwen2.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from PIL import Image
|
6 |
+
from torch.nn import CrossEntropyLoss
|
7 |
+
from transformers import (
|
8 |
+
AutoConfig,
|
9 |
+
AutoModelForCausalLM,
|
10 |
+
Qwen2Config,
|
11 |
+
Qwen2ForCausalLM,
|
12 |
+
Qwen2Model,
|
13 |
+
)
|
14 |
+
from transformers.cache_utils import Cache, DynamicCache
|
15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, MoeCausalLMOutputWithPast
|
16 |
+
from transformers.generation.utils import GenerateOutput
|
17 |
+
|
18 |
+
from ..vita_arch import VITAMetaForCausalLM, VITAMetaModel
|
19 |
+
|
20 |
+
|
21 |
+
def custom_forward(
|
22 |
+
self,
|
23 |
+
input_ids: torch.LongTensor = None,
|
24 |
+
attention_mask: Optional[torch.Tensor] = None,
|
25 |
+
position_ids: Optional[torch.LongTensor] = None,
|
26 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
27 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
28 |
+
labels: Optional[torch.LongTensor] = None,
|
29 |
+
use_cache: Optional[bool] = None,
|
30 |
+
output_attentions: Optional[bool] = None,
|
31 |
+
output_hidden_states: Optional[bool] = None,
|
32 |
+
return_dict: Optional[bool] = None,
|
33 |
+
cache_position: Optional[torch.LongTensor] = None,
|
34 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
35 |
+
r"""
|
36 |
+
Args:
|
37 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
38 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
39 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
40 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
|
44 |
+
Example:
|
45 |
+
|
46 |
+
```python
|
47 |
+
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
|
48 |
+
|
49 |
+
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
50 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
51 |
+
|
52 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
53 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
54 |
+
|
55 |
+
>>> # Generate
|
56 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
57 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
58 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
59 |
+
```"""
|
60 |
+
|
61 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
62 |
+
output_hidden_states = (
|
63 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
64 |
+
)
|
65 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
66 |
+
|
67 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
68 |
+
outputs = self.model(
|
69 |
+
input_ids=input_ids,
|
70 |
+
attention_mask=attention_mask,
|
71 |
+
position_ids=position_ids,
|
72 |
+
past_key_values=past_key_values,
|
73 |
+
inputs_embeds=inputs_embeds,
|
74 |
+
use_cache=use_cache,
|
75 |
+
output_attentions=output_attentions,
|
76 |
+
output_hidden_states=output_hidden_states,
|
77 |
+
return_dict=return_dict,
|
78 |
+
cache_position=cache_position,
|
79 |
+
)
|
80 |
+
|
81 |
+
hidden_states = outputs[0]
|
82 |
+
logits = self.lm_head(hidden_states)
|
83 |
+
# logits = logits.float()
|
84 |
+
|
85 |
+
loss = None
|
86 |
+
if labels is not None:
|
87 |
+
# Shift so that tokens < n predict n
|
88 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
89 |
+
shift_labels = labels[..., 1:].contiguous()
|
90 |
+
# Flatten the tokens
|
91 |
+
loss_fct = CrossEntropyLoss()
|
92 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
93 |
+
shift_labels = shift_labels.view(-1)
|
94 |
+
# Enable model parallelism
|
95 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
96 |
+
loss = loss_fct(shift_logits, shift_labels)
|
97 |
+
|
98 |
+
if not return_dict:
|
99 |
+
output = (logits,) + outputs[1:]
|
100 |
+
return (loss,) + output if loss is not None else output
|
101 |
+
|
102 |
+
#import pdb; pdb.set_trace()
|
103 |
+
return CausalLMOutputWithPast(
|
104 |
+
loss=loss,
|
105 |
+
logits=logits,
|
106 |
+
past_key_values=outputs.past_key_values,
|
107 |
+
hidden_states=outputs.hidden_states,
|
108 |
+
attentions=outputs.attentions,
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
Qwen2ForCausalLM.forward = custom_forward
|
113 |
+
|
114 |
+
|
115 |
+
class VITAQwen2Config(Qwen2Config):
|
116 |
+
model_type = "vita-Qwen2"
|
117 |
+
|
118 |
+
|
119 |
+
class VITAQwen2Model(VITAMetaModel, Qwen2Model):
|
120 |
+
config_class = VITAQwen2Config
|
121 |
+
|
122 |
+
def __init__(self, config: Qwen2Config):
|
123 |
+
super(VITAQwen2Model, self).__init__(config)
|
124 |
+
|
125 |
+
|
126 |
+
class VITAQwen2ForCausalLM(Qwen2ForCausalLM, VITAMetaForCausalLM):
|
127 |
+
config_class = VITAQwen2Config
|
128 |
+
|
129 |
+
def __init__(self, config):
|
130 |
+
super(Qwen2ForCausalLM, self).__init__(config)
|
131 |
+
self.model = VITAQwen2Model(config)
|
132 |
+
self.vocab_size = config.vocab_size
|
133 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
134 |
+
|
135 |
+
# Initialize weights and apply final processing
|
136 |
+
self.post_init()
|
137 |
+
|
138 |
+
def get_model(self):
|
139 |
+
return self.model
|
140 |
+
|
141 |
+
def forward(
|
142 |
+
self,
|
143 |
+
input_ids: torch.LongTensor = None,
|
144 |
+
attention_mask: Optional[torch.Tensor] = None,
|
145 |
+
position_ids: Optional[torch.LongTensor] = None,
|
146 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
147 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
148 |
+
labels: Optional[torch.LongTensor] = None,
|
149 |
+
use_cache: Optional[bool] = None,
|
150 |
+
output_attentions: Optional[bool] = None,
|
151 |
+
output_hidden_states: Optional[bool] = None,
|
152 |
+
images: Optional[torch.FloatTensor] = None,
|
153 |
+
audios: Optional[dict] = None,
|
154 |
+
sf_masks: Optional[torch.Tensor] = None,
|
155 |
+
return_dict: Optional[bool] = None,
|
156 |
+
cache_position: Optional[torch.LongTensor] = None,
|
157 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
158 |
+
if inputs_embeds is None:
|
159 |
+
(
|
160 |
+
input_ids,
|
161 |
+
position_ids,
|
162 |
+
attention_mask,
|
163 |
+
past_key_values,
|
164 |
+
inputs_embeds,
|
165 |
+
labels,
|
166 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
167 |
+
input_ids, position_ids, attention_mask, past_key_values, labels, images, audios, sf_masks
|
168 |
+
)
|
169 |
+
|
170 |
+
return super().forward(
|
171 |
+
input_ids=input_ids,
|
172 |
+
attention_mask=attention_mask,
|
173 |
+
position_ids=position_ids,
|
174 |
+
past_key_values=past_key_values,
|
175 |
+
inputs_embeds=inputs_embeds,
|
176 |
+
labels=labels,
|
177 |
+
use_cache=use_cache,
|
178 |
+
output_attentions=output_attentions,
|
179 |
+
output_hidden_states=output_hidden_states,
|
180 |
+
return_dict=return_dict,
|
181 |
+
cache_position=cache_position,
|
182 |
+
)
|
183 |
+
|
184 |
+
@torch.no_grad()
|
185 |
+
def generate(
|
186 |
+
self,
|
187 |
+
inputs: Optional[torch.Tensor] = None,
|
188 |
+
images: Optional[torch.Tensor] = None,
|
189 |
+
audios: Optional[torch.Tensor] = None,
|
190 |
+
sf_masks: Optional[torch.Tensor] = None,
|
191 |
+
shared_v_pid_stride: Optional[int] = None,
|
192 |
+
**kwargs,
|
193 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
194 |
+
position_ids = kwargs.pop("position_ids", None)
|
195 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
196 |
+
if "inputs_embeds" in kwargs:
|
197 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
198 |
+
|
199 |
+
if images is not None or audios is not None:
|
200 |
+
(
|
201 |
+
inputs,
|
202 |
+
position_ids,
|
203 |
+
attention_mask,
|
204 |
+
_,
|
205 |
+
inputs_embeds,
|
206 |
+
_
|
207 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
208 |
+
inputs,
|
209 |
+
position_ids,
|
210 |
+
attention_mask,
|
211 |
+
None,
|
212 |
+
None,
|
213 |
+
images,
|
214 |
+
audios,
|
215 |
+
sf_masks,
|
216 |
+
shared_v_pid_stride,
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
220 |
+
|
221 |
+
return super().generate(
|
222 |
+
position_ids=position_ids,
|
223 |
+
attention_mask=attention_mask,
|
224 |
+
inputs_embeds=inputs_embeds,
|
225 |
+
**kwargs
|
226 |
+
)
|
227 |
+
|
228 |
+
def prepare_inputs_for_generation(
|
229 |
+
self,
|
230 |
+
input_ids,
|
231 |
+
past_key_values=None,
|
232 |
+
inputs_embeds=None,
|
233 |
+
attention_mask=None,
|
234 |
+
**kwargs,
|
235 |
+
):
|
236 |
+
images = kwargs.pop("images", None)
|
237 |
+
audios = kwargs.pop("audios", None)
|
238 |
+
sf_masks = kwargs.pop("sf_masks", None)
|
239 |
+
|
240 |
+
_inputs = super().prepare_inputs_for_generation(
|
241 |
+
input_ids,
|
242 |
+
past_key_values=past_key_values,
|
243 |
+
inputs_embeds=inputs_embeds,
|
244 |
+
attention_mask=attention_mask,
|
245 |
+
**kwargs,
|
246 |
+
)
|
247 |
+
|
248 |
+
# import pdb; pdb.set_trace()
|
249 |
+
position_ids = _inputs["position_ids"]
|
250 |
+
cache_position = _inputs["cache_position"]
|
251 |
+
if cache_position.shape[-1] == 1 and position_ids.shape[-1] > 1:
|
252 |
+
new_position_ids = torch.zeros((position_ids.shape[0],1), dtype=position_ids.dtype, device=position_ids.device)
|
253 |
+
new_position_ids[:, 0] = position_ids[0,-1] + cache_position[-1] + 1 - position_ids.shape[-1]
|
254 |
+
position_ids = new_position_ids
|
255 |
+
_inputs["position_ids"] = position_ids
|
256 |
+
# import pdb; pdb.set_trace()
|
257 |
+
|
258 |
+
if images is not None:
|
259 |
+
_inputs["images"] = images
|
260 |
+
if audios is not None:
|
261 |
+
_inputs["audios"] = audios
|
262 |
+
if sf_masks is not None:
|
263 |
+
_inputs["sf_masks"] = sf_masks
|
264 |
+
return _inputs
|
265 |
+
|
266 |
+
def expand2square(self, pil_img, background_color):
|
267 |
+
width, height = pil_img.size
|
268 |
+
if width == height:
|
269 |
+
return pil_img
|
270 |
+
elif width > height:
|
271 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
272 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
273 |
+
return result
|
274 |
+
else:
|
275 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
276 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
277 |
+
return result
|
278 |
+
|
279 |
+
def process_images(self, images, model_cfg):
|
280 |
+
vision_tower = self.get_vision_tower()
|
281 |
+
if not vision_tower.is_loaded:
|
282 |
+
vision_tower.load_model()
|
283 |
+
image_processor = vision_tower.image_processor
|
284 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
285 |
+
new_images = []
|
286 |
+
if image_aspect_ratio == "pad":
|
287 |
+
for image in images:
|
288 |
+
image = self.expand2square(
|
289 |
+
image, tuple(int(x * 255) for x in image_processor.image_mean)
|
290 |
+
)
|
291 |
+
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
292 |
+
new_images.append(image)
|
293 |
+
else:
|
294 |
+
return image_processor(images, return_tensors="pt")["pixel_values"]
|
295 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
296 |
+
new_images = torch.stack(new_images, dim=0)
|
297 |
+
return new_images
|
298 |
+
|
299 |
+
|
300 |
+
AutoConfig.register("vita-Qwen2", VITAQwen2Config)
|
301 |
+
AutoModelForCausalLM.register(VITAQwen2Config, VITAQwen2ForCausalLM)
|
302 |
+
|
303 |
+
|
304 |
+
|
vita/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
import torch
|
5 |
+
from transformers.utils.hub import get_file_from_repo
|
6 |
+
|
7 |
+
from .clip.clip_encoder import CLIPVisionTower
|
8 |
+
from .eva_clip.eva_clip_encoder import EvaClipVisionTower
|
9 |
+
from .internvit.internvit_encoder import InternViTVisionTower
|
10 |
+
from .siglip.siglip_encoder import SiglipVisionTower, SiglipVisionTowerS2
|
11 |
+
from .whale.init_model import init_model
|
12 |
+
|
13 |
+
|
14 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
15 |
+
vision_tower = getattr(
|
16 |
+
vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)
|
17 |
+
)
|
18 |
+
use_s2 = getattr(vision_tower_cfg, "use_s2", False)
|
19 |
+
|
20 |
+
if "sig" in vision_tower.lower():
|
21 |
+
if use_s2:
|
22 |
+
return SiglipVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
|
23 |
+
else:
|
24 |
+
return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
25 |
+
elif "eva" in vision_tower.lower():
|
26 |
+
if use_s2:
|
27 |
+
raise ValueError(f"Currently not supporting S2 for EVA-CLIP")
|
28 |
+
else:
|
29 |
+
return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
30 |
+
|
31 |
+
elif "clip" in vision_tower.lower():
|
32 |
+
if use_s2:
|
33 |
+
raise ValueError(f"Currently not supporting S2 for CLIP")
|
34 |
+
else:
|
35 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
36 |
+
elif "internvit" in vision_tower.lower():
|
37 |
+
if use_s2:
|
38 |
+
raise ValueError(f"Currently not supporting S2 for InternViT")
|
39 |
+
else:
|
40 |
+
return InternViTVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
41 |
+
|
42 |
+
else:
|
43 |
+
raise ValueError(f"Unknown vision tower: {vision_tower}")
|
44 |
+
|
45 |
+
|
46 |
+
def build_audio_encoder(audio_encoder_config, **kwargs):
|
47 |
+
with open(get_file_from_repo(audio_encoder_config.mm_audio_encoder, "train.yaml"), "r") as fin:
|
48 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
49 |
+
|
50 |
+
configs["cmvn_file"] = get_file_from_repo(audio_encoder_config.mm_audio_encoder, "global_cmvn")
|
51 |
+
|
52 |
+
configs["model_conf"]["freeze_encoder"] = getattr(
|
53 |
+
audio_encoder_config, "freeze_audio_encoder", True
|
54 |
+
)
|
55 |
+
configs["model_conf"]["freeze_adpter"] = getattr(
|
56 |
+
audio_encoder_config, "freeze_audio_encoder_adapter", True
|
57 |
+
)
|
58 |
+
configs["model_conf"]["audio_prompt_finetune"] = getattr(
|
59 |
+
audio_encoder_config, "audio_prompt_finetune", False
|
60 |
+
)
|
61 |
+
configs["model_conf"]["audio_prompt_num"] = getattr(
|
62 |
+
audio_encoder_config, "audio_prompt_num", 0
|
63 |
+
)
|
64 |
+
|
65 |
+
audio_encoder = init_model(configs)
|
66 |
+
|
67 |
+
checkpoint = torch.load(get_file_from_repo(audio_encoder_config.mm_audio_encoder, "final.pt"), map_location="cpu")
|
68 |
+
model_dict = audio_encoder.state_dict()
|
69 |
+
for key in model_dict.keys():
|
70 |
+
if key in checkpoint.keys():
|
71 |
+
if model_dict[key].shape == checkpoint[key].shape:
|
72 |
+
model_dict[key] = checkpoint[key]
|
73 |
+
else:
|
74 |
+
print(
|
75 |
+
"Key {} has different shape, {} VS {}".format(
|
76 |
+
key, model_dict[key].shape, checkpoint[key].shape
|
77 |
+
)
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
print("Key {} has not in resume model".format(key))
|
81 |
+
audio_encoder.load_state_dict(model_dict)
|
82 |
+
|
83 |
+
return audio_encoder
|
vita/model/multimodal_encoder/clip/clip_encoder.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
|
4 |
+
|
5 |
+
|
6 |
+
class CLIPVisionTower(nn.Module):
|
7 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.is_loaded = False
|
11 |
+
|
12 |
+
self.vision_tower_name = vision_tower
|
13 |
+
self.select_layer = -2
|
14 |
+
|
15 |
+
if not delay_load:
|
16 |
+
self.load_model()
|
17 |
+
else:
|
18 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
19 |
+
|
20 |
+
def load_model(self):
|
21 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
22 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
23 |
+
self.vision_tower.requires_grad_(False)
|
24 |
+
|
25 |
+
self.is_loaded = True
|
26 |
+
|
27 |
+
def feature_select(self, image_forward_outs):
|
28 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
29 |
+
|
30 |
+
image_features = image_features[:, 1:]
|
31 |
+
|
32 |
+
return image_features
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def forward(self, images):
|
36 |
+
if type(images) is list:
|
37 |
+
image_features = []
|
38 |
+
for image in images:
|
39 |
+
image_forward_out = self.vision_tower(
|
40 |
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
41 |
+
output_hidden_states=True,
|
42 |
+
)
|
43 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
44 |
+
image_features.append(image_feature)
|
45 |
+
else:
|
46 |
+
image_forward_outs = self.vision_tower(
|
47 |
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
48 |
+
)
|
49 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
50 |
+
|
51 |
+
return image_features
|
52 |
+
|
53 |
+
@property
|
54 |
+
def dummy_feature(self):
|
55 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
56 |
+
|
57 |
+
@property
|
58 |
+
def dtype(self):
|
59 |
+
return self.vision_tower.dtype
|
60 |
+
|
61 |
+
@property
|
62 |
+
def device(self):
|
63 |
+
return self.vision_tower.device
|
64 |
+
|
65 |
+
@property
|
66 |
+
def config(self):
|
67 |
+
if self.is_loaded:
|
68 |
+
return self.vision_tower.config
|
69 |
+
else:
|
70 |
+
return self.cfg_only
|
71 |
+
|
72 |
+
@property
|
73 |
+
def hidden_size(self):
|
74 |
+
return self.config.hidden_size
|
75 |
+
|
76 |
+
@property
|
77 |
+
def num_patches(self):
|
78 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
vita/model/multimodal_encoder/eva_clip/eva_clip_encoder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .eva_clip_processors import EvaClipImageTrainProcessor
|
5 |
+
from .eva_vit import Eva2LargePlusEncoder
|
6 |
+
|
7 |
+
|
8 |
+
class EvaClipVisionTower(nn.Module):
|
9 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.is_loaded = False
|
13 |
+
|
14 |
+
self.vision_tower_path = vision_tower
|
15 |
+
self.config = VisionTowerConfig()
|
16 |
+
|
17 |
+
if not delay_load:
|
18 |
+
self.load_model()
|
19 |
+
else:
|
20 |
+
self.cfg_only = self.config
|
21 |
+
|
22 |
+
def load_model(self):
|
23 |
+
self.image_processor = EvaClipImageTrainProcessor(self.config.image_size)
|
24 |
+
self.vision_tower = Eva2LargePlusEncoder(self.vision_tower_path)
|
25 |
+
self.vision_tower.requires_grad_(False)
|
26 |
+
|
27 |
+
self.is_loaded = True
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def forward(self, images):
|
31 |
+
if type(images) is list:
|
32 |
+
image_features = []
|
33 |
+
for image in images:
|
34 |
+
image_feature = self.vision_tower(
|
35 |
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0)
|
36 |
+
).to(image.dtype)
|
37 |
+
image_features.append(image_feature)
|
38 |
+
else:
|
39 |
+
image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(
|
40 |
+
images.dtype
|
41 |
+
)
|
42 |
+
|
43 |
+
return image_features
|
44 |
+
|
45 |
+
@property
|
46 |
+
def dtype(self):
|
47 |
+
return self.vision_tower.dtype
|
48 |
+
|
49 |
+
@property
|
50 |
+
def device(self):
|
51 |
+
return self.vision_tower.device
|
52 |
+
|
53 |
+
@property
|
54 |
+
def hidden_size(self):
|
55 |
+
return self.config.hidden_size
|
56 |
+
|
57 |
+
@property
|
58 |
+
def num_patches(self):
|
59 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
60 |
+
|
61 |
+
|
62 |
+
class VisionTowerConfig:
|
63 |
+
def __init__(self):
|
64 |
+
self.image_size = 336
|
65 |
+
self.patch_size = 14
|
66 |
+
self.hidden_size = 1024
|
vita/model/multimodal_encoder/eva_clip/eva_clip_processors.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
|
3 |
+
"""
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from transformers.image_processing_utils import BatchFeature
|
7 |
+
from transformers.image_transforms import convert_to_rgb
|
8 |
+
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.transforms.functional import InterpolationMode
|
11 |
+
|
12 |
+
|
13 |
+
class BaseProcessor:
|
14 |
+
def __init__(self):
|
15 |
+
self.transform = lambda x: x
|
16 |
+
return
|
17 |
+
|
18 |
+
def __call__(self, item):
|
19 |
+
return self.transform(item)
|
20 |
+
|
21 |
+
|
22 |
+
class EvaClipImageBaseProcessor(BaseProcessor):
|
23 |
+
def __init__(self, mean=None, std=None):
|
24 |
+
self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
|
25 |
+
self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
|
26 |
+
|
27 |
+
self.normalize = transforms.Normalize(self.mean, self.std)
|
28 |
+
|
29 |
+
@property
|
30 |
+
def image_mean(self):
|
31 |
+
return self.mean
|
32 |
+
|
33 |
+
|
34 |
+
class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
|
35 |
+
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
|
36 |
+
super().__init__(mean=mean, std=std)
|
37 |
+
|
38 |
+
self.transform = transforms.Compose(
|
39 |
+
[
|
40 |
+
convert_to_rgb,
|
41 |
+
transforms.Resize(
|
42 |
+
image_size,
|
43 |
+
interpolation=InterpolationMode.BICUBIC,
|
44 |
+
),
|
45 |
+
transforms.CenterCrop(image_size),
|
46 |
+
transforms.ToTensor(),
|
47 |
+
self.normalize,
|
48 |
+
]
|
49 |
+
)
|
50 |
+
|
51 |
+
self.image_size = image_size
|
52 |
+
|
53 |
+
def preprocess(self, images, return_tensors):
|
54 |
+
if isinstance(images, Image.Image):
|
55 |
+
images = [images]
|
56 |
+
else:
|
57 |
+
assert isinstance(images, list)
|
58 |
+
|
59 |
+
transformed_images = [self.transform(image).numpy() for image in images]
|
60 |
+
data = {"pixel_values": transformed_images}
|
61 |
+
|
62 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
63 |
+
|
64 |
+
def __call__(self, item):
|
65 |
+
return self.transform(item)
|
66 |
+
|
67 |
+
@property
|
68 |
+
def crop_size(self):
|
69 |
+
return {"height": self.image_size, "width": self.image_size}
|
vita/model/multimodal_encoder/eva_clip/eva_vit.py
ADDED
@@ -0,0 +1,982 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
|
7 |
+
# --------------------------------------------------------
|
8 |
+
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
|
9 |
+
# --------------------------------------------------------
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from functools import partial
|
14 |
+
from math import pi
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from einops import rearrange, repeat
|
21 |
+
from torch import nn as nn
|
22 |
+
|
23 |
+
import xformers.ops as xops
|
24 |
+
|
25 |
+
|
26 |
+
def broadcat(tensors, dim=-1):
|
27 |
+
num_tensors = len(tensors)
|
28 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
29 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
30 |
+
shape_len = list(shape_lens)[0]
|
31 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
32 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
33 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
34 |
+
assert all(
|
35 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
36 |
+
), "invalid dimensions for broadcastable concatentation"
|
37 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
38 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
39 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
40 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
41 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
42 |
+
return torch.cat(tensors, dim=dim)
|
43 |
+
|
44 |
+
|
45 |
+
def rotate_half(x):
|
46 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
47 |
+
x1, x2 = x.unbind(dim=-1)
|
48 |
+
x = torch.stack((-x2, x1), dim=-1)
|
49 |
+
return rearrange(x, "... d r -> ... (d r)")
|
50 |
+
|
51 |
+
|
52 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
dim,
|
56 |
+
pt_seq_len,
|
57 |
+
ft_seq_len=None,
|
58 |
+
custom_freqs=None,
|
59 |
+
freqs_for="lang",
|
60 |
+
theta=10000,
|
61 |
+
max_freq=10,
|
62 |
+
num_freqs=1,
|
63 |
+
patch_dropout=0.0,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
if custom_freqs:
|
67 |
+
freqs = custom_freqs
|
68 |
+
elif freqs_for == "lang":
|
69 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
70 |
+
elif freqs_for == "pixel":
|
71 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
72 |
+
elif freqs_for == "constant":
|
73 |
+
freqs = torch.ones(num_freqs).float()
|
74 |
+
else:
|
75 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
76 |
+
|
77 |
+
if ft_seq_len is None:
|
78 |
+
ft_seq_len = pt_seq_len
|
79 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
80 |
+
|
81 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
82 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
83 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
84 |
+
|
85 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
86 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
87 |
+
|
88 |
+
self.patch_dropout = patch_dropout
|
89 |
+
|
90 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
91 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
92 |
+
|
93 |
+
logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
|
94 |
+
|
95 |
+
def forward(self, t, patch_indices_keep=None):
|
96 |
+
if patch_indices_keep is not None:
|
97 |
+
batch = t.size()[0]
|
98 |
+
batch_indices = torch.arange(batch)
|
99 |
+
batch_indices = batch_indices[..., None]
|
100 |
+
|
101 |
+
freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
|
102 |
+
freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
|
103 |
+
|
104 |
+
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
|
105 |
+
freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
|
106 |
+
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
|
107 |
+
freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
|
108 |
+
|
109 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
110 |
+
|
111 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
112 |
+
|
113 |
+
|
114 |
+
class LayerNorm(nn.LayerNorm):
|
115 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
116 |
+
|
117 |
+
def forward(self, x: torch.Tensor):
|
118 |
+
orig_type = x.dtype
|
119 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
120 |
+
return x.to(orig_type)
|
121 |
+
|
122 |
+
|
123 |
+
class PatchDropout(nn.Module):
|
124 |
+
"""
|
125 |
+
https://arxiv.org/abs/2212.00794
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self, prob, exclude_first_token=True):
|
129 |
+
super().__init__()
|
130 |
+
assert 0 <= prob < 1.0
|
131 |
+
self.prob = prob
|
132 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
133 |
+
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
if not self.training or self.prob == 0.0:
|
137 |
+
return x
|
138 |
+
|
139 |
+
if self.exclude_first_token:
|
140 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
141 |
+
else:
|
142 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
143 |
+
|
144 |
+
batch = x.size()[0]
|
145 |
+
num_tokens = x.size()[1]
|
146 |
+
|
147 |
+
batch_indices = torch.arange(batch)
|
148 |
+
batch_indices = batch_indices[..., None]
|
149 |
+
|
150 |
+
keep_prob = 1 - self.prob
|
151 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
152 |
+
|
153 |
+
rand = torch.randn(batch, num_tokens)
|
154 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
155 |
+
|
156 |
+
x = x[batch_indices, patch_indices_keep]
|
157 |
+
|
158 |
+
if self.exclude_first_token:
|
159 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
160 |
+
|
161 |
+
if self.training and os.getenv("RoPE") == "1":
|
162 |
+
return x, patch_indices_keep
|
163 |
+
|
164 |
+
return x
|
165 |
+
|
166 |
+
|
167 |
+
try:
|
168 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
169 |
+
except:
|
170 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
171 |
+
|
172 |
+
if os.getenv("ENV_TYPE") == "deepspeed":
|
173 |
+
try:
|
174 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
175 |
+
except:
|
176 |
+
from torch.utils.checkpoint import checkpoint
|
177 |
+
else:
|
178 |
+
from torch.utils.checkpoint import checkpoint
|
179 |
+
|
180 |
+
|
181 |
+
class DropPath(nn.Module):
|
182 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
183 |
+
|
184 |
+
def __init__(self, drop_prob=None):
|
185 |
+
super(DropPath, self).__init__()
|
186 |
+
self.drop_prob = drop_prob
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
return drop_path(x, self.drop_prob, self.training)
|
190 |
+
|
191 |
+
def extra_repr(self) -> str:
|
192 |
+
return "p={}".format(self.drop_prob)
|
193 |
+
|
194 |
+
|
195 |
+
class Mlp(nn.Module):
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
in_features,
|
199 |
+
hidden_features=None,
|
200 |
+
out_features=None,
|
201 |
+
act_layer=nn.GELU,
|
202 |
+
norm_layer=nn.LayerNorm,
|
203 |
+
drop=0.0,
|
204 |
+
subln=False,
|
205 |
+
):
|
206 |
+
super().__init__()
|
207 |
+
out_features = out_features or in_features
|
208 |
+
hidden_features = hidden_features or in_features
|
209 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
210 |
+
self.act = act_layer()
|
211 |
+
|
212 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
213 |
+
|
214 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
215 |
+
self.drop = nn.Dropout(drop)
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
x = self.fc1(x)
|
219 |
+
x = self.act(x)
|
220 |
+
# x = self.drop(x)
|
221 |
+
# commit this for the orignal BERT implement
|
222 |
+
x = self.ffn_ln(x)
|
223 |
+
|
224 |
+
x = self.fc2(x)
|
225 |
+
x = self.drop(x)
|
226 |
+
return x
|
227 |
+
|
228 |
+
|
229 |
+
class SwiGLU(nn.Module):
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
in_features,
|
233 |
+
hidden_features=None,
|
234 |
+
out_features=None,
|
235 |
+
act_layer=nn.SiLU,
|
236 |
+
drop=0.0,
|
237 |
+
norm_layer=nn.LayerNorm,
|
238 |
+
subln=False,
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
out_features = out_features or in_features
|
242 |
+
hidden_features = hidden_features or in_features
|
243 |
+
|
244 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
245 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
246 |
+
|
247 |
+
self.act = act_layer()
|
248 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
249 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
250 |
+
|
251 |
+
self.drop = nn.Dropout(drop)
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
x1 = self.w1(x)
|
255 |
+
x2 = self.w2(x)
|
256 |
+
hidden = self.act(x1) * x2
|
257 |
+
x = self.ffn_ln(hidden)
|
258 |
+
x = self.w3(x)
|
259 |
+
x = self.drop(x)
|
260 |
+
return x
|
261 |
+
|
262 |
+
|
263 |
+
class Attention(nn.Module):
|
264 |
+
def __init__(
|
265 |
+
self,
|
266 |
+
dim,
|
267 |
+
num_heads=8,
|
268 |
+
qkv_bias=False,
|
269 |
+
qk_scale=None,
|
270 |
+
attn_drop=0.0,
|
271 |
+
proj_drop=0.0,
|
272 |
+
window_size=None,
|
273 |
+
attn_head_dim=None,
|
274 |
+
xattn=False,
|
275 |
+
rope=None,
|
276 |
+
subln=False,
|
277 |
+
norm_layer=nn.LayerNorm,
|
278 |
+
):
|
279 |
+
super().__init__()
|
280 |
+
self.num_heads = num_heads
|
281 |
+
head_dim = dim // num_heads
|
282 |
+
if attn_head_dim is not None:
|
283 |
+
head_dim = attn_head_dim
|
284 |
+
all_head_dim = head_dim * self.num_heads
|
285 |
+
self.scale = qk_scale or head_dim**-0.5
|
286 |
+
|
287 |
+
self.subln = subln
|
288 |
+
if self.subln:
|
289 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
290 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
291 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
292 |
+
else:
|
293 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
294 |
+
|
295 |
+
if qkv_bias:
|
296 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
297 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
298 |
+
else:
|
299 |
+
self.q_bias = None
|
300 |
+
self.v_bias = None
|
301 |
+
|
302 |
+
if window_size:
|
303 |
+
self.window_size = window_size
|
304 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
305 |
+
self.relative_position_bias_table = nn.Parameter(
|
306 |
+
torch.zeros(self.num_relative_distance, num_heads)
|
307 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
308 |
+
# cls to token & token 2 cls & cls to cls
|
309 |
+
|
310 |
+
# get pair-wise relative position index for each token inside the window
|
311 |
+
coords_h = torch.arange(window_size[0])
|
312 |
+
coords_w = torch.arange(window_size[1])
|
313 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
314 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
315 |
+
relative_coords = (
|
316 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
317 |
+
) # 2, Wh*Ww, Wh*Ww
|
318 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
319 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
320 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
321 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
322 |
+
relative_position_index = torch.zeros(
|
323 |
+
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
324 |
+
)
|
325 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
326 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
327 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
328 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
329 |
+
|
330 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
331 |
+
else:
|
332 |
+
self.window_size = None
|
333 |
+
self.relative_position_bias_table = None
|
334 |
+
self.relative_position_index = None
|
335 |
+
|
336 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
337 |
+
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
338 |
+
# self.proj = nn.Linear(all_head_dim, all_head_dim)
|
339 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
340 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
341 |
+
self.xattn = xattn
|
342 |
+
self.xattn_drop = attn_drop
|
343 |
+
|
344 |
+
self.rope = rope
|
345 |
+
|
346 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
347 |
+
B, N, C = x.shape
|
348 |
+
if self.subln:
|
349 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
350 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
351 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
352 |
+
|
353 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
354 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
355 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
356 |
+
else:
|
357 |
+
|
358 |
+
qkv_bias = None
|
359 |
+
if self.q_bias is not None:
|
360 |
+
qkv_bias = torch.cat(
|
361 |
+
(self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)
|
362 |
+
)
|
363 |
+
|
364 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
365 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
|
366 |
+
2, 0, 3, 1, 4
|
367 |
+
) # 3, B, num_heads, N, C
|
368 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
369 |
+
|
370 |
+
if self.rope:
|
371 |
+
# slightly fast impl
|
372 |
+
q_t = q[:, :, 1:, :]
|
373 |
+
ro_q_t = self.rope(q_t)
|
374 |
+
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
|
375 |
+
|
376 |
+
k_t = k[:, :, 1:, :]
|
377 |
+
ro_k_t = self.rope(k_t)
|
378 |
+
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
|
379 |
+
|
380 |
+
if self.xattn:
|
381 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
382 |
+
k = k.permute(0, 2, 1, 3)
|
383 |
+
v = v.permute(0, 2, 1, 3)
|
384 |
+
|
385 |
+
x = xops.memory_efficient_attention(
|
386 |
+
q,
|
387 |
+
k,
|
388 |
+
v,
|
389 |
+
p=self.xattn_drop,
|
390 |
+
scale=self.scale,
|
391 |
+
)
|
392 |
+
x = x.reshape(B, N, -1)
|
393 |
+
x = self.inner_attn_ln(x)
|
394 |
+
x = self.proj(x)
|
395 |
+
x = self.proj_drop(x)
|
396 |
+
else:
|
397 |
+
q = q * self.scale
|
398 |
+
attn = q @ k.transpose(-2, -1)
|
399 |
+
|
400 |
+
if self.relative_position_bias_table is not None:
|
401 |
+
relative_position_bias = self.relative_position_bias_table[
|
402 |
+
self.relative_position_index.view(-1)
|
403 |
+
].view(
|
404 |
+
self.window_size[0] * self.window_size[1] + 1,
|
405 |
+
self.window_size[0] * self.window_size[1] + 1,
|
406 |
+
-1,
|
407 |
+
) # Wh*Ww,Wh*Ww,nH
|
408 |
+
relative_position_bias = relative_position_bias.permute(
|
409 |
+
2, 0, 1
|
410 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
411 |
+
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
|
412 |
+
|
413 |
+
if rel_pos_bias is not None:
|
414 |
+
attn = attn + rel_pos_bias.type_as(attn)
|
415 |
+
|
416 |
+
if attn_mask is not None:
|
417 |
+
attn_mask = attn_mask.bool()
|
418 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
419 |
+
|
420 |
+
attn = attn.softmax(dim=-1)
|
421 |
+
attn = self.attn_drop(attn)
|
422 |
+
|
423 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
424 |
+
x = self.inner_attn_ln(x)
|
425 |
+
x = self.proj(x)
|
426 |
+
x = self.proj_drop(x)
|
427 |
+
return x
|
428 |
+
|
429 |
+
|
430 |
+
class Block(nn.Module):
|
431 |
+
def __init__(
|
432 |
+
self,
|
433 |
+
dim,
|
434 |
+
num_heads,
|
435 |
+
mlp_ratio=4.0,
|
436 |
+
qkv_bias=False,
|
437 |
+
qk_scale=None,
|
438 |
+
drop=0.0,
|
439 |
+
attn_drop=0.0,
|
440 |
+
drop_path=0.0,
|
441 |
+
init_values=None,
|
442 |
+
act_layer=nn.GELU,
|
443 |
+
norm_layer=nn.LayerNorm,
|
444 |
+
window_size=None,
|
445 |
+
attn_head_dim=None,
|
446 |
+
xattn=False,
|
447 |
+
rope=None,
|
448 |
+
postnorm=False,
|
449 |
+
subln=False,
|
450 |
+
naiveswiglu=False,
|
451 |
+
):
|
452 |
+
super().__init__()
|
453 |
+
self.norm1 = norm_layer(dim)
|
454 |
+
self.attn = Attention(
|
455 |
+
dim,
|
456 |
+
num_heads=num_heads,
|
457 |
+
qkv_bias=qkv_bias,
|
458 |
+
qk_scale=qk_scale,
|
459 |
+
attn_drop=attn_drop,
|
460 |
+
proj_drop=drop,
|
461 |
+
window_size=window_size,
|
462 |
+
attn_head_dim=attn_head_dim,
|
463 |
+
xattn=xattn,
|
464 |
+
rope=rope,
|
465 |
+
subln=subln,
|
466 |
+
norm_layer=norm_layer,
|
467 |
+
)
|
468 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
469 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
470 |
+
self.norm2 = norm_layer(dim)
|
471 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
472 |
+
|
473 |
+
if naiveswiglu:
|
474 |
+
self.mlp = SwiGLU(
|
475 |
+
in_features=dim,
|
476 |
+
hidden_features=mlp_hidden_dim,
|
477 |
+
subln=subln,
|
478 |
+
norm_layer=norm_layer,
|
479 |
+
)
|
480 |
+
else:
|
481 |
+
self.mlp = Mlp(
|
482 |
+
in_features=dim,
|
483 |
+
hidden_features=mlp_hidden_dim,
|
484 |
+
act_layer=act_layer,
|
485 |
+
subln=subln,
|
486 |
+
drop=drop,
|
487 |
+
)
|
488 |
+
|
489 |
+
if init_values is not None and init_values > 0:
|
490 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
491 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
492 |
+
else:
|
493 |
+
self.gamma_1, self.gamma_2 = None, None
|
494 |
+
|
495 |
+
self.postnorm = postnorm
|
496 |
+
|
497 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
498 |
+
if self.gamma_1 is None:
|
499 |
+
if self.postnorm:
|
500 |
+
x = x + self.drop_path(
|
501 |
+
self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
502 |
+
)
|
503 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
504 |
+
else:
|
505 |
+
x = x + self.drop_path(
|
506 |
+
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
|
507 |
+
)
|
508 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
509 |
+
else:
|
510 |
+
if self.postnorm:
|
511 |
+
x = x + self.drop_path(
|
512 |
+
self.gamma_1
|
513 |
+
* self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
514 |
+
)
|
515 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
516 |
+
else:
|
517 |
+
x = x + self.drop_path(
|
518 |
+
self.gamma_1
|
519 |
+
* self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
|
520 |
+
)
|
521 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
522 |
+
return x
|
523 |
+
|
524 |
+
|
525 |
+
class PatchEmbed(nn.Module):
|
526 |
+
"""Image to Patch Embedding"""
|
527 |
+
|
528 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
529 |
+
super().__init__()
|
530 |
+
img_size = to_2tuple(img_size)
|
531 |
+
patch_size = to_2tuple(patch_size)
|
532 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
533 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
534 |
+
self.img_size = img_size
|
535 |
+
self.patch_size = patch_size
|
536 |
+
self.num_patches = num_patches
|
537 |
+
|
538 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
539 |
+
|
540 |
+
def forward(self, x, **kwargs):
|
541 |
+
B, C, H, W = x.shape
|
542 |
+
# FIXME look at relaxing size constraints
|
543 |
+
assert (
|
544 |
+
H == self.img_size[0] and W == self.img_size[1]
|
545 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
546 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
547 |
+
return x
|
548 |
+
|
549 |
+
|
550 |
+
class RelativePositionBias(nn.Module):
|
551 |
+
def __init__(self, window_size, num_heads):
|
552 |
+
super().__init__()
|
553 |
+
self.window_size = window_size
|
554 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
555 |
+
self.relative_position_bias_table = nn.Parameter(
|
556 |
+
torch.zeros(self.num_relative_distance, num_heads)
|
557 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
558 |
+
# cls to token & token 2 cls & cls to cls
|
559 |
+
|
560 |
+
# get pair-wise relative position index for each token inside the window
|
561 |
+
coords_h = torch.arange(window_size[0])
|
562 |
+
coords_w = torch.arange(window_size[1])
|
563 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
564 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
565 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
566 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
567 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
568 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
569 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
570 |
+
relative_position_index = torch.zeros(
|
571 |
+
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
572 |
+
)
|
573 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
574 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
575 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
576 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
577 |
+
|
578 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
579 |
+
|
580 |
+
def forward(self):
|
581 |
+
relative_position_bias = self.relative_position_bias_table[
|
582 |
+
self.relative_position_index.view(-1)
|
583 |
+
].view(
|
584 |
+
self.window_size[0] * self.window_size[1] + 1,
|
585 |
+
self.window_size[0] * self.window_size[1] + 1,
|
586 |
+
-1,
|
587 |
+
) # Wh*Ww,Wh*Ww,nH
|
588 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
589 |
+
|
590 |
+
|
591 |
+
class EVAVisionTransformer(nn.Module):
|
592 |
+
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
593 |
+
|
594 |
+
def __init__(
|
595 |
+
self,
|
596 |
+
img_size=224,
|
597 |
+
patch_size=16,
|
598 |
+
in_chans=3,
|
599 |
+
num_classes=1000,
|
600 |
+
embed_dim=768,
|
601 |
+
depth=12,
|
602 |
+
num_heads=12,
|
603 |
+
mlp_ratio=4.0,
|
604 |
+
qkv_bias=False,
|
605 |
+
qk_scale=None,
|
606 |
+
drop_rate=0.0,
|
607 |
+
attn_drop_rate=0.0,
|
608 |
+
drop_path_rate=0.0,
|
609 |
+
norm_layer=nn.LayerNorm,
|
610 |
+
init_values=None,
|
611 |
+
patch_dropout=0.0,
|
612 |
+
use_abs_pos_emb=True,
|
613 |
+
use_rel_pos_bias=False,
|
614 |
+
use_shared_rel_pos_bias=False,
|
615 |
+
rope=False,
|
616 |
+
use_mean_pooling=True,
|
617 |
+
init_scale=0.001,
|
618 |
+
grad_checkpointing=False,
|
619 |
+
xattn=False,
|
620 |
+
postnorm=False,
|
621 |
+
pt_hw_seq_len=16,
|
622 |
+
intp_freq=False,
|
623 |
+
naiveswiglu=False,
|
624 |
+
subln=False,
|
625 |
+
):
|
626 |
+
super().__init__()
|
627 |
+
self.image_size = img_size
|
628 |
+
self.num_classes = num_classes
|
629 |
+
self.num_features = (
|
630 |
+
self.embed_dim
|
631 |
+
) = embed_dim # num_features for consistency with other models
|
632 |
+
|
633 |
+
self.patch_embed = PatchEmbed(
|
634 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
|
635 |
+
)
|
636 |
+
num_patches = self.patch_embed.num_patches
|
637 |
+
|
638 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
639 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
640 |
+
if use_abs_pos_emb:
|
641 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
642 |
+
else:
|
643 |
+
self.pos_embed = None
|
644 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
645 |
+
|
646 |
+
if use_shared_rel_pos_bias:
|
647 |
+
self.rel_pos_bias = RelativePositionBias(
|
648 |
+
window_size=self.patch_embed.patch_shape, num_heads=num_heads
|
649 |
+
)
|
650 |
+
else:
|
651 |
+
self.rel_pos_bias = None
|
652 |
+
|
653 |
+
if rope:
|
654 |
+
half_head_dim = embed_dim // num_heads // 2
|
655 |
+
hw_seq_len = img_size // patch_size
|
656 |
+
self.rope = VisionRotaryEmbeddingFast(
|
657 |
+
dim=half_head_dim,
|
658 |
+
pt_seq_len=pt_hw_seq_len,
|
659 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
660 |
+
# patch_dropout=patch_dropout
|
661 |
+
)
|
662 |
+
else:
|
663 |
+
self.rope = None
|
664 |
+
|
665 |
+
self.naiveswiglu = naiveswiglu
|
666 |
+
|
667 |
+
dpr = [
|
668 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
669 |
+
] # stochastic depth decay rule
|
670 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
671 |
+
self.blocks = nn.ModuleList(
|
672 |
+
[
|
673 |
+
Block(
|
674 |
+
dim=embed_dim,
|
675 |
+
num_heads=num_heads,
|
676 |
+
mlp_ratio=mlp_ratio,
|
677 |
+
qkv_bias=qkv_bias,
|
678 |
+
qk_scale=qk_scale,
|
679 |
+
drop=drop_rate,
|
680 |
+
attn_drop=attn_drop_rate,
|
681 |
+
drop_path=dpr[i],
|
682 |
+
norm_layer=norm_layer,
|
683 |
+
init_values=init_values,
|
684 |
+
window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
|
685 |
+
xattn=xattn,
|
686 |
+
rope=self.rope,
|
687 |
+
postnorm=postnorm,
|
688 |
+
subln=subln,
|
689 |
+
naiveswiglu=naiveswiglu,
|
690 |
+
)
|
691 |
+
for i in range(depth)
|
692 |
+
]
|
693 |
+
)
|
694 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
695 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
696 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
697 |
+
|
698 |
+
if self.pos_embed is not None:
|
699 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
700 |
+
|
701 |
+
trunc_normal_(self.cls_token, std=0.02)
|
702 |
+
# trunc_normal_(self.mask_token, std=.02)
|
703 |
+
|
704 |
+
self.apply(self._init_weights)
|
705 |
+
self.fix_init_weight()
|
706 |
+
|
707 |
+
if isinstance(self.head, nn.Linear):
|
708 |
+
trunc_normal_(self.head.weight, std=0.02)
|
709 |
+
self.head.weight.data.mul_(init_scale)
|
710 |
+
self.head.bias.data.mul_(init_scale)
|
711 |
+
|
712 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
713 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
|
714 |
+
|
715 |
+
self.grad_checkpointing = grad_checkpointing
|
716 |
+
|
717 |
+
def fix_init_weight(self):
|
718 |
+
def rescale(param, layer_id):
|
719 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
720 |
+
|
721 |
+
for layer_id, layer in enumerate(self.blocks):
|
722 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
723 |
+
if self.naiveswiglu:
|
724 |
+
rescale(layer.mlp.w3.weight.data, layer_id + 1)
|
725 |
+
else:
|
726 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
727 |
+
|
728 |
+
def get_cast_dtype(self) -> torch.dtype:
|
729 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
730 |
+
|
731 |
+
def _init_weights(self, m):
|
732 |
+
if isinstance(m, nn.Linear):
|
733 |
+
trunc_normal_(m.weight, std=0.02)
|
734 |
+
if m.bias is not None:
|
735 |
+
nn.init.constant_(m.bias, 0)
|
736 |
+
elif isinstance(m, nn.LayerNorm):
|
737 |
+
nn.init.constant_(m.bias, 0)
|
738 |
+
nn.init.constant_(m.weight, 1.0)
|
739 |
+
|
740 |
+
def get_num_layers(self):
|
741 |
+
return len(self.blocks)
|
742 |
+
|
743 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
744 |
+
assert unlocked_groups == 0, "partial locking not currently supported for this model"
|
745 |
+
for param in self.parameters():
|
746 |
+
param.requires_grad = False
|
747 |
+
|
748 |
+
@torch.jit.ignore
|
749 |
+
def set_grad_checkpointing(self, enable=True):
|
750 |
+
self.grad_checkpointing = enable
|
751 |
+
|
752 |
+
@torch.jit.ignore
|
753 |
+
def no_weight_decay(self):
|
754 |
+
return {"pos_embed", "cls_token"}
|
755 |
+
|
756 |
+
def get_classifier(self):
|
757 |
+
return self.head
|
758 |
+
|
759 |
+
def reset_classifier(self, num_classes, global_pool=""):
|
760 |
+
self.num_classes = num_classes
|
761 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
762 |
+
|
763 |
+
def forward_features(self, x, return_all_features=False):
|
764 |
+
|
765 |
+
x = self.patch_embed(x)
|
766 |
+
batch_size, seq_len, _ = x.size()
|
767 |
+
|
768 |
+
cls_tokens = self.cls_token.expand(
|
769 |
+
batch_size, -1, -1
|
770 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
771 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
772 |
+
if self.pos_embed is not None:
|
773 |
+
x = x + self.pos_embed
|
774 |
+
x = self.pos_drop(x)
|
775 |
+
|
776 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
777 |
+
if os.getenv("RoPE") == "1":
|
778 |
+
if self.training and not isinstance(self.patch_dropout, nn.Identity):
|
779 |
+
x, patch_indices_keep = self.patch_dropout(x)
|
780 |
+
self.rope.forward = partial(
|
781 |
+
self.rope.forward, patch_indices_keep=patch_indices_keep
|
782 |
+
)
|
783 |
+
else:
|
784 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
|
785 |
+
x = self.patch_dropout(x)
|
786 |
+
else:
|
787 |
+
x = self.patch_dropout(x)
|
788 |
+
|
789 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
790 |
+
for i, blk in enumerate(self.blocks):
|
791 |
+
if i == len(self.blocks) - 1:
|
792 |
+
continue
|
793 |
+
if self.grad_checkpointing:
|
794 |
+
x = checkpoint(blk, x, (rel_pos_bias,))
|
795 |
+
else:
|
796 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
797 |
+
|
798 |
+
if not return_all_features:
|
799 |
+
x = self.norm(x)
|
800 |
+
if self.fc_norm is not None:
|
801 |
+
return self.fc_norm(x.mean(1))
|
802 |
+
else:
|
803 |
+
return x[:, 0]
|
804 |
+
return x
|
805 |
+
|
806 |
+
def forward(self, x, return_all_features=False):
|
807 |
+
if return_all_features:
|
808 |
+
return self.forward_features(x, return_all_features)
|
809 |
+
x = self.forward_features(x)
|
810 |
+
x = self.head(x)
|
811 |
+
return x
|
812 |
+
|
813 |
+
|
814 |
+
def load_state_dict(
|
815 |
+
checkpoint_path: str,
|
816 |
+
map_location: str = "cpu",
|
817 |
+
model_key: str = "model|module|state_dict",
|
818 |
+
is_openai: bool = False,
|
819 |
+
skip_list: list = [],
|
820 |
+
):
|
821 |
+
if is_openai:
|
822 |
+
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
|
823 |
+
state_dict = model.state_dict()
|
824 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
825 |
+
state_dict.pop(key, None)
|
826 |
+
else:
|
827 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
828 |
+
for mk in model_key.split("|"):
|
829 |
+
if isinstance(checkpoint, dict) and mk in checkpoint:
|
830 |
+
state_dict = checkpoint[mk]
|
831 |
+
break
|
832 |
+
else:
|
833 |
+
state_dict = checkpoint
|
834 |
+
if next(iter(state_dict.items()))[0].startswith("module"):
|
835 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
836 |
+
|
837 |
+
for k in skip_list:
|
838 |
+
if k in list(state_dict.keys()):
|
839 |
+
logging.info(f"Removing key {k} from pretrained checkpoint")
|
840 |
+
del state_dict[k]
|
841 |
+
|
842 |
+
if os.getenv("RoPE") == "1":
|
843 |
+
for k in list(state_dict.keys()):
|
844 |
+
if "freqs_cos" in k or "freqs_sin" in k:
|
845 |
+
del state_dict[k]
|
846 |
+
return state_dict
|
847 |
+
|
848 |
+
|
849 |
+
def load_clip_visual_state_dict(
|
850 |
+
checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []
|
851 |
+
):
|
852 |
+
state_dict = load_state_dict(
|
853 |
+
checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list
|
854 |
+
)
|
855 |
+
|
856 |
+
for k in list(state_dict.keys()):
|
857 |
+
if not k.startswith("visual."):
|
858 |
+
del state_dict[k]
|
859 |
+
for k in list(state_dict.keys()):
|
860 |
+
if k.startswith("visual."):
|
861 |
+
new_k = k[7:]
|
862 |
+
state_dict[new_k] = state_dict[k]
|
863 |
+
del state_dict[k]
|
864 |
+
return state_dict
|
865 |
+
|
866 |
+
|
867 |
+
try:
|
868 |
+
from apex.normalization import FusedLayerNorm
|
869 |
+
except:
|
870 |
+
FusedLayerNorm = LayerNorm
|
871 |
+
print(
|
872 |
+
"Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source ."
|
873 |
+
)
|
874 |
+
|
875 |
+
|
876 |
+
@dataclass
|
877 |
+
class CLIPVisionCfg:
|
878 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
879 |
+
width: int = 768
|
880 |
+
head_width: int = 64
|
881 |
+
mlp_ratio: float = 4.0
|
882 |
+
patch_size: int = 16
|
883 |
+
image_size: Union[Tuple[int, int], int] = 224
|
884 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
885 |
+
patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
886 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
887 |
+
drop_path_rate: Optional[float] = None # drop path rate
|
888 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
889 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
890 |
+
timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
891 |
+
timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
|
892 |
+
timm_proj_bias: bool = False # enable bias final projection
|
893 |
+
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
|
894 |
+
qkv_bias: bool = True
|
895 |
+
fusedLN: bool = False
|
896 |
+
xattn: bool = False
|
897 |
+
postnorm: bool = False
|
898 |
+
rope: bool = False
|
899 |
+
pt_hw_seq_len: int = 16 # 224/14
|
900 |
+
intp_freq: bool = False
|
901 |
+
naiveswiglu: bool = False
|
902 |
+
subln: bool = False
|
903 |
+
|
904 |
+
|
905 |
+
def _build_vision_tower(vision_tower_path: str, embed_dim: int, vision_cfg: CLIPVisionCfg):
|
906 |
+
if isinstance(vision_cfg, dict):
|
907 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
908 |
+
|
909 |
+
if vision_cfg.eva_model_name:
|
910 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
911 |
+
norm_layer = LayerNorm
|
912 |
+
|
913 |
+
visual = EVAVisionTransformer(
|
914 |
+
img_size=vision_cfg.image_size,
|
915 |
+
patch_size=vision_cfg.patch_size,
|
916 |
+
num_classes=embed_dim,
|
917 |
+
use_mean_pooling=vision_cfg.global_average_pool, # False
|
918 |
+
init_values=vision_cfg.ls_init_value,
|
919 |
+
patch_dropout=vision_cfg.patch_dropout,
|
920 |
+
embed_dim=vision_cfg.width,
|
921 |
+
depth=vision_cfg.layers,
|
922 |
+
num_heads=vision_heads,
|
923 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
924 |
+
qkv_bias=vision_cfg.qkv_bias,
|
925 |
+
drop_path_rate=vision_cfg.drop_path_rate,
|
926 |
+
norm_layer=partial(FusedLayerNorm, eps=1e-6)
|
927 |
+
if vision_cfg.fusedLN
|
928 |
+
else partial(norm_layer, eps=1e-6),
|
929 |
+
xattn=vision_cfg.xattn,
|
930 |
+
rope=vision_cfg.rope,
|
931 |
+
postnorm=vision_cfg.postnorm,
|
932 |
+
pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
|
933 |
+
intp_freq=vision_cfg.intp_freq,
|
934 |
+
naiveswiglu=vision_cfg.naiveswiglu,
|
935 |
+
subln=vision_cfg.subln,
|
936 |
+
)
|
937 |
+
|
938 |
+
state_dict = load_clip_visual_state_dict(vision_tower_path)
|
939 |
+
incompatible_keys = visual.load_state_dict(state_dict, strict=False)
|
940 |
+
print("EVA-CLIP incompatible_keys:", incompatible_keys)
|
941 |
+
|
942 |
+
return visual
|
943 |
+
|
944 |
+
|
945 |
+
class Eva2LargePlusEncoder(nn.Module):
|
946 |
+
def __init__(self, vision_tower_path):
|
947 |
+
super(Eva2LargePlusEncoder, self).__init__()
|
948 |
+
self.config = {
|
949 |
+
"embed_dim": 768,
|
950 |
+
"vision_cfg": {
|
951 |
+
"image_size": 336,
|
952 |
+
"layers": 24,
|
953 |
+
"width": 1024,
|
954 |
+
"drop_path_rate": 0,
|
955 |
+
"head_width": 64,
|
956 |
+
"mlp_ratio": 2.6667,
|
957 |
+
"patch_size": 14,
|
958 |
+
"eva_model_name": "eva-clip-l-14-336",
|
959 |
+
"xattn": True,
|
960 |
+
"fusedLN": True,
|
961 |
+
"rope": True,
|
962 |
+
"pt_hw_seq_len": 16,
|
963 |
+
"intp_freq": True,
|
964 |
+
"naiveswiglu": True,
|
965 |
+
"subln": True,
|
966 |
+
},
|
967 |
+
}
|
968 |
+
|
969 |
+
self.config["vision_tower_path"] = vision_tower_path
|
970 |
+
self.model = _build_vision_tower(**self.config)
|
971 |
+
|
972 |
+
def forward(self, image, **kwargs):
|
973 |
+
encode = self.model(image, return_all_features=True)[:, 1:, :]
|
974 |
+
return encode
|
975 |
+
|
976 |
+
@property
|
977 |
+
def dtype(self):
|
978 |
+
return list(self.parameters())[-1].dtype
|
979 |
+
|
980 |
+
@property
|
981 |
+
def device(self):
|
982 |
+
return list(self.parameters())[-1].device
|
vita/model/multimodal_encoder/internvit/configuration_intern_vit.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2023 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
import os
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
from transformers.configuration_utils import PretrainedConfig
|
10 |
+
from transformers.utils import logging
|
11 |
+
|
12 |
+
logger = logging.get_logger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class InternVisionConfig(PretrainedConfig):
|
16 |
+
r"""
|
17 |
+
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
18 |
+
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
19 |
+
|
20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
21 |
+
documentation from [`PretrainedConfig`] for more information.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
num_channels (`int`, *optional*, defaults to 3):
|
25 |
+
Number of color channels in the input images (e.g., 3 for RGB).
|
26 |
+
patch_size (`int`, *optional*, defaults to 14):
|
27 |
+
The size (resolution) of each patch.
|
28 |
+
image_size (`int`, *optional*, defaults to 224):
|
29 |
+
The size (resolution) of each image.
|
30 |
+
qkv_bias (`bool`, *optional*, defaults to `False`):
|
31 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
32 |
+
hidden_size (`int`, *optional*, defaults to 3200):
|
33 |
+
Dimensionality of the encoder layers and the pooler layer.
|
34 |
+
num_attention_heads (`int`, *optional*, defaults to 25):
|
35 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
36 |
+
intermediate_size (`int`, *optional*, defaults to 12800):
|
37 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
38 |
+
qk_normalization (`bool`, *optional*, defaults to `True`):
|
39 |
+
Whether to normalize the queries and keys in the self-attention layers.
|
40 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
41 |
+
Number of hidden layers in the Transformer encoder.
|
42 |
+
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
43 |
+
Whether to use flash attention mechanism.
|
44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
46 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
47 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
48 |
+
The epsilon used by the layer normalization layers.
|
49 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
51 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
52 |
+
Dropout rate for stochastic depth.
|
53 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
54 |
+
The dropout ratio for the attention probabilities.
|
55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
57 |
+
initializer_factor (`float`, *optional*, defaults to 0.1):
|
58 |
+
A factor for layer scale.
|
59 |
+
"""
|
60 |
+
|
61 |
+
model_type = "intern_vit_6b"
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
num_channels=3,
|
66 |
+
patch_size=14,
|
67 |
+
image_size=224,
|
68 |
+
qkv_bias=False,
|
69 |
+
hidden_size=3200,
|
70 |
+
num_attention_heads=25,
|
71 |
+
intermediate_size=12800,
|
72 |
+
qk_normalization=True,
|
73 |
+
num_hidden_layers=48,
|
74 |
+
use_flash_attn=True,
|
75 |
+
hidden_act="gelu",
|
76 |
+
norm_type="rms_norm",
|
77 |
+
layer_norm_eps=1e-6,
|
78 |
+
dropout=0.0,
|
79 |
+
drop_path_rate=0.0,
|
80 |
+
attention_dropout=0.0,
|
81 |
+
initializer_range=0.02,
|
82 |
+
initializer_factor=0.1,
|
83 |
+
**kwargs,
|
84 |
+
):
|
85 |
+
super().__init__(**kwargs)
|
86 |
+
|
87 |
+
self.hidden_size = hidden_size
|
88 |
+
self.intermediate_size = intermediate_size
|
89 |
+
self.dropout = dropout
|
90 |
+
self.drop_path_rate = drop_path_rate
|
91 |
+
self.num_hidden_layers = num_hidden_layers
|
92 |
+
self.num_attention_heads = num_attention_heads
|
93 |
+
self.num_channels = num_channels
|
94 |
+
self.patch_size = patch_size
|
95 |
+
self.image_size = image_size
|
96 |
+
self.initializer_range = initializer_range
|
97 |
+
self.initializer_factor = initializer_factor
|
98 |
+
self.attention_dropout = attention_dropout
|
99 |
+
self.layer_norm_eps = layer_norm_eps
|
100 |
+
self.hidden_act = hidden_act
|
101 |
+
self.norm_type = norm_type
|
102 |
+
self.qkv_bias = qkv_bias
|
103 |
+
self.qk_normalization = qk_normalization
|
104 |
+
self.use_flash_attn = use_flash_attn
|
105 |
+
|
106 |
+
@classmethod
|
107 |
+
def from_pretrained(
|
108 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
109 |
+
) -> "PretrainedConfig":
|
110 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
111 |
+
|
112 |
+
if "vision_config" in config_dict:
|
113 |
+
config_dict = config_dict["vision_config"]
|
114 |
+
|
115 |
+
if (
|
116 |
+
"model_type" in config_dict
|
117 |
+
and hasattr(cls, "model_type")
|
118 |
+
and config_dict["model_type"] != cls.model_type
|
119 |
+
):
|
120 |
+
logger.warning(
|
121 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
122 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
123 |
+
)
|
124 |
+
|
125 |
+
return cls.from_dict(config_dict, **kwargs)
|
vita/model/multimodal_encoder/internvit/flash_attention.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
7 |
+
|
8 |
+
try: # v1
|
9 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
10 |
+
except: # v2
|
11 |
+
from flash_attn.flash_attn_interface import (
|
12 |
+
flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
class FlashAttention(nn.Module):
|
17 |
+
"""Implement the scaled dot product attention with softmax.
|
18 |
+
Arguments
|
19 |
+
---------
|
20 |
+
softmax_scale: The temperature to use for the softmax attention.
|
21 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
22 |
+
runtime)
|
23 |
+
attention_dropout: The dropout rate to apply to the attention
|
24 |
+
(default: 0.0)
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
|
28 |
+
super().__init__()
|
29 |
+
self.softmax_scale = softmax_scale
|
30 |
+
self.dropout_p = attention_dropout
|
31 |
+
|
32 |
+
def forward(
|
33 |
+
self,
|
34 |
+
qkv,
|
35 |
+
key_padding_mask=None,
|
36 |
+
causal=False,
|
37 |
+
cu_seqlens=None,
|
38 |
+
max_s=None,
|
39 |
+
need_weights=False,
|
40 |
+
):
|
41 |
+
"""Implements the multihead softmax attention.
|
42 |
+
Arguments
|
43 |
+
---------
|
44 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
45 |
+
if unpadded: (nnz, 3, h, d)
|
46 |
+
key_padding_mask: a bool tensor of shape (B, S)
|
47 |
+
"""
|
48 |
+
assert not need_weights
|
49 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
50 |
+
assert qkv.is_cuda
|
51 |
+
|
52 |
+
if cu_seqlens is None:
|
53 |
+
batch_size = qkv.shape[0]
|
54 |
+
seqlen = qkv.shape[1]
|
55 |
+
if key_padding_mask is None:
|
56 |
+
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
57 |
+
max_s = seqlen
|
58 |
+
cu_seqlens = torch.arange(
|
59 |
+
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
|
60 |
+
)
|
61 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
62 |
+
qkv,
|
63 |
+
cu_seqlens,
|
64 |
+
max_s,
|
65 |
+
self.dropout_p if self.training else 0.0,
|
66 |
+
softmax_scale=self.softmax_scale,
|
67 |
+
causal=causal,
|
68 |
+
)
|
69 |
+
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
70 |
+
else:
|
71 |
+
nheads = qkv.shape[-2]
|
72 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
73 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
74 |
+
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
|
75 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
76 |
+
x_unpad,
|
77 |
+
cu_seqlens,
|
78 |
+
max_s,
|
79 |
+
self.dropout_p if self.training else 0.0,
|
80 |
+
softmax_scale=self.softmax_scale,
|
81 |
+
causal=causal,
|
82 |
+
)
|
83 |
+
output = rearrange(
|
84 |
+
pad_input(
|
85 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
|
86 |
+
),
|
87 |
+
"b s (h d) -> b s h d",
|
88 |
+
h=nheads,
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
assert max_s is not None
|
92 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
93 |
+
qkv,
|
94 |
+
cu_seqlens,
|
95 |
+
max_s,
|
96 |
+
self.dropout_p if self.training else 0.0,
|
97 |
+
softmax_scale=self.softmax_scale,
|
98 |
+
causal=causal,
|
99 |
+
)
|
100 |
+
|
101 |
+
return output, None
|
vita/model/multimodal_encoder/internvit/internvit_encoder.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
4 |
+
|
5 |
+
from .modeling_intern_vit import InternVisionModel
|
6 |
+
|
7 |
+
|
8 |
+
class InternViTVisionTower(nn.Module):
|
9 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.is_loaded = False
|
13 |
+
|
14 |
+
self.vision_tower_name = vision_tower
|
15 |
+
self.select_layer = -1
|
16 |
+
self.scale_pix_shuffle = 0.5
|
17 |
+
|
18 |
+
if not delay_load:
|
19 |
+
self.load_model()
|
20 |
+
else:
|
21 |
+
self.cfg_only = AutoConfig.from_pretrained(
|
22 |
+
self.vision_tower_name, trust_remote_code=True
|
23 |
+
)
|
24 |
+
|
25 |
+
def load_model(self):
|
26 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
27 |
+
self.vision_tower = InternVisionModel.from_pretrained(
|
28 |
+
self.vision_tower_name, trust_remote_code=True
|
29 |
+
)
|
30 |
+
self.vision_tower.requires_grad_(False)
|
31 |
+
|
32 |
+
self.is_loaded = True
|
33 |
+
|
34 |
+
def feature_select(self, image_forward_outs):
|
35 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
36 |
+
|
37 |
+
image_features = image_features[:, 1:]
|
38 |
+
|
39 |
+
return image_features
|
40 |
+
|
41 |
+
def pixel_shuffle(self, x, scale_factor=0.5):
|
42 |
+
n, w, h, c = x.size()
|
43 |
+
# N, W, H, C --> N, W, H * scale, C // scale
|
44 |
+
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
45 |
+
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
46 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
47 |
+
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
48 |
+
x = x.view(
|
49 |
+
n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))
|
50 |
+
)
|
51 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
52 |
+
return x
|
53 |
+
|
54 |
+
#@torch.no_grad()
|
55 |
+
def forward(self, images):
|
56 |
+
if type(images) is list:
|
57 |
+
image_features = []
|
58 |
+
for image in images:
|
59 |
+
image_forward_out = self.vision_tower(
|
60 |
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
61 |
+
output_hidden_states=True,
|
62 |
+
)
|
63 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
64 |
+
image_features.append(image_feature)
|
65 |
+
else:
|
66 |
+
image_forward_outs = self.vision_tower(
|
67 |
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
68 |
+
)
|
69 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
70 |
+
h = w = int(image_features.shape[1] ** 0.5)
|
71 |
+
assert image_features.shape[1] == h * w
|
72 |
+
image_features = image_features.reshape(image_features.shape[0], h, w, -1)
|
73 |
+
image_features = self.pixel_shuffle(image_features * self.scale_pix_shuffle)
|
74 |
+
image_features = image_features.reshape(
|
75 |
+
image_features.shape[0], -1, image_features.shape[-1]
|
76 |
+
)
|
77 |
+
|
78 |
+
return image_features
|
79 |
+
|
80 |
+
@property
|
81 |
+
def dummy_feature(self):
|
82 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
83 |
+
|
84 |
+
@property
|
85 |
+
def dtype(self):
|
86 |
+
return self.vision_tower.dtype
|
87 |
+
|
88 |
+
@property
|
89 |
+
def device(self):
|
90 |
+
return self.vision_tower.device
|
91 |
+
|
92 |
+
@property
|
93 |
+
def config(self):
|
94 |
+
if self.is_loaded:
|
95 |
+
return self.vision_tower.config
|
96 |
+
else:
|
97 |
+
return self.cfg_only
|
98 |
+
|
99 |
+
@property
|
100 |
+
def hidden_size(self):
|
101 |
+
return self.config.hidden_size * (int(1 / self.scale_pix_shuffle) ** 2)
|
102 |
+
|
103 |
+
@property
|
104 |
+
def num_patches(self):
|
105 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
vita/model/multimodal_encoder/internvit/modeling_intern_vit.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2023 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
from typing import Optional, Tuple, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from einops import rearrange
|
12 |
+
from torch import nn
|
13 |
+
from transformers.activations import ACT2FN
|
14 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
15 |
+
from transformers.modeling_utils import PreTrainedModel
|
16 |
+
from transformers.utils import logging
|
17 |
+
|
18 |
+
from timm.models.layers import DropPath
|
19 |
+
|
20 |
+
from .configuration_intern_vit import InternVisionConfig
|
21 |
+
|
22 |
+
try:
|
23 |
+
from .flash_attention import FlashAttention
|
24 |
+
|
25 |
+
has_flash_attn = True
|
26 |
+
except:
|
27 |
+
print("FlashAttention is not installed.")
|
28 |
+
has_flash_attn = False
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
class InternRMSNorm(nn.Module):
|
35 |
+
def __init__(self, hidden_size, eps=1e-6):
|
36 |
+
super().__init__()
|
37 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
38 |
+
self.variance_epsilon = eps
|
39 |
+
|
40 |
+
def forward(self, hidden_states):
|
41 |
+
input_dtype = hidden_states.dtype
|
42 |
+
hidden_states = hidden_states.to(torch.float32)
|
43 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
44 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
45 |
+
return self.weight * hidden_states.to(input_dtype)
|
46 |
+
|
47 |
+
|
48 |
+
try:
|
49 |
+
from apex.normalization import FusedRMSNorm
|
50 |
+
|
51 |
+
InternRMSNorm = FusedRMSNorm # noqa
|
52 |
+
|
53 |
+
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm")
|
54 |
+
except ImportError:
|
55 |
+
# using the normal InternRMSNorm
|
56 |
+
pass
|
57 |
+
except Exception:
|
58 |
+
logger.warning("discovered apex but it failed to load, falling back to InternRMSNorm")
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
NORM2FN = {
|
63 |
+
"rms_norm": InternRMSNorm,
|
64 |
+
"layer_norm": nn.LayerNorm,
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
class InternVisionEmbeddings(nn.Module):
|
69 |
+
def __init__(self, config: InternVisionConfig):
|
70 |
+
super().__init__()
|
71 |
+
self.config = config
|
72 |
+
self.embed_dim = config.hidden_size
|
73 |
+
self.image_size = config.image_size
|
74 |
+
self.patch_size = config.patch_size
|
75 |
+
|
76 |
+
self.class_embedding = nn.Parameter(
|
77 |
+
torch.randn(1, 1, self.embed_dim),
|
78 |
+
)
|
79 |
+
|
80 |
+
self.patch_embedding = nn.Conv2d(
|
81 |
+
in_channels=3,
|
82 |
+
out_channels=self.embed_dim,
|
83 |
+
kernel_size=self.patch_size,
|
84 |
+
stride=self.patch_size,
|
85 |
+
)
|
86 |
+
|
87 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
88 |
+
self.num_positions = self.num_patches + 1
|
89 |
+
|
90 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
91 |
+
|
92 |
+
def _get_pos_embed(self, pos_embed, H, W):
|
93 |
+
target_dtype = pos_embed.dtype
|
94 |
+
pos_embed = (
|
95 |
+
pos_embed.float()
|
96 |
+
.reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1)
|
97 |
+
.permute(0, 3, 1, 2)
|
98 |
+
)
|
99 |
+
pos_embed = (
|
100 |
+
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
|
101 |
+
.reshape(1, -1, H * W)
|
102 |
+
.permute(0, 2, 1)
|
103 |
+
.to(target_dtype)
|
104 |
+
)
|
105 |
+
return pos_embed
|
106 |
+
|
107 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
108 |
+
target_dtype = self.patch_embedding.weight.dtype
|
109 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
|
110 |
+
batch_size, _, height, width = patch_embeds.shape
|
111 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
112 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
113 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
114 |
+
position_embedding = torch.cat(
|
115 |
+
[
|
116 |
+
self.position_embedding[:, :1, :],
|
117 |
+
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
|
118 |
+
],
|
119 |
+
dim=1,
|
120 |
+
)
|
121 |
+
embeddings = embeddings + position_embedding.to(target_dtype)
|
122 |
+
return embeddings
|
123 |
+
|
124 |
+
|
125 |
+
class InternAttention(nn.Module):
|
126 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
127 |
+
|
128 |
+
def __init__(self, config: InternVisionConfig):
|
129 |
+
super().__init__()
|
130 |
+
self.config = config
|
131 |
+
self.embed_dim = config.hidden_size
|
132 |
+
self.num_heads = config.num_attention_heads
|
133 |
+
self.use_flash_attn = config.use_flash_attn and has_flash_attn
|
134 |
+
if config.use_flash_attn and not has_flash_attn:
|
135 |
+
print("Warning: Flash Attention is not available, use_flash_attn is set to False.")
|
136 |
+
self.head_dim = self.embed_dim // self.num_heads
|
137 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
138 |
+
raise ValueError(
|
139 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
140 |
+
f" {self.num_heads})."
|
141 |
+
)
|
142 |
+
|
143 |
+
self.scale = self.head_dim**-0.5
|
144 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
145 |
+
self.attn_drop = nn.Dropout(config.attention_dropout)
|
146 |
+
self.proj_drop = nn.Dropout(config.dropout)
|
147 |
+
|
148 |
+
self.qk_normalization = config.qk_normalization
|
149 |
+
|
150 |
+
if self.qk_normalization:
|
151 |
+
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
152 |
+
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
153 |
+
|
154 |
+
if self.use_flash_attn:
|
155 |
+
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
|
156 |
+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
157 |
+
|
158 |
+
def _naive_attn(self, x):
|
159 |
+
B, N, C = x.shape
|
160 |
+
qkv = (
|
161 |
+
self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
162 |
+
)
|
163 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
164 |
+
|
165 |
+
if self.qk_normalization:
|
166 |
+
B_, H_, N_, D_ = q.shape
|
167 |
+
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
168 |
+
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
169 |
+
|
170 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
171 |
+
attn = attn.softmax(dim=-1)
|
172 |
+
attn = self.attn_drop(attn)
|
173 |
+
|
174 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
175 |
+
x = self.proj(x)
|
176 |
+
x = self.proj_drop(x)
|
177 |
+
return x
|
178 |
+
|
179 |
+
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
180 |
+
qkv = self.qkv(x)
|
181 |
+
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
|
182 |
+
|
183 |
+
if self.qk_normalization:
|
184 |
+
q, k, v = qkv.unbind(2)
|
185 |
+
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
186 |
+
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
187 |
+
qkv = torch.stack([q, k, v], dim=2)
|
188 |
+
|
189 |
+
context, _ = self.inner_attn(
|
190 |
+
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
|
191 |
+
)
|
192 |
+
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
|
193 |
+
outs = self.proj_drop(outs)
|
194 |
+
return outs
|
195 |
+
|
196 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
197 |
+
x = (
|
198 |
+
self._naive_attn(hidden_states)
|
199 |
+
if not self.use_flash_attn
|
200 |
+
else self._flash_attn(hidden_states)
|
201 |
+
)
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
class InternMLP(nn.Module):
|
206 |
+
def __init__(self, config: InternVisionConfig):
|
207 |
+
super().__init__()
|
208 |
+
self.config = config
|
209 |
+
self.act = ACT2FN[config.hidden_act]
|
210 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
211 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
212 |
+
|
213 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
214 |
+
hidden_states = self.fc1(hidden_states)
|
215 |
+
hidden_states = self.act(hidden_states)
|
216 |
+
hidden_states = self.fc2(hidden_states)
|
217 |
+
return hidden_states
|
218 |
+
|
219 |
+
|
220 |
+
class InternVisionEncoderLayer(nn.Module):
|
221 |
+
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
|
222 |
+
super().__init__()
|
223 |
+
self.embed_dim = config.hidden_size
|
224 |
+
self.intermediate_size = config.intermediate_size
|
225 |
+
self.norm_type = config.norm_type
|
226 |
+
|
227 |
+
self.attn = InternAttention(config)
|
228 |
+
self.mlp = InternMLP(config)
|
229 |
+
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
230 |
+
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
231 |
+
|
232 |
+
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
233 |
+
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
234 |
+
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
235 |
+
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
236 |
+
|
237 |
+
def forward(
|
238 |
+
self,
|
239 |
+
hidden_states: torch.Tensor,
|
240 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
241 |
+
"""
|
242 |
+
Args:
|
243 |
+
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
244 |
+
"""
|
245 |
+
hidden_states = hidden_states + self.drop_path1(
|
246 |
+
self.attn(self.norm1(hidden_states)) * self.ls1
|
247 |
+
)
|
248 |
+
|
249 |
+
hidden_states = hidden_states + self.drop_path2(
|
250 |
+
self.mlp(self.norm2(hidden_states)) * self.ls2
|
251 |
+
)
|
252 |
+
|
253 |
+
return hidden_states
|
254 |
+
|
255 |
+
|
256 |
+
class InternVisionEncoder(nn.Module):
|
257 |
+
"""
|
258 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
259 |
+
[`InternEncoderLayer`].
|
260 |
+
|
261 |
+
Args:
|
262 |
+
config (`InternConfig`):
|
263 |
+
The corresponding vision configuration for the `InternEncoder`.
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self, config: InternVisionConfig):
|
267 |
+
super().__init__()
|
268 |
+
self.config = config
|
269 |
+
# stochastic depth decay rule
|
270 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
271 |
+
self.layers = nn.ModuleList(
|
272 |
+
[InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]
|
273 |
+
)
|
274 |
+
self.gradient_checkpointing = True
|
275 |
+
|
276 |
+
def forward(
|
277 |
+
self,
|
278 |
+
inputs_embeds,
|
279 |
+
output_hidden_states: Optional[bool] = None,
|
280 |
+
return_dict: Optional[bool] = None,
|
281 |
+
) -> Union[Tuple, BaseModelOutput]:
|
282 |
+
r"""
|
283 |
+
Args:
|
284 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
285 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
286 |
+
output_hidden_states (`bool`, *optional*):
|
287 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
288 |
+
for more detail.
|
289 |
+
return_dict (`bool`, *optional*):
|
290 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
291 |
+
"""
|
292 |
+
output_hidden_states = (
|
293 |
+
output_hidden_states
|
294 |
+
if output_hidden_states is not None
|
295 |
+
else self.config.output_hidden_states
|
296 |
+
)
|
297 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
298 |
+
|
299 |
+
encoder_states = () if output_hidden_states else None
|
300 |
+
hidden_states = inputs_embeds
|
301 |
+
|
302 |
+
for idx, encoder_layer in enumerate(self.layers):
|
303 |
+
if output_hidden_states:
|
304 |
+
encoder_states = encoder_states + (hidden_states,)
|
305 |
+
if self.gradient_checkpointing and self.training:
|
306 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(encoder_layer, hidden_states)
|
307 |
+
else:
|
308 |
+
layer_outputs = encoder_layer(
|
309 |
+
hidden_states,
|
310 |
+
)
|
311 |
+
hidden_states = layer_outputs
|
312 |
+
|
313 |
+
if output_hidden_states:
|
314 |
+
encoder_states = encoder_states + (hidden_states,)
|
315 |
+
|
316 |
+
if not return_dict:
|
317 |
+
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
318 |
+
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
|
319 |
+
|
320 |
+
|
321 |
+
class InternVisionModel(PreTrainedModel):
|
322 |
+
main_input_name = "pixel_values"
|
323 |
+
config_class = InternVisionConfig
|
324 |
+
_no_split_modules = ["InternVisionEncoderLayer"]
|
325 |
+
|
326 |
+
def __init__(self, config: InternVisionConfig):
|
327 |
+
super().__init__(config)
|
328 |
+
self.config = config
|
329 |
+
|
330 |
+
self.embeddings = InternVisionEmbeddings(config)
|
331 |
+
self.encoder = InternVisionEncoder(config)
|
332 |
+
|
333 |
+
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
334 |
+
pos_emb = self.embeddings.position_embedding
|
335 |
+
_, num_positions, embed_dim = pos_emb.shape
|
336 |
+
cls_emb = pos_emb[:, :1, :]
|
337 |
+
pos_emb = (
|
338 |
+
pos_emb[:, 1:, :]
|
339 |
+
.reshape(1, old_size // patch_size, old_size // patch_size, -1)
|
340 |
+
.permute(0, 3, 1, 2)
|
341 |
+
)
|
342 |
+
pos_emb = F.interpolate(
|
343 |
+
pos_emb.float(), size=new_size // patch_size, mode="bicubic", align_corners=False
|
344 |
+
)
|
345 |
+
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
346 |
+
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
347 |
+
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
348 |
+
self.embeddings.image_size = new_size
|
349 |
+
logger.info("Resized position embeddings from {} to {}".format(old_size, new_size))
|
350 |
+
|
351 |
+
def get_input_embeddings(self):
|
352 |
+
return self.embeddings
|
353 |
+
|
354 |
+
def forward(
|
355 |
+
self,
|
356 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
357 |
+
output_hidden_states: Optional[bool] = None,
|
358 |
+
return_dict: Optional[bool] = None,
|
359 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
360 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
361 |
+
output_hidden_states = (
|
362 |
+
output_hidden_states
|
363 |
+
if output_hidden_states is not None
|
364 |
+
else self.config.output_hidden_states
|
365 |
+
)
|
366 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
367 |
+
|
368 |
+
if pixel_values is None and pixel_embeds is None:
|
369 |
+
raise ValueError("You have to specify pixel_values or pixel_embeds")
|
370 |
+
|
371 |
+
if pixel_embeds is not None:
|
372 |
+
hidden_states = pixel_embeds
|
373 |
+
else:
|
374 |
+
if len(pixel_values.shape) == 4:
|
375 |
+
hidden_states = self.embeddings(pixel_values)
|
376 |
+
else:
|
377 |
+
raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
|
378 |
+
encoder_outputs = self.encoder(
|
379 |
+
inputs_embeds=hidden_states,
|
380 |
+
output_hidden_states=output_hidden_states,
|
381 |
+
return_dict=return_dict,
|
382 |
+
)
|
383 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
384 |
+
pooled_output = last_hidden_state[:, 0, :]
|
385 |
+
|
386 |
+
if not return_dict:
|
387 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
388 |
+
|
389 |
+
return BaseModelOutputWithPooling(
|
390 |
+
last_hidden_state=last_hidden_state,
|
391 |
+
pooler_output=pooled_output,
|
392 |
+
hidden_states=encoder_outputs.hidden_states,
|
393 |
+
attentions=encoder_outputs.attentions,
|
394 |
+
)
|
vita/model/multimodal_encoder/siglip/siglip_encoder.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
|
4 |
+
|
5 |
+
from vita.util.s2wrapper import forward as multiscale_forward
|
6 |
+
|
7 |
+
|
8 |
+
class SiglipVisionTower(nn.Module):
|
9 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.is_loaded = False
|
13 |
+
|
14 |
+
self.vision_tower_name = vision_tower
|
15 |
+
self.select_layer = -2
|
16 |
+
|
17 |
+
if not delay_load:
|
18 |
+
self.load_model()
|
19 |
+
else:
|
20 |
+
self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
|
21 |
+
|
22 |
+
def load_model(self):
|
23 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
|
24 |
+
self.image_processor.crop_size = self.image_processor.size
|
25 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
|
26 |
+
self.vision_tower.requires_grad_(False)
|
27 |
+
|
28 |
+
self.is_loaded = True
|
29 |
+
|
30 |
+
def feature_select(self, image_forward_outs):
|
31 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
32 |
+
|
33 |
+
return image_features
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def forward(self, images):
|
37 |
+
if type(images) is list:
|
38 |
+
image_features = []
|
39 |
+
for image in images:
|
40 |
+
image_forward_out = self.vision_tower(
|
41 |
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
42 |
+
output_hidden_states=True,
|
43 |
+
)
|
44 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
45 |
+
image_features.append(image_feature)
|
46 |
+
else:
|
47 |
+
image_forward_outs = self.vision_tower(
|
48 |
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
49 |
+
)
|
50 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
51 |
+
|
52 |
+
return image_features
|
53 |
+
|
54 |
+
@property
|
55 |
+
def dummy_feature(self):
|
56 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
57 |
+
|
58 |
+
@property
|
59 |
+
def dtype(self):
|
60 |
+
return self.vision_tower.dtype
|
61 |
+
|
62 |
+
@property
|
63 |
+
def device(self):
|
64 |
+
return self.vision_tower.device
|
65 |
+
|
66 |
+
@property
|
67 |
+
def config(self):
|
68 |
+
if self.is_loaded:
|
69 |
+
return self.vision_tower.config
|
70 |
+
else:
|
71 |
+
return self.cfg_only
|
72 |
+
|
73 |
+
@property
|
74 |
+
def hidden_size(self):
|
75 |
+
return self.config.hidden_size
|
76 |
+
|
77 |
+
@property
|
78 |
+
def num_patches(self):
|
79 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
80 |
+
|
81 |
+
|
82 |
+
class SiglipVisionTowerS2(SiglipVisionTower):
|
83 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
84 |
+
self.s2_scales = getattr(args, "s2_scales", "384,768,1152")
|
85 |
+
self.s2_scales = list(map(int, self.s2_scales.split(",")))
|
86 |
+
self.s2_scales.sort()
|
87 |
+
self.s2_split_size = self.s2_scales[0]
|
88 |
+
self.s2_image_size = self.s2_scales[-1]
|
89 |
+
|
90 |
+
super().__init__(vision_tower, args, delay_load)
|
91 |
+
|
92 |
+
self.multiscale_forward = multiscale_forward
|
93 |
+
|
94 |
+
if not delay_load:
|
95 |
+
self.image_processor.size["height"] = self.image_processor.size[
|
96 |
+
"width"
|
97 |
+
] = self.s2_image_size
|
98 |
+
self.image_processor.crop_size["height"] = self.image_processor.crop_size[
|
99 |
+
"width"
|
100 |
+
] = self.s2_image_size
|
101 |
+
|
102 |
+
def load_model(self):
|
103 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
|
104 |
+
self.image_processor.crop_size = self.image_processor.size
|
105 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
|
106 |
+
self.vision_tower.requires_grad_(False)
|
107 |
+
|
108 |
+
self.image_processor.size["height"] = self.image_processor.size[
|
109 |
+
"width"
|
110 |
+
] = self.s2_image_size
|
111 |
+
self.image_processor.crop_size["height"] = self.image_processor.crop_size[
|
112 |
+
"width"
|
113 |
+
] = self.s2_image_size
|
114 |
+
|
115 |
+
self.is_loaded = True
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def forward_feature(self, images):
|
119 |
+
image_forward_outs = self.vision_tower(
|
120 |
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
121 |
+
)
|
122 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
123 |
+
return image_features
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def forward(self, images):
|
127 |
+
if type(images) is list:
|
128 |
+
image_features = []
|
129 |
+
for image in images:
|
130 |
+
image_feature = self.multiscale_forward(
|
131 |
+
self.forward_feature,
|
132 |
+
image.unsqueeze(0),
|
133 |
+
img_sizes=self.s2_scales,
|
134 |
+
max_split_size=self.s2_split_size,
|
135 |
+
)
|
136 |
+
image_features.append(image_feature)
|
137 |
+
else:
|
138 |
+
image_features = self.multiscale_forward(
|
139 |
+
self.forward_feature,
|
140 |
+
images,
|
141 |
+
img_sizes=self.s2_scales,
|
142 |
+
max_split_size=self.s2_split_size,
|
143 |
+
)
|
144 |
+
|
145 |
+
return image_features
|
146 |
+
|
147 |
+
@property
|
148 |
+
def hidden_size(self):
|
149 |
+
return self.config.hidden_size * len(self.s2_scales)
|
vita/model/multimodal_encoder/whale/adapter.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn.utils.rnn import pad_sequence
|
4 |
+
|
5 |
+
|
6 |
+
class CNNAdapter(torch.nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
enc_out_dim: int = 512,
|
10 |
+
llm_embed_dim: int = 4096,
|
11 |
+
kernel_size: int = 5,
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
|
16 |
+
self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
|
17 |
+
self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
|
18 |
+
self.relu1 = nn.ReLU()
|
19 |
+
|
20 |
+
self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
|
21 |
+
self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 1, 0)
|
22 |
+
self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
|
23 |
+
self.relu2 = nn.ReLU()
|
24 |
+
|
25 |
+
self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
|
26 |
+
|
27 |
+
def forward(self, x, mask_pad):
|
28 |
+
"""
|
29 |
+
x: B, T, enc_out_dim
|
30 |
+
mask: (B, T) or (B, 1, T)
|
31 |
+
"""
|
32 |
+
x = x.transpose(1, 2) # B, channels, T
|
33 |
+
|
34 |
+
# mask batch padding
|
35 |
+
if mask_pad.size(2) > 0: # time > 0
|
36 |
+
x.masked_fill_(~mask_pad, 0.0)
|
37 |
+
|
38 |
+
x = self.left_padding1(x)
|
39 |
+
x = self.conv1d1(x)
|
40 |
+
x = self.bn1(x)
|
41 |
+
x = self.relu1(x)
|
42 |
+
|
43 |
+
x = self.left_padding2(x)
|
44 |
+
x = self.conv1d2(x)
|
45 |
+
x = self.bn2(x)
|
46 |
+
x = self.relu2(x)
|
47 |
+
|
48 |
+
x = x.transpose(1, 2)
|
49 |
+
x = self.project(x)
|
50 |
+
|
51 |
+
return x, mask_pad
|
52 |
+
|
53 |
+
|
54 |
+
class LinearAdapter(torch.nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
enc_out_dim: int = 512,
|
58 |
+
llm_embed_dim: int = 4096,
|
59 |
+
):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
self.adpter = torch.nn.Linear(enc_out_dim, llm_embed_dim)
|
63 |
+
|
64 |
+
def forward(self, x, mask_pad):
|
65 |
+
return self.adpter(x), mask_pad
|
66 |
+
|
67 |
+
|
68 |
+
class CNNSubsampling(torch.nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
enc_out_dim: int = 512,
|
72 |
+
llm_embed_dim: int = 4096,
|
73 |
+
kernel_size: int = 5,
|
74 |
+
activation_func: str = "relu",
|
75 |
+
norm: str = "batch",
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
#if enc_out_dim * 4 < llm_embed_dim:
|
80 |
+
if enc_out_dim * 4 < 0:
|
81 |
+
self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
|
82 |
+
self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
|
83 |
+
self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
|
84 |
+
self.relu1 = nn.ReLU()
|
85 |
+
|
86 |
+
self.left_padding2 = nn.ConstantPad1d((0, kernel_size - 1), 0.0)
|
87 |
+
self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 2, 0)
|
88 |
+
self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
|
89 |
+
self.relu2 = nn.ReLU()
|
90 |
+
|
91 |
+
self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
|
92 |
+
self.cnn_num = 2
|
93 |
+
else:
|
94 |
+
self.left_padding2 = nn.ConstantPad1d((0, kernel_size - 1), 0.0)
|
95 |
+
self.conv1d2 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 2, 0)
|
96 |
+
if norm == "batch":
|
97 |
+
self.bn2 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
|
98 |
+
elif norm == "layer":
|
99 |
+
self.bn2 = nn.LayerNorm(2 * enc_out_dim, eps=1e-3)
|
100 |
+
if activation_func == "gelu":
|
101 |
+
self.relu2 = nn.GELU()
|
102 |
+
else:
|
103 |
+
self.relu2 = nn.ReLU()
|
104 |
+
|
105 |
+
self.project = nn.Linear(2 * enc_out_dim, llm_embed_dim)
|
106 |
+
self.cnn_num = 1
|
107 |
+
|
108 |
+
def forward(self, x, mask_pad):
|
109 |
+
"""
|
110 |
+
x: B, T, enc_out_dim
|
111 |
+
mask: (B, T) or (B, 1, T)
|
112 |
+
"""
|
113 |
+
x = x.transpose(1, 2) # B, channels, T
|
114 |
+
|
115 |
+
# mask batch padding
|
116 |
+
if mask_pad.size(2) > 0: # time > 0
|
117 |
+
x.masked_fill_(~mask_pad, 0.0)
|
118 |
+
|
119 |
+
if self.cnn_num == 2:
|
120 |
+
x = self.left_padding1(x)
|
121 |
+
x = self.conv1d1(x)
|
122 |
+
x = self.bn1(x)
|
123 |
+
x = self.relu1(x)
|
124 |
+
|
125 |
+
x = self.left_padding2(x)
|
126 |
+
x = self.conv1d2(x)
|
127 |
+
if isinstance(self.bn2, nn.LayerNorm):
|
128 |
+
x = x.transpose(1, 2)
|
129 |
+
x = self.bn2(x)
|
130 |
+
if isinstance(self.bn2, nn.LayerNorm):
|
131 |
+
x = x.transpose(1, 2)
|
132 |
+
x = self.relu2(x)
|
133 |
+
|
134 |
+
x = x.transpose(1, 2)
|
135 |
+
x = self.project(x)
|
136 |
+
|
137 |
+
return x, mask_pad[:, :, 0::2]
|
vita/model/multimodal_encoder/whale/cmvn.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
class GlobalCMVN(torch.nn.Module):
|
8 |
+
def __init__(self, mean: torch.Tensor, istd: torch.Tensor, norm_var: bool = True):
|
9 |
+
"""
|
10 |
+
Args:
|
11 |
+
mean (torch.Tensor): mean stats
|
12 |
+
istd (torch.Tensor): inverse std, std which is 1.0 / std
|
13 |
+
"""
|
14 |
+
super().__init__()
|
15 |
+
assert mean.shape == istd.shape
|
16 |
+
self.norm_var = norm_var
|
17 |
+
# The buffer can be accessed from this module using self.mean
|
18 |
+
self.register_buffer("mean", mean)
|
19 |
+
self.register_buffer("istd", istd)
|
20 |
+
|
21 |
+
def forward(self, x: torch.Tensor):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
x (torch.Tensor): (batch, max_len, feat_dim)
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
(torch.Tensor): normalized feature
|
28 |
+
"""
|
29 |
+
x = x - self.mean
|
30 |
+
if self.norm_var:
|
31 |
+
x = x * self.istd
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
def load_cmvn_json(json_cmvn_file):
|
36 |
+
with open(json_cmvn_file) as f:
|
37 |
+
cmvn_json = json.load(f)
|
38 |
+
|
39 |
+
avg = cmvn_json["mean_stat"]
|
40 |
+
var = cmvn_json["var_stat"]
|
41 |
+
count = cmvn_json["frame_num"]
|
42 |
+
for i in range(len(avg)):
|
43 |
+
avg[i] /= count
|
44 |
+
var[i] = var[i] / count - avg[i] * avg[i]
|
45 |
+
if var[i] < 1.0e-20:
|
46 |
+
var[i] = 1.0e-20
|
47 |
+
var[i] = 1.0 / math.sqrt(var[i])
|
48 |
+
cmvn = np.array([avg, var])
|
49 |
+
return cmvn
|
50 |
+
|
51 |
+
|
52 |
+
def load_cmvn_kaldi(kaldi_cmvn_file):
|
53 |
+
avg = []
|
54 |
+
var = []
|
55 |
+
with open(kaldi_cmvn_file, "r") as file:
|
56 |
+
# kaldi binary file start with '\0B'
|
57 |
+
if file.read(2) == "\0B":
|
58 |
+
logging.error(
|
59 |
+
"kaldi cmvn binary file is not supported, please "
|
60 |
+
)
|
61 |
+
sys.exit(1)
|
62 |
+
file.seek(0)
|
63 |
+
arr = file.read().split()
|
64 |
+
assert arr[0] == "["
|
65 |
+
assert arr[-2] == "0"
|
66 |
+
assert arr[-1] == "]"
|
67 |
+
feat_dim = int((len(arr) - 2 - 2) / 2)
|
68 |
+
for i in range(1, feat_dim + 1):
|
69 |
+
avg.append(float(arr[i]))
|
70 |
+
count = float(arr[feat_dim + 1])
|
71 |
+
for i in range(feat_dim + 2, 2 * feat_dim + 2):
|
72 |
+
var.append(float(arr[i]))
|
73 |
+
|
74 |
+
for i in range(len(avg)):
|
75 |
+
avg[i] /= count
|
76 |
+
var[i] = var[i] / count - avg[i] * avg[i]
|
77 |
+
if var[i] < 1.0e-20:
|
78 |
+
var[i] = 1.0e-20
|
79 |
+
var[i] = 1.0 / math.sqrt(var[i])
|
80 |
+
cmvn = np.array([avg, var])
|
81 |
+
return cmvn
|
82 |
+
|
83 |
+
|
84 |
+
def load_cmvn(filename, is_json):
|
85 |
+
if is_json:
|
86 |
+
file = load_cmvn_json(filename)
|
87 |
+
else:
|
88 |
+
file = load_cmvn_kaldi(filename)
|
89 |
+
return file[0], file[1]
|
vita/model/multimodal_encoder/whale/init_model.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Binbin Zhang ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Dict, List, Optional, Tuple
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
import torchaudio
|
21 |
+
import torchaudio.compliance.kaldi as kaldi
|
22 |
+
|
23 |
+
from .adapter import CNNAdapter, CNNSubsampling, LinearAdapter
|
24 |
+
from .cmvn import GlobalCMVN, load_cmvn
|
25 |
+
from .module.encoder.encoder import whaleEncoder
|
26 |
+
|
27 |
+
|
28 |
+
class audioEncoderProcessor:
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
dataset_conf: dict = None,
|
32 |
+
):
|
33 |
+
self.dataset_conf = dataset_conf
|
34 |
+
|
35 |
+
def process(self, wav_path):
|
36 |
+
try:
|
37 |
+
waveform, sample_rate = torchaudio.load(wav_path)
|
38 |
+
except Exception as e:
|
39 |
+
print(f"cannot open {wav_path}!!!!!!!!!!!!!!!!")
|
40 |
+
if sample_rate != self.dataset_conf["resample_conf"]["resample_rate"]:
|
41 |
+
waveform = torchaudio.transforms.Resample(
|
42 |
+
orig_freq=sample_rate, new_freq=self.dataset_conf["resample_conf"]["resample_rate"]
|
43 |
+
)(waveform)
|
44 |
+
sample_rate = self.dataset_conf['resample_conf']['resample_rate']
|
45 |
+
|
46 |
+
waveform = waveform * (1 << 15)
|
47 |
+
# Only keep key, feat, label
|
48 |
+
mat = kaldi.fbank(
|
49 |
+
waveform,
|
50 |
+
num_mel_bins=self.dataset_conf["fbank_conf"]["num_mel_bins"],
|
51 |
+
frame_length=self.dataset_conf["fbank_conf"]["frame_length"],
|
52 |
+
frame_shift=self.dataset_conf["fbank_conf"]["frame_shift"],
|
53 |
+
dither=self.dataset_conf["fbank_conf"]["dither"],
|
54 |
+
energy_floor=0.0,
|
55 |
+
sample_frequency=sample_rate,
|
56 |
+
)
|
57 |
+
attn_mask = torch.ones(mat.shape[0])
|
58 |
+
attn_mask = attn_mask[2::2][2::2][0::2]
|
59 |
+
|
60 |
+
return mat, attn_mask.shape[0]
|
61 |
+
|
62 |
+
|
63 |
+
class audioEncoder(torch.nn.Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
encoder: torch.nn.Module,
|
67 |
+
llm_path: str,
|
68 |
+
freeze_llm: bool = True,
|
69 |
+
enc_out_dim: int = 512,
|
70 |
+
llm_embed_dim: int = 4096,
|
71 |
+
kernel_size: int = 3,
|
72 |
+
IGNORE_ID: int = -100,
|
73 |
+
adpter_type: str = "cnn",
|
74 |
+
add_audio_bos_eos: bool = False,
|
75 |
+
task_num: int = 10,
|
76 |
+
task_before_audio: bool = False,
|
77 |
+
task_type: str = "prompt",
|
78 |
+
freeze_encoder: bool = False,
|
79 |
+
freeze_adpter: bool = False,
|
80 |
+
audio_prompt_finetune: bool = False,
|
81 |
+
audio_prompt_num: int = 25,
|
82 |
+
activation_func: str = "relu",
|
83 |
+
norm: str = "batch",
|
84 |
+
chat_template=None,
|
85 |
+
):
|
86 |
+
super().__init__()
|
87 |
+
self.encoder = encoder
|
88 |
+
|
89 |
+
self.enc_out_dim = enc_out_dim
|
90 |
+
self.llm_embed_dim = llm_embed_dim
|
91 |
+
self.IGNORE_ID = IGNORE_ID
|
92 |
+
self.add_audio_bos_eos = add_audio_bos_eos
|
93 |
+
self.task_before_audio = task_before_audio
|
94 |
+
self.task_type = task_type
|
95 |
+
self.freeze_encoder = freeze_encoder
|
96 |
+
self.freeze_adpter = freeze_adpter
|
97 |
+
self.audio_prompt_finetune = audio_prompt_finetune
|
98 |
+
self.audio_prompt_num = audio_prompt_num
|
99 |
+
|
100 |
+
if adpter_type == "cnn":
|
101 |
+
self.adpter = CNNAdapter(enc_out_dim, llm_embed_dim, kernel_size)
|
102 |
+
elif adpter_type == "linear":
|
103 |
+
self.adpter = LinearAdapter(enc_out_dim, llm_embed_dim)
|
104 |
+
elif adpter_type == "subsampling":
|
105 |
+
self.adpter = CNNSubsampling(
|
106 |
+
enc_out_dim, llm_embed_dim, kernel_size, activation_func, norm
|
107 |
+
)
|
108 |
+
|
109 |
+
if self.freeze_encoder:
|
110 |
+
self.encoder.eval()
|
111 |
+
for (name, param) in self.encoder.named_parameters():
|
112 |
+
param.requires_grad = False
|
113 |
+
if self.freeze_adpter:
|
114 |
+
self.adpter.eval()
|
115 |
+
for (name, param) in self.adpter.named_parameters():
|
116 |
+
param.requires_grad = False
|
117 |
+
|
118 |
+
if self.audio_prompt_finetune:
|
119 |
+
self.prompt_embeddings = nn.Embedding(audio_prompt_num, llm_embed_dim)
|
120 |
+
self.prompt_ids = torch.tensor([i for i in range(audio_prompt_num)]).long()
|
121 |
+
|
122 |
+
def forward(
|
123 |
+
self,
|
124 |
+
speech: torch.Tensor,
|
125 |
+
speech_lengths: torch.Tensor,
|
126 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
127 |
+
|
128 |
+
speech = speech.to(next(self.parameters()).dtype)
|
129 |
+
|
130 |
+
# 1. Encoder
|
131 |
+
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
|
132 |
+
inputs_embeds, encoder_mask = self.adpter(encoder_out, encoder_mask) # B, T, D
|
133 |
+
attention_mask = encoder_mask.squeeze(1) # B, T
|
134 |
+
assert inputs_embeds.size(1) == attention_mask.size(1)
|
135 |
+
|
136 |
+
# audio bos/eos
|
137 |
+
if self.add_audio_bos_eos:
|
138 |
+
inputs_embeds, attention_mask, target = self._add_bos_eos(
|
139 |
+
"audio", "/audio", inputs_embeds, attention_mask, target
|
140 |
+
)
|
141 |
+
|
142 |
+
B, _, _ = inputs_embeds.shape
|
143 |
+
if self.audio_prompt_finetune:
|
144 |
+
prompt_ids = self.prompt_ids.repeat(B, 1).to(inputs_embeds.device)
|
145 |
+
prompt_embeds = self.prompt_embeddings(
|
146 |
+
prompt_ids.to(inputs_embeds.device)) # B, 5, D
|
147 |
+
inputs_embeds = torch.cat((prompt_embeds, inputs_embeds), 1) # B, (T+5), D
|
148 |
+
|
149 |
+
outputs = {
|
150 |
+
"inputs_embeds": inputs_embeds,
|
151 |
+
"attention_mask": attention_mask,
|
152 |
+
}
|
153 |
+
|
154 |
+
return outputs
|
155 |
+
|
156 |
+
def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None):
|
157 |
+
B = len(inputs_embeds)
|
158 |
+
bos_embed = self.task_embeddings(
|
159 |
+
torch.full([B, 1], self.task_ids[bos]).to(inputs_embeds.device)
|
160 |
+
) # B, 1, D
|
161 |
+
eos_embed = self.task_embeddings(
|
162 |
+
torch.full([B, 1], self.task_ids[eos]).to(inputs_embeds.device)
|
163 |
+
) # B, 1, D
|
164 |
+
bos_eos_target = torch.full([B, 2], self.IGNORE_ID).to(inputs_embeds.device) # B, 2
|
165 |
+
bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device) # B, 1
|
166 |
+
|
167 |
+
inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1) # B, (1+T), D
|
168 |
+
inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1) # B, (1+T+1), D
|
169 |
+
attention_mask = torch.cat((bos_eos_mask, attention_mask), 1) # B, (1+T)
|
170 |
+
attention_mask = torch.cat((attention_mask, bos_eos_mask), 1) # B, (1+T+1)
|
171 |
+
if target is not None:
|
172 |
+
target = torch.cat((target, bos_eos_target), 1) # B, (T+2), D
|
173 |
+
|
174 |
+
return inputs_embeds, attention_mask, target
|
175 |
+
|
176 |
+
|
177 |
+
def init_model(configs):
|
178 |
+
if configs["cmvn_file"] is not None:
|
179 |
+
mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"])
|
180 |
+
global_cmvn = GlobalCMVN(torch.from_numpy(mean).float(), torch.from_numpy(istd).float())
|
181 |
+
else:
|
182 |
+
global_cmvn = None
|
183 |
+
|
184 |
+
input_dim = configs["input_dim"]
|
185 |
+
|
186 |
+
encoder = whaleEncoder(input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"])
|
187 |
+
model = audioEncoder(encoder=encoder, **configs["model_conf"])
|
188 |
+
processor = audioEncoderProcessor(dataset_conf=configs["dataset_conf"])
|
189 |
+
|
190 |
+
model.audio_processor = processor
|
191 |
+
|
192 |
+
return model
|
vita/model/multimodal_encoder/whale/module/component/mamba.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Encoder self-attention layer definition."""
|
2 |
+
|
3 |
+
import math
|
4 |
+
import pdb
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, strtobool
|
13 |
+
|
14 |
+
try:
|
15 |
+
from mamba_ssm.modules.mamba_simple import Mamba, Block
|
16 |
+
from mamba_ssm.models.mixer_seq_simple import _init_weights
|
17 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm
|
18 |
+
except ImportError:
|
19 |
+
print("Please install mamba_ssm to use MambaSSM component.")
|
20 |
+
|
21 |
+
|
22 |
+
class MambaBlock(nn.Module):
|
23 |
+
def __init__(self, in_channels, n_layer=1, d_state=16, d_conv=4, expand=4, bidirectional=False):
|
24 |
+
super(MambaBlock, self).__init__()
|
25 |
+
self.forward_blocks = nn.ModuleList([])
|
26 |
+
self.forward_norm_f = RMSNorm(in_channels, eps=1e-5)
|
27 |
+
for i in range(n_layer):
|
28 |
+
self.forward_blocks.append(
|
29 |
+
Block(
|
30 |
+
in_channels,
|
31 |
+
mixer_cls=partial(
|
32 |
+
Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand
|
33 |
+
),
|
34 |
+
norm_cls=partial(RMSNorm, eps=1e-5),
|
35 |
+
fused_add_norm=True,
|
36 |
+
residual_in_fp32=True,
|
37 |
+
)
|
38 |
+
)
|
39 |
+
if bidirectional:
|
40 |
+
self.backward_blocks = nn.ModuleList([])
|
41 |
+
for i in range(n_layer):
|
42 |
+
self.backward_blocks.append(
|
43 |
+
Block(
|
44 |
+
in_channels,
|
45 |
+
mixer_cls=partial(
|
46 |
+
Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand
|
47 |
+
),
|
48 |
+
norm_cls=partial(RMSNorm, eps=1e-5),
|
49 |
+
fused_add_norm=True,
|
50 |
+
residual_in_fp32=True,
|
51 |
+
)
|
52 |
+
)
|
53 |
+
self.backward_norm_f = RMSNorm(in_channels, eps=1e-5)
|
54 |
+
else:
|
55 |
+
self.backward_blocks = None
|
56 |
+
|
57 |
+
self.apply(partial(_init_weights, n_layer=n_layer))
|
58 |
+
|
59 |
+
def forward(self, input):
|
60 |
+
for_residual = None
|
61 |
+
forward_f = input.clone()
|
62 |
+
for block in self.forward_blocks:
|
63 |
+
forward_f, for_residual = block(forward_f, for_residual, inference_params=None)
|
64 |
+
residual = (forward_f + for_residual) if for_residual is not None else forward_f
|
65 |
+
residual = self.forward_norm_f(residual)
|
66 |
+
|
67 |
+
if self.backward_blocks is not None:
|
68 |
+
back_residual = None
|
69 |
+
backward_f = torch.flip(input, [1])
|
70 |
+
for block in self.backward_blocks:
|
71 |
+
backward_f, back_residual = block(backward_f, back_residual, inference_params=None)
|
72 |
+
back_residual = (
|
73 |
+
(backward_f + back_residual) if back_residual is not None else backward_f
|
74 |
+
)
|
75 |
+
|
76 |
+
back_residual = torch.flip(back_residual, [1])
|
77 |
+
back_residual = self.backward_norm_f(back_residual)
|
78 |
+
residual = torch.cat([residual, back_residual], -1)
|
79 |
+
|
80 |
+
return residual
|
81 |
+
|
82 |
+
|
83 |
+
class MambaSSM(torch.nn.Module):
|
84 |
+
@staticmethod
|
85 |
+
def add_arguments(group):
|
86 |
+
"""Add TDNN common arguments."""
|
87 |
+
group.add_argument(
|
88 |
+
"--mamba-num-layers", default=4, type=int, help="Output dim of MambaSSM."
|
89 |
+
)
|
90 |
+
group.add_argument(
|
91 |
+
"--mamba-input-dim", default=256, type=int, help="Input dim of MambaSSM."
|
92 |
+
)
|
93 |
+
group.add_argument(
|
94 |
+
"--mamba-output-dim", default=256, type=int, help="Output dim of MambaSSM."
|
95 |
+
)
|
96 |
+
group.add_argument("--mamba-d-state", default=16, type=int, help="d-state of MambaSSM.")
|
97 |
+
group.add_argument("--mamba-d-conv", default=4, type=int, help="d-conv of MambaSSM.")
|
98 |
+
group.add_argument("--mamba-expand", default=4, type=int, help="expand of MambaSSM.")
|
99 |
+
return group
|
100 |
+
|
101 |
+
def __init__(self, args):
|
102 |
+
"""Construct an Encoder object."""
|
103 |
+
super(MambaSSM, self).__init__()
|
104 |
+
self.mamb_num_layers = args.mamba_num_layers
|
105 |
+
self.mamba_input_dim = args.mamba_input_dim
|
106 |
+
self.mamba_output_dim = args.mamba_output_dim
|
107 |
+
self.mamba_d_state = args.mamba_d_state
|
108 |
+
self.mamba_d_conv = args.mamba_d_conv
|
109 |
+
self.mamba_expand = args.mamba_expand
|
110 |
+
|
111 |
+
self.mamba = MambaBlock(
|
112 |
+
self.mamba_input_dim,
|
113 |
+
self.mamb_num_layers,
|
114 |
+
self.mamba_d_state,
|
115 |
+
self.mamba_d_conv,
|
116 |
+
self.mamba_expand,
|
117 |
+
)
|
118 |
+
|
119 |
+
@torch.jit.unused
|
120 |
+
def forward(self, xs, ilens=None, masks=None):
|
121 |
+
"""Embed positions in tensor.
|
122 |
+
|
123 |
+
:param torch.Tensor xs: input tensor
|
124 |
+
:param torch.Tensor masks: input mask
|
125 |
+
:return: position embedded tensor and mask
|
126 |
+
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
127 |
+
"""
|
128 |
+
|
129 |
+
xs_out = self.mamba(xs)
|
130 |
+
|
131 |
+
return xs_out.to(xs.dtype), ilens, masks
|
vita/model/multimodal_encoder/whale/module/component/subsampling.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
|
5 |
+
class BaseSubsampling(torch.nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
self.subsampling_rate = 1
|
9 |
+
self.right_context = 0
|
10 |
+
|
11 |
+
def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
|
12 |
+
return self.pos_enc.position_encoding(offset, size)
|
13 |
+
|
14 |
+
|
15 |
+
class Conv2dSubsampling4(BaseSubsampling):
|
16 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
17 |
+
|
18 |
+
Args:
|
19 |
+
idim (int): Input dimension.
|
20 |
+
odim (int): Output dimension.
|
21 |
+
dropout_rate (float): Dropout rate.
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float):
|
26 |
+
"""Construct an Conv2dSubsampling4 object."""
|
27 |
+
super().__init__()
|
28 |
+
self.conv = torch.nn.Sequential(
|
29 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
30 |
+
torch.nn.ReLU(),
|
31 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
32 |
+
torch.nn.ReLU(),
|
33 |
+
)
|
34 |
+
self.out = torch.nn.Sequential(torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
|
35 |
+
self.right_context = 6
|
36 |
+
self.subsampling_rate = 4
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
39 |
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
40 |
+
x = self.conv(x)
|
41 |
+
b, c, t, f = x.size()
|
42 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
43 |
+
return x, x_mask[:, :, 2::2][:, :, 2::2]
|
44 |
+
|
45 |
+
|
46 |
+
class Subsampling(torch.nn.Module):
|
47 |
+
@staticmethod
|
48 |
+
def add_arguments(group):
|
49 |
+
"""Add Subsampling common arguments."""
|
50 |
+
group.add_argument("--subsampling-rate", default=4, type=int)
|
51 |
+
group.add_argument("--subsampling-input-dim", default=256, type=int)
|
52 |
+
group.add_argument("--subsampling-output-dim", default=256, type=int)
|
53 |
+
group.add_argument("--subsampling-dropout-rate", default=0.1, type=float)
|
54 |
+
|
55 |
+
return group
|
56 |
+
|
57 |
+
def __init__(self, args):
|
58 |
+
super().__init__()
|
59 |
+
self.subsampling_rate = args.subsampling_rate
|
60 |
+
self.subsampling_input_dim = args.subsampling_input_dim
|
61 |
+
self.subsampling_output_dim = args.subsampling_output_dim
|
62 |
+
self.subsampling_dropout_rate = args.subsampling_dropout_rate
|
63 |
+
|
64 |
+
if self.subsampling_rate == 4:
|
65 |
+
self.core = Conv2dSubsampling4(
|
66 |
+
self.subsampling_input_dim,
|
67 |
+
self.subsampling_output_dim,
|
68 |
+
self.subsampling_dropout_rate,
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, xs, ilens, masks):
|
72 |
+
xs, masks = self.core(xs, masks)
|
73 |
+
ilens = masks.squeeze(1).sum(1)
|
74 |
+
return xs, ilens, masks
|
vita/model/multimodal_encoder/whale/module/component/transformer.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Encoder self-attention layer definition."""
|
2 |
+
|
3 |
+
import math
|
4 |
+
import pdb
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from vita.model.multimodal_encoder.whale.module.layer.attention import (
|
12 |
+
Conv1dLinear,
|
13 |
+
MultiHeadedAttention,
|
14 |
+
MultiLayeredConv1d,
|
15 |
+
PositionalEncoding,
|
16 |
+
PositionwiseFeedForward,
|
17 |
+
RelPositionalEncoding,
|
18 |
+
)
|
19 |
+
|
20 |
+
# from vita.model.multimodal_encoder.whale.module.component.utils import *
|
21 |
+
from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, add_optional_chunk_mask, strtobool
|
22 |
+
|
23 |
+
|
24 |
+
def repeat(N, fn):
|
25 |
+
"""Repeat module N times.
|
26 |
+
|
27 |
+
:param int N: repeat time
|
28 |
+
:param function fn: function to generate module
|
29 |
+
:return: repeated modules
|
30 |
+
:rtype: MultiSequential
|
31 |
+
"""
|
32 |
+
return MultiSequential(*[fn(n) for n in range(N)])
|
33 |
+
|
34 |
+
|
35 |
+
class MultiSequential(torch.nn.Sequential):
|
36 |
+
"""Multi-input multi-output torch.nn.Sequential."""
|
37 |
+
|
38 |
+
def forward(self, x, masks, pos_emb):
|
39 |
+
|
40 |
+
"""Repeat."""
|
41 |
+
for m in self:
|
42 |
+
x, masks, pos_emb = m(x, masks, pos_emb)
|
43 |
+
return x, masks, pos_emb
|
44 |
+
|
45 |
+
@torch.jit.export
|
46 |
+
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
|
47 |
+
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
|
48 |
+
"""Repeat."""
|
49 |
+
for m in self:
|
50 |
+
x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
|
51 |
+
x, pos_emb, buffer, buffer_index, buffer_out
|
52 |
+
)
|
53 |
+
return x, pos_emb, buffer, buffer_index, buffer_out
|
54 |
+
|
55 |
+
@torch.jit.export
|
56 |
+
def infer_hidden(self, x, pos_emb, buffer, buffer_index, buffer_out, hidden_out):
|
57 |
+
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
|
58 |
+
"""Repeat."""
|
59 |
+
for m in self:
|
60 |
+
x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
|
61 |
+
x, pos_emb, buffer, buffer_index, buffer_out
|
62 |
+
)
|
63 |
+
hidden_out.append(x)
|
64 |
+
return x, pos_emb, buffer, buffer_index, buffer_out, hidden_out
|
65 |
+
|
66 |
+
|
67 |
+
class TransformerLayer(nn.Module):
|
68 |
+
"""Transformer layer module.
|
69 |
+
|
70 |
+
:param int size: input dim
|
71 |
+
:param self_attn: self attention module
|
72 |
+
:param feed_forward: feed forward module
|
73 |
+
:param float dropout_rate: dropout rate
|
74 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
75 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
76 |
+
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
|
77 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
78 |
+
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False
|
83 |
+
):
|
84 |
+
"""Construct an TransformerLayer object."""
|
85 |
+
super(TransformerLayer, self).__init__()
|
86 |
+
self.self_attn = self_attn
|
87 |
+
self.feed_forward = feed_forward
|
88 |
+
self.norm1 = torch.nn.LayerNorm(size)
|
89 |
+
self.norm2 = torch.nn.LayerNorm(size)
|
90 |
+
self.dropout = nn.Dropout(dropout_rate)
|
91 |
+
self.size = size
|
92 |
+
self.normalize_before = normalize_before
|
93 |
+
self.concat_after = concat_after
|
94 |
+
if self.concat_after:
|
95 |
+
self.concat_linear = nn.Linear(size + size, size)
|
96 |
+
else:
|
97 |
+
self.concat_linear = nn.Identity()
|
98 |
+
|
99 |
+
@torch.jit.unused
|
100 |
+
def forward(self, x, mask, pos_emb):
|
101 |
+
"""Compute encoded features.
|
102 |
+
|
103 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
104 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
105 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
106 |
+
"""
|
107 |
+
residual = x
|
108 |
+
if self.normalize_before:
|
109 |
+
x = self.norm1(x)
|
110 |
+
if self.concat_after:
|
111 |
+
x_concat = torch.cat((x, self.self_attn(x, x, x, mask, pos_emb)), dim=-1)
|
112 |
+
x = residual + self.concat_linear(x_concat)
|
113 |
+
else:
|
114 |
+
x = residual + self.dropout(self.self_attn(x, x, x, mask, pos_emb))
|
115 |
+
if not self.normalize_before:
|
116 |
+
x = self.norm1(x)
|
117 |
+
|
118 |
+
residual = x
|
119 |
+
if self.normalize_before:
|
120 |
+
x = self.norm2(x)
|
121 |
+
x = residual + self.dropout(self.feed_forward(x))
|
122 |
+
if not self.normalize_before:
|
123 |
+
x = self.norm2(x)
|
124 |
+
|
125 |
+
return x, mask, pos_emb
|
126 |
+
|
127 |
+
@torch.jit.export
|
128 |
+
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
|
129 |
+
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
|
130 |
+
residual = x.clone()
|
131 |
+
if self.normalize_before:
|
132 |
+
x = self.norm1(x)
|
133 |
+
if self.concat_after:
|
134 |
+
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
|
135 |
+
x, x, x, pos_emb, buffer, buffer_index, buffer_out
|
136 |
+
)
|
137 |
+
x_concat = torch.cat((x, x_att), dim=-1)
|
138 |
+
x = residual + self.concat_linear(x_concat)
|
139 |
+
else:
|
140 |
+
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
|
141 |
+
x, x, x, pos_emb, buffer, buffer_index, buffer_out
|
142 |
+
)
|
143 |
+
x = residual + x_att
|
144 |
+
if not self.normalize_before:
|
145 |
+
x = self.norm1(x)
|
146 |
+
|
147 |
+
residual = x.clone()
|
148 |
+
if self.normalize_before:
|
149 |
+
x = self.norm2(x)
|
150 |
+
x_feed, buffer, buffer_index, buffer_out = self.feed_forward.infer(
|
151 |
+
x, buffer, buffer_index, buffer_out
|
152 |
+
)
|
153 |
+
x = residual + x_feed
|
154 |
+
if not self.normalize_before:
|
155 |
+
x = self.norm2(x)
|
156 |
+
|
157 |
+
return x, pos_emb, buffer, buffer_index, buffer_out
|
158 |
+
|
159 |
+
|
160 |
+
class Transformer(torch.nn.Module):
|
161 |
+
@staticmethod
|
162 |
+
def add_arguments(group):
|
163 |
+
"""Add TDNN common arguments."""
|
164 |
+
group.add_argument(
|
165 |
+
"--transformer-input-dim", default=256, type=int, help="Input dim of Transformer."
|
166 |
+
)
|
167 |
+
group.add_argument(
|
168 |
+
"--transformer-output-dim", default=4, type=int, help="Output dim of Transformer."
|
169 |
+
)
|
170 |
+
group.add_argument(
|
171 |
+
"--transformer-attention-dim", default=256, type=int, help="Dimention of attention."
|
172 |
+
)
|
173 |
+
group.add_argument(
|
174 |
+
"--transformer-attention-heads",
|
175 |
+
default=4,
|
176 |
+
type=int,
|
177 |
+
help="The number of heads of multi head attention.",
|
178 |
+
)
|
179 |
+
group.add_argument(
|
180 |
+
"--transformer-linear-units",
|
181 |
+
default=1024,
|
182 |
+
type=int,
|
183 |
+
help="The number of units of position-wise feed forward.",
|
184 |
+
)
|
185 |
+
group.add_argument(
|
186 |
+
"--transformer-num-blocks", default=6, type=int, help="The number of attention blocks."
|
187 |
+
)
|
188 |
+
group.add_argument(
|
189 |
+
"--transformer-dropout-rate",
|
190 |
+
default=0.1,
|
191 |
+
type=float,
|
192 |
+
help="Dropout rate in Transformer.",
|
193 |
+
)
|
194 |
+
group.add_argument(
|
195 |
+
"--transformer-attention-dropout-rate",
|
196 |
+
default=0.0,
|
197 |
+
type=float,
|
198 |
+
help="Dropout rate in attention.",
|
199 |
+
)
|
200 |
+
group.add_argument(
|
201 |
+
"--transformer-positional-dropout-rate",
|
202 |
+
default=0.1,
|
203 |
+
type=float,
|
204 |
+
help="Dropout rate after adding positional encoding.",
|
205 |
+
)
|
206 |
+
group.add_argument(
|
207 |
+
"--transformer-input-layer", default="linear", type=str, help="Type of input layer"
|
208 |
+
)
|
209 |
+
group.add_argument("--transformer-pos-enc-class", default="abs-enc", type=str, help="")
|
210 |
+
group.add_argument(
|
211 |
+
"--transformer-normalize-before",
|
212 |
+
default=True,
|
213 |
+
type=strtobool,
|
214 |
+
help="Whether to use layer-norm before the first block.",
|
215 |
+
)
|
216 |
+
group.add_argument(
|
217 |
+
"--transformer-concat-after",
|
218 |
+
default=False,
|
219 |
+
type=strtobool,
|
220 |
+
help="Whether to concat attention layer's input and output.",
|
221 |
+
)
|
222 |
+
group.add_argument(
|
223 |
+
"--transformer-positionwise-layer-type",
|
224 |
+
default="linear",
|
225 |
+
type=str,
|
226 |
+
help="Linear of conv1d.",
|
227 |
+
)
|
228 |
+
group.add_argument(
|
229 |
+
"--transformer-positionwise-conv-kernel_size",
|
230 |
+
default=1,
|
231 |
+
type=int,
|
232 |
+
help="Kernel size of positionwise conv1d layer.",
|
233 |
+
)
|
234 |
+
group.add_argument("--transformer-chunk_size", default=-1, type=int, help="")
|
235 |
+
group.add_argument("--transformer-left_chunks", default=-1, type=int, help="")
|
236 |
+
group.add_argument("--transformer-dynamic-chunks", default=True, type=strtobool, help="")
|
237 |
+
return group
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
args,
|
242 |
+
input_dim=None,
|
243 |
+
output_dim=None,
|
244 |
+
attention_dim=None,
|
245 |
+
attention_heads=None,
|
246 |
+
linear_units=None,
|
247 |
+
num_blocks=None,
|
248 |
+
dropout_rate=None,
|
249 |
+
positional_dropout_rate=None,
|
250 |
+
attention_dropout_rate=None,
|
251 |
+
input_layer=None,
|
252 |
+
pos_enc_class=None,
|
253 |
+
normalize_before=None,
|
254 |
+
concat_after=None,
|
255 |
+
positionwise_layer_type=None,
|
256 |
+
positionwise_conv_kernel_size=None,
|
257 |
+
chunk_size=None,
|
258 |
+
left_chunks=None,
|
259 |
+
):
|
260 |
+
"""Construct an Encoder object."""
|
261 |
+
super(Transformer, self).__init__()
|
262 |
+
if args is None:
|
263 |
+
self.input_dim = input_dim
|
264 |
+
self.output_dim = output_dim
|
265 |
+
self.attention_dim = attention_dim
|
266 |
+
self.attention_heads = attention_heads
|
267 |
+
self.linear_units = linear_units
|
268 |
+
self.num_blocks = num_blocks
|
269 |
+
self.dropout_rate = dropout_rate
|
270 |
+
self.positional_dropout_rate = positional_dropout_rate
|
271 |
+
self.attention_dropout_rate = attention_dropout_rate
|
272 |
+
self.input_layer = input_layer
|
273 |
+
self.pos_enc_class = pos_enc_class
|
274 |
+
self.normalize_before = normalize_before
|
275 |
+
self.concat_after = concat_after
|
276 |
+
self.positionwise_layer_type = positionwise_layer_type
|
277 |
+
self.positionwise_conv_kernel_size = positionwise_conv_kernel_size
|
278 |
+
self.chunk_size = chunk_size
|
279 |
+
self.left_chunks = left_chunks
|
280 |
+
else:
|
281 |
+
self.input_dim = args.transformer_input_dim
|
282 |
+
self.output_dim = args.transformer_output_dim
|
283 |
+
self.attention_dim = args.transformer_attention_dim
|
284 |
+
self.attention_heads = args.transformer_attention_heads
|
285 |
+
self.linear_units = args.transformer_linear_units
|
286 |
+
self.num_blocks = args.transformer_num_blocks
|
287 |
+
self.dropout_rate = args.transformer_dropout_rate
|
288 |
+
self.positional_dropout_rate = args.transformer_positional_dropout_rate
|
289 |
+
self.attention_dropout_rate = args.transformer_attention_dropout_rate
|
290 |
+
self.input_layer = args.transformer_input_layer
|
291 |
+
self.pos_enc_class = args.transformer_pos_enc_class
|
292 |
+
self.normalize_before = args.transformer_normalize_before
|
293 |
+
self.concat_after = args.transformer_concat_after
|
294 |
+
self.positionwise_layer_type = args.transformer_positionwise_layer_type
|
295 |
+
self.positionwise_conv_kernel_size = args.transformer_positionwise_conv_kernel_size
|
296 |
+
self.chunk_size = args.transformer_chunk_size
|
297 |
+
self.left_chunks = args.transformer_left_chunks
|
298 |
+
self.transformer_dynamic_chunks = args.transformer_dynamic_chunks
|
299 |
+
|
300 |
+
if self.pos_enc_class == "abs-enc":
|
301 |
+
pos_enc_args = (self.attention_dim, self.positional_dropout_rate)
|
302 |
+
pos_enc_class = PositionalEncoding
|
303 |
+
elif self.pos_enc_class == "rel-enc":
|
304 |
+
pos_enc_args = (
|
305 |
+
self.attention_dim,
|
306 |
+
self.positional_dropout_rate,
|
307 |
+
self.chunk_size,
|
308 |
+
self.left_chunks,
|
309 |
+
)
|
310 |
+
pos_enc_class = RelPositionalEncoding
|
311 |
+
|
312 |
+
if self.input_layer == "linear":
|
313 |
+
self.embed = torch.nn.Sequential(
|
314 |
+
torch.nn.Linear(self.input_dim, self.attention_dim),
|
315 |
+
torch.nn.LayerNorm(self.attention_dim),
|
316 |
+
torch.nn.Dropout(self.dropout_rate),
|
317 |
+
torch.nn.ReLU(),
|
318 |
+
)
|
319 |
+
elif self.input_layer == "embed":
|
320 |
+
self.embed = torch.nn.Sequential(
|
321 |
+
torch.nn.Embedding(self.input_dim, self.attention_dim, padding_idx=IGNORE_ID)
|
322 |
+
)
|
323 |
+
elif self.input_layer == "none":
|
324 |
+
self.embed = torch.nn.Sequential(torch.nn.Identity())
|
325 |
+
else:
|
326 |
+
raise ValueError("unknown input_layer: " + self.input_layer)
|
327 |
+
self.pe = pos_enc_class(*pos_enc_args)
|
328 |
+
self.embed_layer_num = len(self.embed)
|
329 |
+
|
330 |
+
if self.positionwise_layer_type == "linear":
|
331 |
+
positionwise_layer = PositionwiseFeedForward
|
332 |
+
positionwise_layer_args = (self.attention_dim, self.linear_units, self.dropout_rate)
|
333 |
+
elif self.positionwise_layer_type == "conv1d":
|
334 |
+
positionwise_layer = MultiLayeredConv1d
|
335 |
+
positionwise_layer_args = (
|
336 |
+
self.attention_dim,
|
337 |
+
self.linear_units,
|
338 |
+
self.positionwise_conv_kernel_size,
|
339 |
+
self.dropout_rate,
|
340 |
+
)
|
341 |
+
elif self.positionwise_layer_type == "conv1d-linear":
|
342 |
+
positionwise_layer = Conv1dLinear
|
343 |
+
positionwise_layer_args = (
|
344 |
+
self.attention_dim,
|
345 |
+
self.linear_units,
|
346 |
+
self.positionwise_conv_kernel_size,
|
347 |
+
self.dropout_rate,
|
348 |
+
)
|
349 |
+
else:
|
350 |
+
raise NotImplementedError("Support only linear or conv1d.")
|
351 |
+
|
352 |
+
self.encoders = repeat(
|
353 |
+
self.num_blocks,
|
354 |
+
lambda lnum: TransformerLayer(
|
355 |
+
self.attention_dim,
|
356 |
+
MultiHeadedAttention(
|
357 |
+
self.attention_heads,
|
358 |
+
self.attention_dim,
|
359 |
+
self.attention_dropout_rate,
|
360 |
+
self.chunk_size,
|
361 |
+
self.left_chunks,
|
362 |
+
self.pos_enc_class,
|
363 |
+
),
|
364 |
+
positionwise_layer(*positionwise_layer_args),
|
365 |
+
self.dropout_rate,
|
366 |
+
self.normalize_before,
|
367 |
+
self.concat_after,
|
368 |
+
),
|
369 |
+
)
|
370 |
+
if self.normalize_before:
|
371 |
+
self.after_norm = torch.nn.LayerNorm(self.attention_dim)
|
372 |
+
|
373 |
+
@torch.jit.unused
|
374 |
+
def forward(self, xs, ilens=None, masks=None):
|
375 |
+
"""Embed positions in tensor.
|
376 |
+
|
377 |
+
:param torch.Tensor xs: input tensor
|
378 |
+
:param torch.Tensor masks: input mask
|
379 |
+
:return: position embedded tensor and mask
|
380 |
+
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
381 |
+
"""
|
382 |
+
|
383 |
+
if self.transformer_dynamic_chunks == True: # and self.training:
|
384 |
+
chunk_masks = add_optional_chunk_mask(xs, masks, True, True, 0, 0, -1)
|
385 |
+
else:
|
386 |
+
chunk_masks = add_optional_chunk_mask(
|
387 |
+
xs, masks, False, False, self.chunk_size, self.chunk_size, self.left_chunks
|
388 |
+
).to(xs.device)
|
389 |
+
xs = self.embed(xs)
|
390 |
+
xs, pos_emb = self.pe(xs)
|
391 |
+
xs, chunk_masks, pos_emb = self.encoders(xs, chunk_masks, pos_emb)
|
392 |
+
if self.normalize_before:
|
393 |
+
xs = self.after_norm(xs)
|
394 |
+
return xs, ilens, masks
|
395 |
+
|
396 |
+
@torch.jit.export
|
397 |
+
def infer(self, xs, buffer, buffer_index, buffer_out):
|
398 |
+
xs = self.embed(xs)
|
399 |
+
|
400 |
+
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
|
401 |
+
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
|
402 |
+
# buffer_out.append(pe_index.reshape(-1).to(torch.float32))
|
403 |
+
# buffer_index = buffer_index + 1
|
404 |
+
xs, pos_emb, _ = self.pe.infer(xs, 0)
|
405 |
+
xs, pos_emb, buffer, buffer_index, buffer_out = self.encoders.infer(
|
406 |
+
xs, pos_emb, buffer, buffer_index, buffer_out
|
407 |
+
)
|
408 |
+
|
409 |
+
if self.normalize_before:
|
410 |
+
xs = self.after_norm(xs)
|
411 |
+
return xs, buffer, buffer_index, buffer_out
|
412 |
+
|
413 |
+
@torch.jit.export
|
414 |
+
def infer_hidden(self, xs, buffer, buffer_index, buffer_out, hidden_out):
|
415 |
+
xs = self.embed(xs)
|
416 |
+
|
417 |
+
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
|
418 |
+
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
|
419 |
+
# buffer_out.append(pe_index.reshape(-1).to(torch.float32))
|
420 |
+
# buffer_index = buffer_index + 1
|
421 |
+
xs, pos_emb, _ = self.pe.infer(xs, 0)
|
422 |
+
xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out = self.encoders.infer_hidden(
|
423 |
+
xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out
|
424 |
+
)
|
425 |
+
|
426 |
+
if self.normalize_before:
|
427 |
+
xs = self.after_norm(xs)
|
428 |
+
return xs, buffer, buffer_index, buffer_out, hidden_out
|
vita/model/multimodal_encoder/whale/module/encoder/encoder.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
from typing import Dict, Optional, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import six
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from vita.model.multimodal_encoder.whale.module.component.mamba import MambaSSM
|
12 |
+
from vita.model.multimodal_encoder.whale.module.component.subsampling import Subsampling
|
13 |
+
from vita.model.multimodal_encoder.whale.module.component.transformer import Transformer
|
14 |
+
from vita.model.multimodal_encoder.whale.utils import make_pad_mask
|
15 |
+
|
16 |
+
|
17 |
+
def add_encoder_args(group):
|
18 |
+
"""Add Encoder common arguments."""
|
19 |
+
group.add_argument(
|
20 |
+
"--encoder-layer-config",
|
21 |
+
type=str,
|
22 |
+
default="tdnn-dtc",
|
23 |
+
help="Layer config of encoder. Format layername-layername-..., default(conv1d-fsmn-rnn)",
|
24 |
+
)
|
25 |
+
group.add_argument(
|
26 |
+
"--encoder-input-dim",
|
27 |
+
type=int,
|
28 |
+
default=256,
|
29 |
+
help="Input dim of encoder. Must equal to the input dim of the first Component (default=40)",
|
30 |
+
)
|
31 |
+
group.add_argument(
|
32 |
+
"--encoder-output-dim",
|
33 |
+
type=int,
|
34 |
+
default=256,
|
35 |
+
help="Output dim of encoder. Must enqual to the output dim of the last Component ! (default=256)",
|
36 |
+
)
|
37 |
+
# Add args of all kinds of components.
|
38 |
+
# If you add a new component, DO NOT forget to add args to add_component_args func.
|
39 |
+
group = Transformer.add_arguments(group)
|
40 |
+
group = Subsampling.add_arguments(group)
|
41 |
+
group = MambaSSM.add_arguments(group)
|
42 |
+
return group
|
43 |
+
|
44 |
+
|
45 |
+
def assign_args_from_dict(args, dict, prefix_key=None):
|
46 |
+
if prefix_key is not None:
|
47 |
+
dict = dict[prefix_key]
|
48 |
+
for k, v in dict.items():
|
49 |
+
k_args = k.replace("-", "_")
|
50 |
+
if hasattr(args, k_args):
|
51 |
+
setattr(args, k_args, dict[k])
|
52 |
+
return args
|
53 |
+
|
54 |
+
|
55 |
+
class whaleEncoder(torch.nn.Module):
|
56 |
+
def __init__(self, input_dim, overview_conf=None, para_conf=None, global_cmvn=None):
|
57 |
+
super(whaleEncoder, self).__init__()
|
58 |
+
|
59 |
+
parser = argparse.ArgumentParser()
|
60 |
+
add_encoder_args(parser)
|
61 |
+
args, _ = parser.parse_known_args()
|
62 |
+
|
63 |
+
assign_args_from_dict(args, overview_conf)
|
64 |
+
# assign_args_from_dict(args, para_conf)
|
65 |
+
|
66 |
+
self.config = args.encoder_layer_config.split("-")
|
67 |
+
encoder_input_dim = args.encoder_input_dim
|
68 |
+
encoder_output_dim = args.encoder_output_dim
|
69 |
+
prev_output_dim = encoder_input_dim
|
70 |
+
prev_component_name = "encoder"
|
71 |
+
self.enc = torch.nn.ModuleList([])
|
72 |
+
for name in self.config:
|
73 |
+
assign_args_from_dict(args, para_conf[name])
|
74 |
+
if len(name.split("_")) == 2:
|
75 |
+
name = name.split("_")[0]
|
76 |
+
elif len(name.split("_")) == 1:
|
77 |
+
name = name
|
78 |
+
else:
|
79 |
+
logging.error("WRONG CONFIG! {} is not valid".format("encoder", name))
|
80 |
+
sys.exit()
|
81 |
+
|
82 |
+
if name == "transformer":
|
83 |
+
self.enc.append(Transformer(args))
|
84 |
+
elif name == "subsampling":
|
85 |
+
self.enc.append(Subsampling(args))
|
86 |
+
elif name == "mamba":
|
87 |
+
self.enc.append(MambaSSM(args))
|
88 |
+
else:
|
89 |
+
print("{} is not supported now!".format(name))
|
90 |
+
return NotImplemented
|
91 |
+
component_input_dim = getattr(args, name + "_input_dim")
|
92 |
+
if component_input_dim != prev_output_dim:
|
93 |
+
# This is the first layer
|
94 |
+
logging.error(
|
95 |
+
"WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-input-dim ({})".format(
|
96 |
+
prev_component_name, prev_output_dim, name, component_input_dim
|
97 |
+
)
|
98 |
+
)
|
99 |
+
sys.exit()
|
100 |
+
prev_output_dim = getattr(args, name + "_output_dim")
|
101 |
+
prev_component_name = name
|
102 |
+
|
103 |
+
self.global_cmvn = global_cmvn
|
104 |
+
if prev_output_dim != encoder_output_dim:
|
105 |
+
logging.error(
|
106 |
+
"WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-output-dim ({}, the last component)".format(
|
107 |
+
"encoder", encoder_output_dim, name, prev_output_dim
|
108 |
+
)
|
109 |
+
)
|
110 |
+
sys.exit()
|
111 |
+
|
112 |
+
self._output_size = encoder_output_dim
|
113 |
+
|
114 |
+
num_params = sum(p.numel() for p in self.parameters())
|
115 |
+
print("the number of whale encoder params: {}M".format(num_params / 1024 / 1024))
|
116 |
+
|
117 |
+
def output_size(self) -> int:
|
118 |
+
return self._output_size
|
119 |
+
|
120 |
+
@torch.jit.unused
|
121 |
+
def forward(self, xs, ilens, decoding_chunk_size=None, num_decoding_left_chunks=None):
|
122 |
+
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[List[int]], Optional[Tensor]]
|
123 |
+
"""Encoder forward
|
124 |
+
|
125 |
+
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
126 |
+
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
127 |
+
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
128 |
+
:rtype: torch.Tensor
|
129 |
+
"""
|
130 |
+
|
131 |
+
if decoding_chunk_size is not None and num_decoding_left_chunks is not None:
|
132 |
+
for layer in self.enc:
|
133 |
+
if hasattr(layer, "chunk_size"):
|
134 |
+
layer.chunk_size = decoding_chunk_size
|
135 |
+
if hasattr(layer, "left_chunks"):
|
136 |
+
layer.left_chunks = num_decoding_left_chunks
|
137 |
+
if hasattr(layer, "transformer_dynamic_chunks"):
|
138 |
+
layer.transformer_dynamic_chunks = False
|
139 |
+
|
140 |
+
assert (len(xs.shape)) == 3
|
141 |
+
T = xs.size(1)
|
142 |
+
masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T)
|
143 |
+
if self.global_cmvn is not None:
|
144 |
+
xs = self.global_cmvn(xs)
|
145 |
+
for module in self.enc:
|
146 |
+
xs, ilens, masks = module(xs, ilens, masks)
|
147 |
+
return xs, masks
|
148 |
+
|
149 |
+
@torch.jit.export
|
150 |
+
def infer(self, xs_pad, buffer, buffer_index, buffer_out):
|
151 |
+
if self.global_cmvn is not None:
|
152 |
+
xs = self.global_cmvn(xs)
|
153 |
+
for module in self.enc:
|
154 |
+
xs_pad, buffer, buffer_index, buffer_out = module.infer(
|
155 |
+
xs_pad, buffer, buffer_index, buffer_out
|
156 |
+
)
|
157 |
+
return xs_pad, buffer, buffer_index, buffer_out
|
158 |
+
|
159 |
+
@torch.jit.export
|
160 |
+
def infer_hidden(self, xs_pad, buffer, buffer_index, buffer_out, hidden_out):
|
161 |
+
if self.global_cmvn is not None:
|
162 |
+
xs = self.global_cmvn(xs)
|
163 |
+
for module in self.enc:
|
164 |
+
xs_pad, buffer, buffer_index, buffer_out, hidden_out = module.infer_hidden(
|
165 |
+
xs_pad, buffer, buffer_index, buffer_out, hidden_out
|
166 |
+
)
|
167 |
+
return xs_pad, buffer, buffer_index, buffer_out, hidden_out
|
168 |
+
|
169 |
+
@torch.jit.ignore(drop=True)
|
170 |
+
def get_extra_loss(self) -> Dict[str, torch.Tensor]:
|
171 |
+
return None
|
vita/model/multimodal_encoder/whale/module/layer/attention.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import pdb
|
3 |
+
|
4 |
+
import numpy
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class PositionalEncoding(torch.nn.Module):
|
10 |
+
"""Positional encoding.
|
11 |
+
:param int d_model: embedding dim
|
12 |
+
:param float dropout_rate: dropout rate
|
13 |
+
:param int max_len: maximum input length
|
14 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
15 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self, d_model: int, dropout_rate: float, max_len: int = 1500, reverse: bool = False
|
20 |
+
):
|
21 |
+
"""Construct an PositionalEncoding object."""
|
22 |
+
super().__init__()
|
23 |
+
self.d_model = d_model
|
24 |
+
self.xscale = math.sqrt(self.d_model)
|
25 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
26 |
+
self.max_len = max_len
|
27 |
+
|
28 |
+
self.pe = torch.zeros(self.max_len, self.d_model)
|
29 |
+
position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1)
|
30 |
+
div_term = torch.exp(
|
31 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
32 |
+
* -(math.log(10000.0) / self.d_model)
|
33 |
+
)
|
34 |
+
self.pe[:, 0::2] = torch.sin(position * div_term)
|
35 |
+
self.pe[:, 1::2] = torch.cos(position * div_term)
|
36 |
+
self.pe = self.pe.unsqueeze(0)
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor, offset: int = 0):
|
39 |
+
"""Add positional encoding.
|
40 |
+
Args:
|
41 |
+
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
42 |
+
offset (int): position offset
|
43 |
+
Returns:
|
44 |
+
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
45 |
+
torch.Tensor: for compatibility to RelPositionalEncoding
|
46 |
+
"""
|
47 |
+
assert offset + x.size(1) < self.max_len
|
48 |
+
self.pe = self.pe.to(x.device)
|
49 |
+
pos_emb = self.pe[:, offset : offset + x.size(1)]
|
50 |
+
x = x * self.xscale + pos_emb
|
51 |
+
return self.dropout(x), self.dropout(pos_emb)
|
52 |
+
|
53 |
+
def position_encoding(self, offset: int, size: int):
|
54 |
+
"""For getting encoding in a streaming fashion
|
55 |
+
Attention!!!!!
|
56 |
+
we apply dropout only once at the whole utterance level in a none
|
57 |
+
streaming way, but will call this function several times with
|
58 |
+
increasing input size in a streaming scenario, so the dropout will
|
59 |
+
be applied several times.
|
60 |
+
Args:
|
61 |
+
offset (int): start offset
|
62 |
+
size (int): requried size of position encoding
|
63 |
+
Returns:
|
64 |
+
torch.Tensor: Corresponding encoding
|
65 |
+
"""
|
66 |
+
assert offset + size < self.max_len
|
67 |
+
return self.dropout(self.pe[:, offset : offset + size])
|
68 |
+
|
69 |
+
|
70 |
+
class RelPositionalEncoding(PositionalEncoding):
|
71 |
+
"""Relative positional encoding module.
|
72 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
73 |
+
Args:
|
74 |
+
d_model (int): Embedding dimension.
|
75 |
+
dropout_rate (float): Dropout rate.
|
76 |
+
max_len (int): Maximum input length.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
d_model: int,
|
82 |
+
dropout_rate: float,
|
83 |
+
chunk_size: int,
|
84 |
+
left_chunks: int,
|
85 |
+
max_len: int = 5000,
|
86 |
+
):
|
87 |
+
"""Initialize class."""
|
88 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
89 |
+
self.chunk_size = chunk_size
|
90 |
+
self.left_chunks = left_chunks
|
91 |
+
self.full_chunk_size = (self.left_chunks + 1) * self.chunk_size
|
92 |
+
|
93 |
+
self.div_term = torch.exp(
|
94 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
95 |
+
* -(math.log(10000.0) / self.d_model)
|
96 |
+
)
|
97 |
+
self.max_len = self.chunk_size * (max_len // self.chunk_size) - self.full_chunk_size
|
98 |
+
|
99 |
+
@torch.jit.export
|
100 |
+
def forward(self, x: torch.Tensor, offset: int = 0):
|
101 |
+
"""Compute positional encoding.
|
102 |
+
Args:
|
103 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
104 |
+
Returns:
|
105 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
106 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
107 |
+
"""
|
108 |
+
self.pe = self.pe.to(x.device)
|
109 |
+
x = x * self.xscale
|
110 |
+
pos_emb = self.pe[:, offset : offset + x.size(1)]
|
111 |
+
return self.dropout(x), self.dropout(pos_emb)
|
112 |
+
|
113 |
+
@torch.jit.export
|
114 |
+
def infer(self, xs, pe_index):
|
115 |
+
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
116 |
+
pe_index = pe_index % self.max_len
|
117 |
+
xs = xs * self.xscale
|
118 |
+
|
119 |
+
pe = torch.zeros(self.full_chunk_size, self.d_model)
|
120 |
+
position = torch.arange(
|
121 |
+
pe_index, pe_index + self.full_chunk_size, dtype=torch.float32
|
122 |
+
).unsqueeze(1)
|
123 |
+
pe[:, 0::2] = torch.sin(position * self.div_term)
|
124 |
+
pe[:, 1::2] = torch.cos(position * self.div_term)
|
125 |
+
pos_emb = pe.unsqueeze(0)
|
126 |
+
|
127 |
+
pe_index = pe_index + self.chunk_size
|
128 |
+
return xs, pos_emb, pe_index
|
129 |
+
|
130 |
+
|
131 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
132 |
+
"""Positionwise feed forward layer.
|
133 |
+
:param int idim: input dimenstion
|
134 |
+
:param int hidden_units: number of hidden units
|
135 |
+
:param float dropout_rate: dropout rate
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(self, idim, hidden_units, dropout_rate):
|
139 |
+
"""Construct an PositionwiseFeedForward object."""
|
140 |
+
super(PositionwiseFeedForward, self).__init__()
|
141 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
142 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
143 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
"""Forward funciton."""
|
147 |
+
return self.w_2(self.dropout(torch.relu(self.w_1(x))))
|
148 |
+
|
149 |
+
@torch.jit.export
|
150 |
+
def infer(self, xs, buffer, buffer_index, buffer_out):
|
151 |
+
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
152 |
+
return self.w_2(torch.relu(self.w_1(xs))), buffer, buffer_index, buffer_out
|
153 |
+
|
154 |
+
|
155 |
+
class MultiLayeredConv1d(torch.nn.Module):
|
156 |
+
"""Multi-layered conv1d for Transformer block.
|
157 |
+
|
158 |
+
This is a module of multi-leyered conv1d designed
|
159 |
+
to replace positionwise feed-forward network
|
160 |
+
in Transformer block, which is introduced in
|
161 |
+
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
162 |
+
|
163 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
164 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
165 |
+
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
169 |
+
"""Initialize MultiLayeredConv1d module.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
in_chans (int): Number of input channels.
|
173 |
+
hidden_chans (int): Number of hidden channels.
|
174 |
+
kernel_size (int): Kernel size of conv1d.
|
175 |
+
dropout_rate (float): Dropout rate.
|
176 |
+
|
177 |
+
"""
|
178 |
+
super(MultiLayeredConv1d, self).__init__()
|
179 |
+
self.w_1 = torch.nn.Conv1d(
|
180 |
+
in_chans,
|
181 |
+
hidden_chans,
|
182 |
+
kernel_size,
|
183 |
+
stride=1,
|
184 |
+
padding=(kernel_size - 1) // 2,
|
185 |
+
)
|
186 |
+
self.w_2 = torch.nn.Conv1d(
|
187 |
+
hidden_chans,
|
188 |
+
in_chans,
|
189 |
+
kernel_size,
|
190 |
+
stride=1,
|
191 |
+
padding=(kernel_size - 1) // 2,
|
192 |
+
)
|
193 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
194 |
+
|
195 |
+
@torch.jit.unused
|
196 |
+
def forward(self, x):
|
197 |
+
"""Calculate forward propagation.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
204 |
+
|
205 |
+
"""
|
206 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
207 |
+
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
208 |
+
|
209 |
+
|
210 |
+
class Conv1dLinear(torch.nn.Module):
|
211 |
+
"""Conv1D + Linear for Transformer block.
|
212 |
+
|
213 |
+
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
214 |
+
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
218 |
+
"""Initialize Conv1dLinear module.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
in_chans (int): Number of input channels.
|
222 |
+
hidden_chans (int): Number of hidden channels.
|
223 |
+
kernel_size (int): Kernel size of conv1d.
|
224 |
+
dropout_rate (float): Dropout rate.
|
225 |
+
|
226 |
+
"""
|
227 |
+
super(Conv1dLinear, self).__init__()
|
228 |
+
self.lorder = kernel_size - 1
|
229 |
+
self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
|
230 |
+
self.w_1 = torch.nn.Sequential(
|
231 |
+
torch.nn.Conv1d(in_chans, in_chans, kernel_size, stride=1, padding=0, groups=in_chans),
|
232 |
+
torch.nn.Conv1d(in_chans, hidden_chans, 1, padding=0),
|
233 |
+
)
|
234 |
+
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
235 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
236 |
+
self.in_chans = in_chans
|
237 |
+
|
238 |
+
# cnn_buffer = 1, in_chans, self.lorder
|
239 |
+
self.buffer_size = 1 * self.in_chans * self.lorder
|
240 |
+
|
241 |
+
@torch.jit.unused
|
242 |
+
def forward(self, x):
|
243 |
+
"""Calculate forward propagation.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
250 |
+
|
251 |
+
"""
|
252 |
+
x = torch.relu(self.w_1(self.left_padding(x.transpose(-1, 1)))).transpose(-1, 1)
|
253 |
+
return self.w_2(self.dropout(x))
|
254 |
+
|
255 |
+
@torch.jit.export
|
256 |
+
def infer(self, x, buffer, buffer_index, buffer_out):
|
257 |
+
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
258 |
+
x = x.transpose(-1, 1)
|
259 |
+
|
260 |
+
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
|
261 |
+
[1, self.in_chans, self.lorder]
|
262 |
+
)
|
263 |
+
x = torch.cat([cnn_buffer, x], dim=2)
|
264 |
+
buffer_out.append(x[:, :, -self.lorder :].reshape(-1))
|
265 |
+
buffer_index = buffer_index + self.buffer_size
|
266 |
+
|
267 |
+
x = self.w_1(x)
|
268 |
+
x = torch.relu(x).transpose(-1, 1)
|
269 |
+
x = self.w_2(x)
|
270 |
+
return x, buffer, buffer_index, buffer_out
|
271 |
+
|
272 |
+
|
273 |
+
class MultiHeadedAttention(nn.Module):
|
274 |
+
"""Multi-Head Attention layer.
|
275 |
+
|
276 |
+
:param int n_head: the number of head s
|
277 |
+
:param int n_feat: the number of features
|
278 |
+
:param float dropout_rate: dropout rate
|
279 |
+
|
280 |
+
"""
|
281 |
+
|
282 |
+
def __init__(self, n_head, n_feat, dropout_rate, chunk_size, left_chunks, pos_enc_class):
|
283 |
+
"""Construct an MultiHeadedAttention object."""
|
284 |
+
super(MultiHeadedAttention, self).__init__()
|
285 |
+
assert n_feat % n_head == 0
|
286 |
+
# We assume d_v always equals d_k
|
287 |
+
self.d_k = n_feat // n_head
|
288 |
+
self.h = n_head
|
289 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
290 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
291 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
292 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
293 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
294 |
+
# self.min_value = float(numpy.finfo(torch.tensor(0, dtype=torch.float16).numpy().dtype).min)
|
295 |
+
self.min_value = float(torch.finfo(torch.float16).min)
|
296 |
+
# chunk par
|
297 |
+
if chunk_size > 0 and left_chunks > 0: # for streaming mode
|
298 |
+
self.buffersize = chunk_size * (left_chunks)
|
299 |
+
self.left_chunk_size = chunk_size * left_chunks
|
300 |
+
else: # for non-streaming mode
|
301 |
+
self.buffersize = 1
|
302 |
+
self.left_chunk_size = 1
|
303 |
+
self.chunk_size = chunk_size
|
304 |
+
|
305 |
+
# encoding setup
|
306 |
+
if pos_enc_class == "rel-enc":
|
307 |
+
self.rel_enc = True
|
308 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
309 |
+
# these two learnable bias are used in matrix c and matrix d
|
310 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
311 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
312 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
313 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
314 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
315 |
+
else:
|
316 |
+
self.rel_enc = False
|
317 |
+
self.linear_pos = nn.Identity()
|
318 |
+
self.pos_bias_u = torch.tensor([0])
|
319 |
+
self.pos_bias_v = torch.tensor([0])
|
320 |
+
|
321 |
+
# buffer
|
322 |
+
# key_buffer = 1, self.h, self.buffersize, self.d_k
|
323 |
+
self.key_buffer_size = 1 * self.h * self.buffersize * self.d_k
|
324 |
+
# value_buffer = 1, self.h, self.buffersize, self.d_k
|
325 |
+
self.value_buffer_size = 1 * self.h * self.buffersize * self.d_k
|
326 |
+
if self.chunk_size > 0:
|
327 |
+
# buffer_mask_size = 1, self.h, self.chunk_size, self.buffersize
|
328 |
+
self.buffer_mask_size = 1 * self.h * self.chunk_size * self.buffersize
|
329 |
+
# self.buffer_mask = torch.ones([1, self.h, self.chunk_size, self.buffersize], dtype=torch.bool)
|
330 |
+
else:
|
331 |
+
self.buffer_mask = torch.ones([1, self.h, 1, 1], dtype=torch.bool)
|
332 |
+
|
333 |
+
@torch.jit.unused
|
334 |
+
def rel_shift(self, x, zero_triu: bool = False):
|
335 |
+
"""Compute relative positinal encoding.
|
336 |
+
Args:
|
337 |
+
x (torch.Tensor): Input tensor (batch, time, size).
|
338 |
+
zero_triu (bool): If true, return the lower triangular part of
|
339 |
+
the matrix.
|
340 |
+
Returns:
|
341 |
+
torch.Tensor: Output tensor.
|
342 |
+
"""
|
343 |
+
|
344 |
+
zero_pad = torch.zeros(
|
345 |
+
(x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
|
346 |
+
)
|
347 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
348 |
+
|
349 |
+
x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
|
350 |
+
x = x_padded[:, :, 1:].view_as(x)
|
351 |
+
|
352 |
+
if zero_triu:
|
353 |
+
ones = torch.ones((x.size(2), x.size(3)))
|
354 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
355 |
+
return x
|
356 |
+
|
357 |
+
@torch.jit.export
|
358 |
+
def forward(self, query, key, value, mask=None, pos_emb=torch.tensor(1.0)):
|
359 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tensor) -> Tensor
|
360 |
+
"""Compute 'Scaled Dot Product Attention'.
|
361 |
+
|
362 |
+
:param torch.Tensor query: (batch, time1, size)
|
363 |
+
:param torch.Tensor key: (batch, time2, size)
|
364 |
+
:param torch.Tensor value: (batch, time2, size)
|
365 |
+
:param torch.Tensor mask: (batch, time1, time2)
|
366 |
+
:param torch.nn.Dropout dropout:
|
367 |
+
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
|
368 |
+
weighted by the query dot key attention (batch, head, time1, time2)
|
369 |
+
"""
|
370 |
+
n_batch = query.size(0)
|
371 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
372 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
373 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
374 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
375 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
376 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
377 |
+
|
378 |
+
if self.rel_enc:
|
379 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
380 |
+
n_batch_pos = pos_emb.size(0)
|
381 |
+
p = self.linear_pos(pos_emb.to(query.dtype)).view(n_batch_pos, -1, self.h, self.d_k)
|
382 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
383 |
+
# (batch, head, time1, d_k)
|
384 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
385 |
+
# (batch, head, time1, d_k)
|
386 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
387 |
+
# compute attention score
|
388 |
+
# first compute matrix a and matrix c
|
389 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
390 |
+
# (batch, head, time1, time2)
|
391 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
392 |
+
# compute matrix b and matrix d
|
393 |
+
# (batch, head, time1, time2)
|
394 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
395 |
+
# Remove rel_shift since it is useless in speech recognition,
|
396 |
+
# and it requires special attention for streaming.
|
397 |
+
# matrix_bd = self.rel_shift(matrix_bd)
|
398 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
|
399 |
+
else:
|
400 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
|
401 |
+
self.d_k
|
402 |
+
) # (batch, head, time1, time2)
|
403 |
+
|
404 |
+
if mask is not None:
|
405 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
406 |
+
scores = scores.masked_fill(mask, self.min_value)
|
407 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
408 |
+
mask, 0.0
|
409 |
+
) # (batch, head, time1, time2)
|
410 |
+
else:
|
411 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
412 |
+
|
413 |
+
p_attn = self.dropout(attn)
|
414 |
+
|
415 |
+
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
|
416 |
+
x = (
|
417 |
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
418 |
+
) # (batch, time1, d_model)
|
419 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
420 |
+
|
421 |
+
@torch.jit.export
|
422 |
+
def infer(self, query, key, value, pos_emb, buffer, buffer_index, buffer_out):
|
423 |
+
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
424 |
+
n_batch = query.size(0)
|
425 |
+
|
426 |
+
q = (
|
427 |
+
self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
|
428 |
+
) # (batch, head, len_q, d_k)
|
429 |
+
k = (
|
430 |
+
self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
|
431 |
+
) # (batch, head, len_k, d_k)
|
432 |
+
v = (
|
433 |
+
self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
|
434 |
+
) # (batch, head, len_v, d_k)
|
435 |
+
|
436 |
+
key_value_buffer = buffer[
|
437 |
+
buffer_index : buffer_index + self.key_buffer_size + self.value_buffer_size
|
438 |
+
].reshape([1, self.h, self.buffersize * 2, self.d_k])
|
439 |
+
key_buffer = torch.cat([key_value_buffer[:, :, : self.buffersize, :], k], dim=2)
|
440 |
+
value_buffer = torch.cat([key_value_buffer[:, :, self.buffersize :, :], v], dim=2)
|
441 |
+
buffer_out.append(
|
442 |
+
torch.cat(
|
443 |
+
[key_buffer[:, :, self.chunk_size :, :], value_buffer[:, :, self.chunk_size :, :]],
|
444 |
+
dim=2,
|
445 |
+
).reshape(-1)
|
446 |
+
)
|
447 |
+
buffer_index = buffer_index + self.key_buffer_size + self.value_buffer_size
|
448 |
+
|
449 |
+
if self.rel_enc:
|
450 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
451 |
+
n_batch_pos = pos_emb.size(0)
|
452 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
453 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
454 |
+
# (batch, head, time1, d_k)
|
455 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
456 |
+
# (batch, head, time1, d_k)
|
457 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
458 |
+
# compute attention score
|
459 |
+
# first compute matrix a and matrix c
|
460 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
461 |
+
# (batch, head, time1, time2)
|
462 |
+
matrix_ac = torch.matmul(q_with_bias_u, key_buffer.transpose(-2, -1))
|
463 |
+
# compute matrix b and matrix d
|
464 |
+
# (batch, head, time1, time2)
|
465 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
466 |
+
# Remove rel_shift since it is useless in speech recognition,
|
467 |
+
# and it requires special attention for streaming.
|
468 |
+
# matrix_bd = self.rel_shift(matrix_bd)
|
469 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
|
470 |
+
else:
|
471 |
+
scores = torch.matmul(q, key_buffer.transpose(-2, -1)) / math.sqrt(
|
472 |
+
self.d_k
|
473 |
+
) # (batch, head, len_q, buffersize)
|
474 |
+
|
475 |
+
attn = torch.softmax(scores, dim=-1)
|
476 |
+
|
477 |
+
x = torch.matmul(attn, value_buffer) # (batch, head, len_q, d_k)
|
478 |
+
x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
479 |
+
return self.linear_out(x), buffer, buffer_index, buffer_out # (batch, time1, d_model)
|
480 |
+
|
481 |
+
@torch.jit.export
|
482 |
+
def infer_mask(self, query, key, value, mask, buffer, buffer_index, buffer_out, is_static):
|
483 |
+
n_batch = query.size(0)
|
484 |
+
|
485 |
+
q = (
|
486 |
+
self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
|
487 |
+
) # (batch, head, len_q, d_k)
|
488 |
+
k = (
|
489 |
+
self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
|
490 |
+
) # (batch, head, len_k, d_k)
|
491 |
+
v = (
|
492 |
+
self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
|
493 |
+
) # (batch, head, len_v, d_k)
|
494 |
+
|
495 |
+
if is_static:
|
496 |
+
key_buffer = k
|
497 |
+
value_buffer = v
|
498 |
+
else:
|
499 |
+
key_value_buffer = buffer[
|
500 |
+
buffer_index : buffer_index + self.key_buffer_size + self.value_buffer_size
|
501 |
+
].reshape([1, self.h, self.buffersize * 2, self.d_k])
|
502 |
+
key_buffer = torch.cat([key_value_buffer[:, :, : self.buffersize, :], k], dim=2)
|
503 |
+
value_buffer = torch.cat([key_value_buffer[:, :, self.buffersize :, :], v], dim=2)
|
504 |
+
buffer_out.append(
|
505 |
+
torch.cat(
|
506 |
+
[
|
507 |
+
key_buffer[:, :, self.chunk_size :, :],
|
508 |
+
value_buffer[:, :, self.chunk_size :, :],
|
509 |
+
],
|
510 |
+
dim=2,
|
511 |
+
).reshape(-1)
|
512 |
+
)
|
513 |
+
buffer_index = buffer_index + self.key_buffer_size + self.value_buffer_size
|
514 |
+
|
515 |
+
scores = torch.matmul(q, key_buffer.transpose(-2, -1)) / math.sqrt(
|
516 |
+
self.d_k
|
517 |
+
) # (batch, head, len_q, buffersize)
|
518 |
+
|
519 |
+
if mask is not None:
|
520 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
521 |
+
scores = scores.masked_fill(mask, self.min_value)
|
522 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
523 |
+
mask, 0.0
|
524 |
+
) # (batch, head, time1, time2)
|
525 |
+
else:
|
526 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
527 |
+
|
528 |
+
x = torch.matmul(attn, value_buffer) # (batch, head, len_q, d_k)
|
529 |
+
x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
530 |
+
return self.linear_out(x), buffer_index, buffer_out # (batch, time1, d_model)
|
531 |
+
|
532 |
+
|
533 |
+
class SoftAttention(nn.Module):
|
534 |
+
def __init__(self, in_dim, hidden_dim):
|
535 |
+
super(SoftAttention, self).__init__()
|
536 |
+
self.q = torch.nn.Parameter(torch.rand([hidden_dim]), requires_grad=True)
|
537 |
+
self.wb = nn.Linear(in_dim, hidden_dim)
|
538 |
+
self.min_value = float(numpy.finfo(torch.tensor(0, dtype=torch.float32).numpy().dtype).min)
|
539 |
+
# buffer
|
540 |
+
self.window_size = 50
|
541 |
+
self.buffer_in = torch.zeros([1, self.window_size, in_dim], dtype=torch.float32)
|
542 |
+
self.buffer = torch.zeros([1, self.window_size], dtype=torch.float32)
|
543 |
+
self.buffer[:, :] = float(
|
544 |
+
numpy.finfo(torch.tensor(0, dtype=torch.float32).numpy().dtype).min
|
545 |
+
)
|
546 |
+
|
547 |
+
@torch.jit.unused
|
548 |
+
def forward(self, x, mask=None):
|
549 |
+
hidden = torch.tanh(self.wb(x)) # B T D
|
550 |
+
hidden = torch.einsum("btd,d->bt", hidden, self.q)
|
551 |
+
score = torch.softmax(hidden, dim=-1) # B T
|
552 |
+
if mask is not None:
|
553 |
+
score = score.masked_fill(mask, 0.0)
|
554 |
+
output = torch.einsum("bt,btd->bd", score, x)
|
555 |
+
return output
|
556 |
+
|
557 |
+
@torch.jit.export
|
558 |
+
def infer(self, x):
|
559 |
+
# type: (Tensor) -> Tensor
|
560 |
+
hidden = torch.tanh(self.wb(x)) # B T D
|
561 |
+
hidden = torch.einsum("btd,d->bt", hidden, self.q)
|
562 |
+
size = hidden.shape[1]
|
563 |
+
output = torch.zeros([size, x.shape[-1]])
|
564 |
+
for i in range(size):
|
565 |
+
self.buffer = torch.cat([self.buffer, hidden[:, i : i + 1]], dim=-1)
|
566 |
+
self.buffer = self.buffer[:, 1:]
|
567 |
+
score = torch.softmax(self.buffer, dim=-1) # B T
|
568 |
+
self.buffer_in = torch.cat([self.buffer_in, x[:, i : i + 1, :]], dim=1)
|
569 |
+
self.buffer_in = self.buffer_in[:, 1:]
|
570 |
+
output[i : i + 1] = torch.einsum("bt,btd->bd", score, self.buffer_in)
|
571 |
+
return output
|
vita/model/multimodal_encoder/whale/module/layer/conv1d.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Conv1dLayer(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
input_dim,
|
10 |
+
output_dim,
|
11 |
+
kernel_size,
|
12 |
+
stride,
|
13 |
+
causal_conv,
|
14 |
+
dilation,
|
15 |
+
dropout_rate,
|
16 |
+
residual=True,
|
17 |
+
):
|
18 |
+
super(Conv1dLayer, self).__init__()
|
19 |
+
self.input_dim = input_dim
|
20 |
+
self.output_dim = output_dim
|
21 |
+
self.kernel_size = kernel_size
|
22 |
+
self.stride = stride
|
23 |
+
self.dilation = dilation
|
24 |
+
self.causal_conv = causal_conv
|
25 |
+
if causal_conv:
|
26 |
+
self.lorder = (kernel_size - 1) * self.dilation
|
27 |
+
self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
|
28 |
+
else:
|
29 |
+
assert (kernel_size - 1) % 2 == 0
|
30 |
+
self.lorder = ((kernel_size - 1) // 2) * self.dilation
|
31 |
+
self.left_padding = nn.ConstantPad1d((self.lorder, self.lorder), 0.0)
|
32 |
+
self.conv1d = nn.Conv1d(
|
33 |
+
self.input_dim, self.output_dim, self.kernel_size, self.stride, 0, self.dilation
|
34 |
+
)
|
35 |
+
self.bn = nn.BatchNorm1d(self.output_dim, eps=1e-3, momentum=0.99)
|
36 |
+
self.relu = nn.ReLU()
|
37 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
38 |
+
self.residual = residual
|
39 |
+
if self.input_dim != self.output_dim:
|
40 |
+
self.residual = False
|
41 |
+
|
42 |
+
# buffer = 1, self.input_dim, self.lorder
|
43 |
+
self.lorder = (kernel_size - 1) * self.dilation - (self.stride - 1)
|
44 |
+
self.buffer_size = 1 * self.input_dim * self.lorder
|
45 |
+
self.x_data_chache_size = self.lorder
|
46 |
+
self.x_data_buffer_size = self.input_dim * self.x_data_chache_size
|
47 |
+
|
48 |
+
@torch.jit.unused
|
49 |
+
def forward(self, x):
|
50 |
+
x_data = x
|
51 |
+
x = self.left_padding(x)
|
52 |
+
x = self.conv1d(x)
|
53 |
+
x = self.bn(x)
|
54 |
+
if self.stride == 1 and self.residual:
|
55 |
+
x = self.relu(x + x_data)
|
56 |
+
else:
|
57 |
+
x = self.relu(x)
|
58 |
+
x = self.dropout(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
@torch.jit.export
|
62 |
+
def infer(self, x, buffer, buffer_index, buffer_out):
|
63 |
+
# type: (Tensor) -> Tensor
|
64 |
+
x_data = x.clone()
|
65 |
+
|
66 |
+
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
|
67 |
+
[1, self.input_dim, self.lorder]
|
68 |
+
)
|
69 |
+
x = torch.cat([cnn_buffer, x], dim=2)
|
70 |
+
buffer_out.append(x[:, :, -self.lorder :].reshape(-1))
|
71 |
+
buffer_index = buffer_index + self.buffer_size
|
72 |
+
|
73 |
+
x = self.conv1d(x)
|
74 |
+
x = self.bn(x)
|
75 |
+
|
76 |
+
if self.stride == 1 and self.residual:
|
77 |
+
x_data_cnn_buffer = buffer[
|
78 |
+
buffer_index : buffer_index + self.x_data_buffer_size
|
79 |
+
].reshape([1, self.input_dim, self.x_data_chache_size])
|
80 |
+
x_data = torch.cat([x_data_cnn_buffer, x_data], dim=2)
|
81 |
+
buffer_out.append(x_data[:, :, -self.x_data_chache_size :].reshape(-1))
|
82 |
+
buffer_index = buffer_index + self.x_data_buffer_size
|
83 |
+
x_data = x_data[:, :, : -self.x_data_chache_size]
|
84 |
+
x = self.relu(x + x_data)
|
85 |
+
else:
|
86 |
+
x = self.relu(x)
|
87 |
+
|
88 |
+
return x, buffer, buffer_index, buffer_out
|
vita/model/multimodal_encoder/whale/module/layer/dtcblock.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import pdb
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class DTCBlock(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self, input_dim, output_dim, kernel_size, stride, causal_conv, dilation, dropout_rate
|
13 |
+
):
|
14 |
+
super(DTCBlock, self).__init__()
|
15 |
+
self.input_dim = input_dim
|
16 |
+
self.output_dim = output_dim
|
17 |
+
self.kernel_size = kernel_size
|
18 |
+
self.stride = stride
|
19 |
+
self.dilation = dilation
|
20 |
+
if causal_conv:
|
21 |
+
self.padding = 0
|
22 |
+
self.lorder = (kernel_size - 1) * self.dilation
|
23 |
+
self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
|
24 |
+
else:
|
25 |
+
assert (kernel_size - 1) % 2 == 0
|
26 |
+
self.padding = ((kernel_size - 1) // 2) * self.dilation
|
27 |
+
self.lorder = 0
|
28 |
+
self.causal_conv = causal_conv
|
29 |
+
self.depthwise_conv = nn.Conv1d(
|
30 |
+
self.input_dim,
|
31 |
+
self.input_dim,
|
32 |
+
self.kernel_size,
|
33 |
+
self.stride,
|
34 |
+
self.padding,
|
35 |
+
self.dilation,
|
36 |
+
groups=self.input_dim,
|
37 |
+
)
|
38 |
+
self.point_conv_1 = nn.Conv1d(self.input_dim, self.input_dim, 1, 1, self.padding)
|
39 |
+
self.point_conv_2 = nn.Conv1d(self.input_dim, self.input_dim, 1, 1, self.padding)
|
40 |
+
self.bn_1 = nn.BatchNorm1d(self.input_dim)
|
41 |
+
self.bn_2 = nn.BatchNorm1d(self.input_dim)
|
42 |
+
self.bn_3 = nn.BatchNorm1d(self.input_dim)
|
43 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
44 |
+
|
45 |
+
# buffer = 1, self.input_dim, self.lorder
|
46 |
+
self.lorder = (kernel_size - 1) * self.dilation - (self.stride - 1)
|
47 |
+
self.buffer_size = 1 * self.input_dim * self.lorder
|
48 |
+
|
49 |
+
@torch.jit.unused
|
50 |
+
def forward(self, x):
|
51 |
+
x_in = x
|
52 |
+
x_data = x_in.transpose(1, 2)
|
53 |
+
if self.causal_conv:
|
54 |
+
x_data_pad = self.left_padding(x_data)
|
55 |
+
else:
|
56 |
+
x_data_pad = x_data
|
57 |
+
x_depth = self.depthwise_conv(x_data_pad)
|
58 |
+
x_bn_1 = self.bn_1(x_depth)
|
59 |
+
x_point_1 = self.point_conv_1(x_bn_1)
|
60 |
+
x_bn_2 = self.bn_2(x_point_1)
|
61 |
+
x_relu_2 = torch.relu(x_bn_2)
|
62 |
+
x_point_2 = self.point_conv_2(x_relu_2)
|
63 |
+
x_bn_3 = self.bn_3(x_point_2)
|
64 |
+
x_bn_3 = x_bn_3.transpose(1, 2)
|
65 |
+
if self.stride == 1:
|
66 |
+
x_relu_3 = torch.relu(x_bn_3 + x_in)
|
67 |
+
else:
|
68 |
+
x_relu_3 = torch.relu(x_bn_3)
|
69 |
+
x_drop = self.dropout(x_relu_3)
|
70 |
+
return x_drop
|
71 |
+
|
72 |
+
@torch.jit.export
|
73 |
+
def infer(self, x, buffer, buffer_index, buffer_out):
|
74 |
+
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
75 |
+
x_in = x
|
76 |
+
x = x_in.transpose(1, 2)
|
77 |
+
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
|
78 |
+
[1, self.input_dim, self.lorder]
|
79 |
+
)
|
80 |
+
x = torch.cat([cnn_buffer, x], dim=2)
|
81 |
+
buffer_out.append(x[:, :, -self.lorder :].reshape(-1))
|
82 |
+
buffer_index = buffer_index + self.buffer_size
|
83 |
+
x = self.depthwise_conv(x)
|
84 |
+
x = self.bn_1(x)
|
85 |
+
x = self.point_conv_1(x)
|
86 |
+
x = self.bn_2(x)
|
87 |
+
x = torch.relu(x)
|
88 |
+
x = self.point_conv_2(x)
|
89 |
+
x = self.bn_3(x)
|
90 |
+
x = x.transpose(1, 2)
|
91 |
+
if self.stride == 1:
|
92 |
+
x = torch.relu(x + x_in)
|
93 |
+
else:
|
94 |
+
x = torch.relu(x)
|
95 |
+
return x, buffer, buffer_index, buffer_out
|
vita/model/multimodal_encoder/whale/module/layer/fsmn.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class FsmnLayer(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
input_dim,
|
10 |
+
out_dim,
|
11 |
+
hidden_dim,
|
12 |
+
left_frame=1,
|
13 |
+
right_frame=1,
|
14 |
+
left_dilation=1,
|
15 |
+
right_dilation=1,
|
16 |
+
):
|
17 |
+
super(FsmnLayer, self).__init__()
|
18 |
+
self.input_dim = input_dim
|
19 |
+
self.out_dim = out_dim
|
20 |
+
self.hidden_dim = hidden_dim
|
21 |
+
self.left_frame = left_frame
|
22 |
+
self.right_frame = right_frame
|
23 |
+
self.left_dilation = left_dilation
|
24 |
+
self.right_dilation = right_dilation
|
25 |
+
self.conv_in = nn.Conv1d(input_dim, hidden_dim, kernel_size=1)
|
26 |
+
if left_frame > 0:
|
27 |
+
self.pad_left = nn.ConstantPad1d([left_dilation * left_frame, 0], 0.0)
|
28 |
+
self.conv_left = nn.Conv1d(
|
29 |
+
hidden_dim,
|
30 |
+
hidden_dim,
|
31 |
+
kernel_size=left_frame + 1,
|
32 |
+
dilation=left_dilation,
|
33 |
+
bias=False,
|
34 |
+
groups=hidden_dim,
|
35 |
+
)
|
36 |
+
if right_frame > 0:
|
37 |
+
self.pad_right = nn.ConstantPad1d([-right_dilation, right_dilation * right_frame], 0.0)
|
38 |
+
self.conv_right = nn.Conv1d(
|
39 |
+
hidden_dim,
|
40 |
+
hidden_dim,
|
41 |
+
kernel_size=right_frame,
|
42 |
+
dilation=right_dilation,
|
43 |
+
bias=False,
|
44 |
+
groups=hidden_dim,
|
45 |
+
)
|
46 |
+
self.conv_out = nn.Conv1d(hidden_dim, out_dim, kernel_size=1)
|
47 |
+
|
48 |
+
# cache = 1, self.hidden_dim, left_frame * left_dilation + right_frame * right_dilation
|
49 |
+
self.cache_size = left_frame * left_dilation + right_frame * right_dilation
|
50 |
+
self.buffer_size = self.hidden_dim * self.cache_size
|
51 |
+
self.p_in_raw_chache_size = self.right_frame * self.right_dilation
|
52 |
+
self.p_in_raw_buffer_size = self.hidden_dim * self.p_in_raw_chache_size
|
53 |
+
self.hidden_chache_size = self.right_frame * self.right_dilation
|
54 |
+
self.hidden_buffer_size = self.hidden_dim * self.hidden_chache_size
|
55 |
+
|
56 |
+
@torch.jit.unused
|
57 |
+
def forward(self, x, hidden=None):
|
58 |
+
x_data = x.transpose(1, 2)
|
59 |
+
p_in = self.conv_in(x_data)
|
60 |
+
if self.left_frame > 0:
|
61 |
+
p_left = self.pad_left(p_in)
|
62 |
+
p_left = self.conv_left(p_left)
|
63 |
+
else:
|
64 |
+
p_left = 0
|
65 |
+
if self.right_frame > 0:
|
66 |
+
p_right = self.pad_right(p_in)
|
67 |
+
p_right = self.conv_right(p_right)
|
68 |
+
else:
|
69 |
+
p_right = 0
|
70 |
+
p_out = p_in + p_right + p_left
|
71 |
+
if hidden is not None:
|
72 |
+
p_out = hidden + p_out
|
73 |
+
out = F.relu(self.conv_out(p_out))
|
74 |
+
out = out.transpose(1, 2)
|
75 |
+
return out, p_out
|
76 |
+
|
77 |
+
@torch.jit.export
|
78 |
+
def infer(self, x, buffer, buffer_index, buffer_out, hidden=None):
|
79 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor]
|
80 |
+
p_in_raw = self.conv_in(x)
|
81 |
+
|
82 |
+
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
|
83 |
+
[1, self.hidden_dim, self.cache_size]
|
84 |
+
)
|
85 |
+
p_in = torch.cat([cnn_buffer, p_in_raw], dim=2)
|
86 |
+
# buffer[buffer_index: buffer_index + self.buffer_size] = p_in[:, :, -self.cache_size:].reshape(-1)
|
87 |
+
buffer_out.append(p_in[:, :, -self.cache_size :].reshape(-1))
|
88 |
+
buffer_index = buffer_index + self.buffer_size
|
89 |
+
|
90 |
+
if self.left_frame > 0:
|
91 |
+
if self.right_frame > 0:
|
92 |
+
p_left = p_in[:, :, : -self.right_frame * self.right_dilation]
|
93 |
+
else:
|
94 |
+
p_left = p_in[:, :]
|
95 |
+
p_left_out = self.conv_left(p_left)
|
96 |
+
else:
|
97 |
+
p_left_out = torch.tensor([0])
|
98 |
+
if self.right_frame > 0:
|
99 |
+
p_right = p_in[:, :, self.left_frame * self.left_dilation + 1 :]
|
100 |
+
p_right_out = self.conv_right(p_right)
|
101 |
+
else:
|
102 |
+
p_right_out = torch.tensor([0])
|
103 |
+
|
104 |
+
if self.right_frame > 0:
|
105 |
+
p_in_raw_cnn_buffer = buffer[
|
106 |
+
buffer_index : buffer_index + self.p_in_raw_buffer_size
|
107 |
+
].reshape([1, self.hidden_dim, self.p_in_raw_chache_size])
|
108 |
+
p_in_raw = torch.cat([p_in_raw_cnn_buffer, p_in_raw], dim=2)
|
109 |
+
# buffer[buffer_index: buffer_index + self.p_in_raw_buffer_size] = p_in_raw[:, :, -self.p_in_raw_chache_size:].reshape(-1)
|
110 |
+
buffer_out.append(p_in_raw[:, :, -self.p_in_raw_chache_size :].reshape(-1))
|
111 |
+
buffer_index = buffer_index + self.p_in_raw_buffer_size
|
112 |
+
p_in_raw = p_in_raw[:, :, : -self.p_in_raw_chache_size]
|
113 |
+
p_out = p_in_raw + p_left_out + p_right_out
|
114 |
+
|
115 |
+
if hidden is not None:
|
116 |
+
if self.right_frame > 0:
|
117 |
+
hidden_cnn_buffer = buffer[
|
118 |
+
buffer_index : buffer_index + self.hidden_buffer_size
|
119 |
+
].reshape([1, self.hidden_dim, self.hidden_chache_size])
|
120 |
+
hidden = torch.cat([hidden_cnn_buffer, hidden], dim=2)
|
121 |
+
# buffer[buffer_index: buffer_index + self.hidden_buffer_size] = hidden[:, :, -self.hidden_chache_size:].reshape(-1)
|
122 |
+
buffer_out.append(hidden[:, :, -self.hidden_chache_size :].reshape(-1))
|
123 |
+
buffer_index = buffer_index + self.hidden_buffer_size
|
124 |
+
hidden = hidden[:, :, : -self.hidden_chache_size]
|
125 |
+
p_out = hidden + p_out
|
126 |
+
|
127 |
+
out = F.relu(self.conv_out(p_out))
|
128 |
+
|
129 |
+
return out, buffer, buffer_index, buffer_out, p_out
|
vita/model/multimodal_encoder/whale/utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import importlib
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from distutils.util import strtobool as dist_strtobool
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
IGNORE_ID = -1
|
11 |
+
|
12 |
+
|
13 |
+
def assign_args_from_yaml(args, yaml_path, prefix_key=None):
|
14 |
+
with open(yaml_path) as f:
|
15 |
+
ydict = yaml.load(f, Loader=yaml.FullLoader)
|
16 |
+
if prefix_key is not None:
|
17 |
+
ydict = ydict[prefix_key]
|
18 |
+
for k, v in ydict.items():
|
19 |
+
k_args = k.replace("-", "_")
|
20 |
+
if hasattr(args, k_args):
|
21 |
+
setattr(args, k_args, ydict[k])
|
22 |
+
return args
|
23 |
+
|
24 |
+
|
25 |
+
def get_model_conf(model_path):
|
26 |
+
model_conf = os.path.dirname(model_path) + "/model.json"
|
27 |
+
with open(model_conf, "rb") as f:
|
28 |
+
print("reading a config file from " + model_conf)
|
29 |
+
confs = json.load(f)
|
30 |
+
# for asr, tts, mt
|
31 |
+
idim, odim, args = confs
|
32 |
+
return argparse.Namespace(**args)
|
33 |
+
|
34 |
+
|
35 |
+
def strtobool(x):
|
36 |
+
return bool(dist_strtobool(x))
|
37 |
+
|
38 |
+
|
39 |
+
def dynamic_import(import_path, alias=dict()):
|
40 |
+
"""dynamic import module and class
|
41 |
+
|
42 |
+
:param str import_path: syntax 'module_name:class_name'
|
43 |
+
e.g., 'espnet.transform.add_deltas:AddDeltas'
|
44 |
+
:param dict alias: shortcut for registered class
|
45 |
+
:return: imported class
|
46 |
+
"""
|
47 |
+
if import_path not in alias and ":" not in import_path:
|
48 |
+
raise ValueError(
|
49 |
+
"import_path should be one of {} or "
|
50 |
+
'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : '
|
51 |
+
"{}".format(set(alias), import_path)
|
52 |
+
)
|
53 |
+
if ":" not in import_path:
|
54 |
+
import_path = alias[import_path]
|
55 |
+
|
56 |
+
module_name, objname = import_path.split(":")
|
57 |
+
m = importlib.import_module(module_name)
|
58 |
+
return getattr(m, objname)
|
59 |
+
|
60 |
+
|
61 |
+
def set_deterministic_pytorch(args):
|
62 |
+
# seed setting
|
63 |
+
torch.manual_seed(args.seed)
|
64 |
+
|
65 |
+
torch.backends.cudnn.deterministic = False
|
66 |
+
torch.backends.cudnn.benchmark = False
|
67 |
+
|
68 |
+
|
69 |
+
def pad_list(xs, pad_value):
|
70 |
+
n_batch = len(xs)
|
71 |
+
max_len = max(x.size(0) for x in xs)
|
72 |
+
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
73 |
+
for i in range(n_batch):
|
74 |
+
pad[i, : xs[i].size(0)] = xs[i]
|
75 |
+
return pad
|
76 |
+
|
77 |
+
|
78 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
79 |
+
batch_size = lengths.size(0)
|
80 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
81 |
+
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
|
82 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
83 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
84 |
+
mask = seq_range_expand >= seq_length_expand
|
85 |
+
return mask
|
86 |
+
|
87 |
+
|
88 |
+
def subsequent_chunk_mask(
|
89 |
+
size: int,
|
90 |
+
ck_size: int,
|
91 |
+
num_l_cks: int = -1,
|
92 |
+
device: torch.device = torch.device("cpu"),
|
93 |
+
) -> torch.Tensor:
|
94 |
+
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
95 |
+
for i in range(size):
|
96 |
+
if num_l_cks < 0:
|
97 |
+
start = 0
|
98 |
+
else:
|
99 |
+
start = max((i // ck_size - num_l_cks) * ck_size, 0)
|
100 |
+
ending = min((i // ck_size + 1) * ck_size, size)
|
101 |
+
ret[i, start:ending] = True
|
102 |
+
return ret
|
103 |
+
|
104 |
+
|
105 |
+
def add_optional_chunk_mask(
|
106 |
+
xs: torch.Tensor,
|
107 |
+
masks: torch.Tensor,
|
108 |
+
use_dynamic_chunk: bool,
|
109 |
+
use_dynamic_left_chunk: bool,
|
110 |
+
decoding_chunk_size: int,
|
111 |
+
static_chunk_size: int,
|
112 |
+
num_decoding_left_chunks: int,
|
113 |
+
):
|
114 |
+
if use_dynamic_chunk:
|
115 |
+
max_len = xs.size(1)
|
116 |
+
if decoding_chunk_size < 0:
|
117 |
+
chunk_size = max_len
|
118 |
+
num_l_cks = -1
|
119 |
+
elif decoding_chunk_size > 0:
|
120 |
+
chunk_size = decoding_chunk_size
|
121 |
+
num_l_cks = num_decoding_left_chunks
|
122 |
+
else:
|
123 |
+
chunk_size = torch.randint(1, max_len, (1,)).item()
|
124 |
+
num_l_cks = -1
|
125 |
+
if chunk_size > max_len // 2:
|
126 |
+
chunk_size = max_len
|
127 |
+
else:
|
128 |
+
chunk_size = chunk_size % 25 + 1
|
129 |
+
if use_dynamic_left_chunk:
|
130 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
131 |
+
num_l_cks = torch.randint(0, max_left_chunks, (1,)).item()
|
132 |
+
ck_masks = subsequent_chunk_mask(
|
133 |
+
xs.size(1), chunk_size, num_l_cks, xs.device
|
134 |
+
) # (L, L)
|
135 |
+
ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
|
136 |
+
ck_masks = masks & ck_masks # (B, L, L)
|
137 |
+
elif static_chunk_size > 0:
|
138 |
+
num_l_cks = num_decoding_left_chunks
|
139 |
+
ck_masks = subsequent_chunk_mask(
|
140 |
+
xs.size(1), static_chunk_size, num_l_cks, xs.device
|
141 |
+
) # (L, L)
|
142 |
+
ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
|
143 |
+
ck_masks = masks & ck_masks # (B, L, L)
|
144 |
+
else:
|
145 |
+
ck_masks = masks
|
146 |
+
return ck_masks
|
vita/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import re
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from timm.layers.norm_act import LayerNormAct2d
|
8 |
+
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig
|
9 |
+
from torchvision.ops.misc import SqueezeExcitation as SELayer
|
10 |
+
|
11 |
+
|
12 |
+
class IdentityMap(nn.Module):
|
13 |
+
def __init__(self):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
def forward(self, x, *args, **kwargs):
|
17 |
+
return x
|
18 |
+
|
19 |
+
@property
|
20 |
+
def config(self):
|
21 |
+
return {"mm_projector_type": "identity"}
|
22 |
+
|
23 |
+
|
24 |
+
class Minigpt(nn.Module):
|
25 |
+
def __init__(self, config=None):
|
26 |
+
super(Minigpt, self).__init__()
|
27 |
+
# c*4 is the input size, and c is the output size for the linear layer
|
28 |
+
inc, ouc = config.mm_hidden_size, config.hidden_size
|
29 |
+
self.linear = nn.Linear(inc * 4, ouc)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
# x is the input tensor with shape [b, num_tokens, c]
|
33 |
+
b, num_tokens, c = x.shape
|
34 |
+
|
35 |
+
# Check if num_tokens is divisible by 4
|
36 |
+
if num_tokens % 4 != 0:
|
37 |
+
raise ValueError("num_tokens must be divisible by 4")
|
38 |
+
|
39 |
+
# Reshape x to [b, num_tokens/4, c*4]
|
40 |
+
x = x.view(b, num_tokens // 4, c * 4)
|
41 |
+
|
42 |
+
# Apply the linear transformation
|
43 |
+
x = self.linear(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class Vanilla(nn.Module):
|
48 |
+
def __init__(self, config=None):
|
49 |
+
super(Vanilla, self).__init__()
|
50 |
+
# c*4 is the input size, and c is the output size for the linear layer
|
51 |
+
inc, ouc = config.mm_hidden_size, config.hidden_size
|
52 |
+
self.linear = nn.Linear(inc * 4, ouc)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
b, num_tokens, c = x.shape
|
56 |
+
|
57 |
+
# Check if num_tokens is divisible by 4
|
58 |
+
if num_tokens % 4 != 0:
|
59 |
+
raise ValueError("num_tokens must be divisible by 4")
|
60 |
+
|
61 |
+
# First, reshape to [b, num_tokens//4, 4, c]
|
62 |
+
x = x.view(b, num_tokens // 4, 4, c)
|
63 |
+
|
64 |
+
# Then, permute to interleave the tokens
|
65 |
+
x = x.permute(0, 1, 3, 2).contiguous()
|
66 |
+
|
67 |
+
# Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens
|
68 |
+
x = x.view(b, num_tokens // 4, c * 4)
|
69 |
+
|
70 |
+
# Apply the linear transformation
|
71 |
+
x = self.linear(x)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class LDPBlock(nn.Module):
|
76 |
+
# Lightweight Downsample Projector Block
|
77 |
+
|
78 |
+
def __init__(self, config=None):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
inc, ouc = config.mm_hidden_size, config.hidden_size
|
82 |
+
layer_norm = partial(LayerNormAct2d, act_layer=None)
|
83 |
+
se_layer = partial(SELayer, scale_activation=nn.Hardsigmoid)
|
84 |
+
self.mlp = nn.Sequential(nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc))
|
85 |
+
self.mb_block = nn.Sequential(
|
86 |
+
nn.Identity(),
|
87 |
+
InvertedResidual(
|
88 |
+
InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer
|
89 |
+
),
|
90 |
+
InvertedResidual(
|
91 |
+
InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer
|
92 |
+
),
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
b, num_tokens, c = x.shape
|
97 |
+
h = int(math.sqrt(num_tokens))
|
98 |
+
x = self.mlp(x)
|
99 |
+
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
|
100 |
+
x = self.mb_block(x)
|
101 |
+
x = x.flatten(2).permute(0, 2, 1)
|
102 |
+
return x
|
103 |
+
|
104 |
+
|
105 |
+
class LDPNetProjector(nn.Module):
|
106 |
+
def __init__(self, config=None):
|
107 |
+
super().__init__()
|
108 |
+
self.model = LDPBlock(config)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
return self.model(x)
|
112 |
+
|
113 |
+
|
114 |
+
class SPP(nn.Module):
|
115 |
+
def __init__(self, config=None, projector_type="v1"):
|
116 |
+
super().__init__()
|
117 |
+
|
118 |
+
self.projector_type = projector_type
|
119 |
+
|
120 |
+
inc, ouc = config.mm_hidden_size, config.hidden_size
|
121 |
+
self.linear_0 = nn.Linear(inc, inc)
|
122 |
+
|
123 |
+
self.linear_1 = nn.Linear(inc, ouc)
|
124 |
+
|
125 |
+
self.pooling = nn.AvgPool2d(kernel_size=2)
|
126 |
+
|
127 |
+
self.linear_2 = nn.Linear(ouc, ouc)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
b, num_tokens, c = x.shape
|
131 |
+
h = int(math.sqrt(num_tokens))
|
132 |
+
if "v1" in self.projector_type:
|
133 |
+
x = self.linear_1(x)
|
134 |
+
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
|
135 |
+
x = self.pooling(x)
|
136 |
+
x = x.flatten(2).permute(0, 2, 1)
|
137 |
+
x = self.linear_2(x)
|
138 |
+
elif "v2" in self.projector_type:
|
139 |
+
x = self.linear_1(x)
|
140 |
+
x = self.linear_2(x)
|
141 |
+
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
|
142 |
+
x = self.pooling(x)
|
143 |
+
x = x.flatten(2).permute(0, 2, 1)
|
144 |
+
elif "v3" in self.projector_type:
|
145 |
+
x = self.linear_0(x)
|
146 |
+
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
|
147 |
+
x = self.pooling(x)
|
148 |
+
x = x.flatten(2).permute(0, 2, 1)
|
149 |
+
x = self.linear_1(x)
|
150 |
+
x = self.linear_2(x)
|
151 |
+
return x
|
152 |
+
|
153 |
+
|
154 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
155 |
+
projector_type = getattr(config, "mm_projector_type", "mlp2x_gelu")
|
156 |
+
|
157 |
+
if projector_type == "linear":
|
158 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
159 |
+
|
160 |
+
elif projector_type.startswith("mlp"):
|
161 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
162 |
+
if mlp_gelu_match:
|
163 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
164 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
165 |
+
for _ in range(1, mlp_depth):
|
166 |
+
modules.append(nn.GELU())
|
167 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
168 |
+
return nn.Sequential(*modules)
|
169 |
+
|
170 |
+
elif projector_type.startswith("spp"):
|
171 |
+
return SPP(config, projector_type)
|
172 |
+
|
173 |
+
elif projector_type == "ldp":
|
174 |
+
return LDPNetProjector(config)
|
175 |
+
|
176 |
+
elif projector_type == "vanilla":
|
177 |
+
return Vanilla(config)
|
178 |
+
|
179 |
+
elif projector_type == "minigpt":
|
180 |
+
return Minigpt(config)
|
181 |
+
|
182 |
+
elif projector_type == "identity":
|
183 |
+
return IdentityMap()
|
184 |
+
|
185 |
+
raise ValueError(f"Unknown projector type: {projector_type}")
|
vita/model/vita_arch.py
ADDED
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from vita.constants import AUDIO_TOKEN_INDEX, IGNORE_INDEX, IMAGE_TOKEN_INDEX
|
8 |
+
|
9 |
+
from .multimodal_encoder.builder import build_audio_encoder, build_vision_tower
|
10 |
+
from .multimodal_projector.builder import build_vision_projector
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
class VITAMetaModel:
|
14 |
+
def __init__(self, config):
|
15 |
+
super(VITAMetaModel, self).__init__(config)
|
16 |
+
|
17 |
+
if hasattr(config, "mm_vision_tower"):
|
18 |
+
self.vision_tower = build_vision_tower(
|
19 |
+
config, delay_load=False#not getattr(config, "continuous_training", False)
|
20 |
+
)
|
21 |
+
if getattr(config, "continuous_training", False):
|
22 |
+
config.continuous_training = False
|
23 |
+
self.mm_projector = build_vision_projector(config)
|
24 |
+
|
25 |
+
if hasattr(config, "mm_audio_encoder"):
|
26 |
+
self.audio_encoder = build_audio_encoder(config)
|
27 |
+
|
28 |
+
def get_vision_tower(self):
|
29 |
+
vision_tower = getattr(self, "vision_tower", None)
|
30 |
+
if type(vision_tower) is list:
|
31 |
+
vision_tower = vision_tower[0]
|
32 |
+
return vision_tower
|
33 |
+
|
34 |
+
def get_audio_encoder(self):
|
35 |
+
audio_encoder = getattr(self, "audio_encoder", None)
|
36 |
+
return audio_encoder
|
37 |
+
|
38 |
+
def initialize_vision_modules(self, model_args):
|
39 |
+
vision_tower = model_args.vision_tower
|
40 |
+
|
41 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
42 |
+
|
43 |
+
self.config.mm_vision_tower = vision_tower
|
44 |
+
|
45 |
+
if self.get_vision_tower() is None:
|
46 |
+
vision_tower = build_vision_tower(model_args)
|
47 |
+
self.vision_tower = vision_tower
|
48 |
+
else:
|
49 |
+
vision_tower = self.vision_tower
|
50 |
+
#vision_tower.load_model()
|
51 |
+
|
52 |
+
self.config.use_mm_proj = True
|
53 |
+
self.config.mm_projector_type = getattr(model_args, "mm_projector_type")
|
54 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
55 |
+
|
56 |
+
if getattr(self, "mm_projector", None) is None:
|
57 |
+
self.mm_projector = build_vision_projector(self.config)
|
58 |
+
else:
|
59 |
+
# In case it is frozen by LoRA
|
60 |
+
for p in self.mm_projector.parameters():
|
61 |
+
p.requires_grad = True
|
62 |
+
|
63 |
+
if pretrain_mm_mlp_adapter is not None:
|
64 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
|
65 |
+
|
66 |
+
def get_w(weights, keyword):
|
67 |
+
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
68 |
+
|
69 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
|
70 |
+
|
71 |
+
def initialize_audio_modules(self, model_args):
|
72 |
+
audio_encoder = model_args.audio_encoder
|
73 |
+
|
74 |
+
pretrain_audio_mlp_adapter = model_args.pretrain_audio_mlp_adapter
|
75 |
+
|
76 |
+
setattr(self.config, "mm_audio_encoder", audio_encoder)
|
77 |
+
|
78 |
+
audio_encoder = build_audio_encoder(self.config)
|
79 |
+
self.audio_encoder = audio_encoder
|
80 |
+
|
81 |
+
load_audio_ckpt_from_mllm = True
|
82 |
+
if load_audio_ckpt_from_mllm:
|
83 |
+
from safetensors.torch import load_file
|
84 |
+
import os
|
85 |
+
audio_weights = {}
|
86 |
+
for file_name in os.listdir(model_args.model_name_or_path):
|
87 |
+
if file_name.endswith('safetensors'):
|
88 |
+
audio_weights.update(
|
89 |
+
{k[20:]: v for k, v in load_file(os.path.join(model_args.model_name_or_path, file_name)).items() if
|
90 |
+
k.startswith('model.audio_encoder.')})
|
91 |
+
self.audio_encoder.load_state_dict(audio_weights, strict=True)
|
92 |
+
|
93 |
+
#load_audio_ckpt = True
|
94 |
+
#if self.get_audio_encoder() is None or load_audio_ckpt or model_args.audio_prompt_finetune:
|
95 |
+
# audio_encoder = build_audio_encoder(self.config)
|
96 |
+
# self.audio_encoder = audio_encoder
|
97 |
+
|
98 |
+
#load_audio_prompt_weight = False #True
|
99 |
+
#if load_audio_prompt_weight:
|
100 |
+
# from safetensors.torch import load_file
|
101 |
+
# import os
|
102 |
+
# audio_weights = {}
|
103 |
+
# for file_name in os.listdir(model_args.model_name_or_path):
|
104 |
+
# if file_name.endswith('safetensors'):
|
105 |
+
# audio_weights.update(
|
106 |
+
# {k[38:]: v for k, v in load_file(os.path.join(model_args.model_name_or_path, file_name)).items() if
|
107 |
+
# k.startswith('model.audio_encoder.prompt_embeddings')})
|
108 |
+
# self.audio_encoder.prompt_embeddings.load_state_dict(audio_weights, strict=True)
|
109 |
+
|
110 |
+
#checkpoint = torch.load(model_args.audio_encoder + "/final.pt", map_location="cpu")
|
111 |
+
#model_dict = self.audio_encoder.state_dict()
|
112 |
+
#for key in model_dict.keys():
|
113 |
+
# if key in checkpoint.keys():
|
114 |
+
# if model_dict[key].shape == checkpoint[key].shape:
|
115 |
+
# model_dict[key] = checkpoint[key]
|
116 |
+
# else:
|
117 |
+
# print(
|
118 |
+
# "Key {} has different shape, {} VS {}".format(
|
119 |
+
# key, model_dict[key].shape, checkpoint[key].shape
|
120 |
+
# )
|
121 |
+
# )
|
122 |
+
# else:
|
123 |
+
# print("Key {} has not in resume model".format(key))
|
124 |
+
#self.audio_encoder.load_state_dict(model_dict)
|
125 |
+
|
126 |
+
if pretrain_audio_mlp_adapter is not None:
|
127 |
+
audio_projector_weights = torch.load(pretrain_audio_mlp_adapter, map_location="cpu")
|
128 |
+
|
129 |
+
def get_w(weights, keyword):
|
130 |
+
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
131 |
+
|
132 |
+
self.audio_encoder.adpter.load_state_dict(get_w(audio_projector_weights, "audio_encoder.adpter"))
|
133 |
+
|
134 |
+
|
135 |
+
class VITAMetaForCausalLM(ABC):
|
136 |
+
@abstractmethod
|
137 |
+
def get_model(self):
|
138 |
+
pass
|
139 |
+
|
140 |
+
def get_vision_tower(self):
|
141 |
+
return self.get_model().get_vision_tower()
|
142 |
+
|
143 |
+
def get_audio_encoder(self):
|
144 |
+
return self.get_model().get_audio_encoder()
|
145 |
+
|
146 |
+
def pool_feats(self, x, out_size):
|
147 |
+
ndim = x.ndim
|
148 |
+
if ndim == 2:
|
149 |
+
x = x.unsqueeze(0)
|
150 |
+
b, num_tokens, c = x.shape
|
151 |
+
h = int(math.sqrt(num_tokens))
|
152 |
+
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
|
153 |
+
x = F.interpolate(x, size=out_size, mode='bilinear', align_corners=False)
|
154 |
+
num_tokens = x.shape[2] * x.shape[3] # Recalculate the number of tokens after pooling
|
155 |
+
x = x.reshape(b, c, num_tokens).permute(0, 2, 1)
|
156 |
+
if ndim == 2:
|
157 |
+
x = x.squeeze(0)
|
158 |
+
return x
|
159 |
+
|
160 |
+
def encode_images(self, images):
|
161 |
+
image_features = self.get_model().get_vision_tower()(images)
|
162 |
+
#image_features = self.pool_feats(image_features)
|
163 |
+
image_features = self.get_model().mm_projector(image_features)
|
164 |
+
return image_features
|
165 |
+
|
166 |
+
def encode_images_frameCat(self, images):
|
167 |
+
image_features = self.get_model().get_vision_tower()(images)
|
168 |
+
assert len(image_features) % 5 == 0
|
169 |
+
|
170 |
+
concatenated_features = []
|
171 |
+
for i in range(0, len(image_features), 5):
|
172 |
+
tensors_to_concat = [image_features[j] for j in range(i, i + 5)]
|
173 |
+
concatenated_tensor = torch.cat(tensors_to_concat, dim=-1)
|
174 |
+
concatenated_features.append(concatenated_tensor)
|
175 |
+
concatenated_features = torch.stack(concatenated_features)
|
176 |
+
image_features = concatenated_features
|
177 |
+
|
178 |
+
image_features = self.get_model().mm_projector(image_features)
|
179 |
+
return image_features
|
180 |
+
|
181 |
+
def slow_fast_pooling0(self, temp_img_feats):
|
182 |
+
num_frame = len(temp_img_feats)
|
183 |
+
if num_frame <= 30:
|
184 |
+
slow_token_num = max([e for e in [256, 225, 196, 169] if e <= 5200/num_frame])
|
185 |
+
fast_token_num = slow_token_num
|
186 |
+
elif num_frame <= 45:
|
187 |
+
slow_token_num = 169
|
188 |
+
fast_token_num = 81
|
189 |
+
elif num_frame <= 64:
|
190 |
+
slow_token_num = 169
|
191 |
+
fast_token_num = 49
|
192 |
+
else:
|
193 |
+
raise ValueError("The number of frames is too large!")
|
194 |
+
|
195 |
+
if num_frame <= 30:
|
196 |
+
num_slow = num_frame
|
197 |
+
else:
|
198 |
+
num_slow = int((5200 - fast_token_num * num_frame) / (slow_token_num - fast_token_num))
|
199 |
+
num_fast = num_frame - num_slow
|
200 |
+
slow_index = list(np.linspace(0, num_frame, num=num_slow, dtype=int))
|
201 |
+
|
202 |
+
new_img_feats = []
|
203 |
+
for i, feat in enumerate(temp_img_feats):
|
204 |
+
if i in slow_index:
|
205 |
+
sqrt_len = int(math.sqrt(slow_token_num))
|
206 |
+
else:
|
207 |
+
sqrt_len = int(math.sqrt(fast_token_num))
|
208 |
+
if sqrt_len != 16:
|
209 |
+
feat = self.pool_feats(feat, out_size=(sqrt_len, sqrt_len))
|
210 |
+
new_img_feats.append(feat)
|
211 |
+
|
212 |
+
return new_img_feats
|
213 |
+
|
214 |
+
def slow_fast_pooling1(self, temp_img_feats):
|
215 |
+
num_frame = len(temp_img_feats)
|
216 |
+
if num_frame <= 28:
|
217 |
+
slow_token_num = max([e for e in [256, 225, 196, 169, 144] if e <= 4096/num_frame])
|
218 |
+
fast_token_num = slow_token_num
|
219 |
+
elif num_frame <= 40:
|
220 |
+
slow_token_num = 144
|
221 |
+
fast_token_num = 81
|
222 |
+
elif num_frame <= 64:
|
223 |
+
slow_token_num = 144
|
224 |
+
fast_token_num = 49
|
225 |
+
else:
|
226 |
+
raise ValueError("The number of frames is too large!")
|
227 |
+
|
228 |
+
if num_frame <= 28:
|
229 |
+
num_slow = num_frame
|
230 |
+
else:
|
231 |
+
num_slow = int((4096 - fast_token_num * num_frame) / (slow_token_num - fast_token_num))
|
232 |
+
num_fast = num_frame - num_slow
|
233 |
+
slow_index = list(np.linspace(0, num_frame, num=num_slow, dtype=int))
|
234 |
+
|
235 |
+
new_img_feats = []
|
236 |
+
for i, feat in enumerate(temp_img_feats):
|
237 |
+
if i in slow_index:
|
238 |
+
sqrt_len = int(math.sqrt(slow_token_num))
|
239 |
+
else:
|
240 |
+
sqrt_len = int(math.sqrt(fast_token_num))
|
241 |
+
if sqrt_len != 16:
|
242 |
+
feat = self.pool_feats(feat, out_size=(sqrt_len, sqrt_len))
|
243 |
+
new_img_feats.append(feat)
|
244 |
+
|
245 |
+
return new_img_feats
|
246 |
+
|
247 |
+
def slow_fast_pooling(self, temp_img_feats):
|
248 |
+
num_frame = len(temp_img_feats)
|
249 |
+
slow_token_num = 144
|
250 |
+
fast_token_num = 49
|
251 |
+
|
252 |
+
slow_index = list(range(0, num_frame, 4))
|
253 |
+
|
254 |
+
new_img_feats = []
|
255 |
+
for i, feat in enumerate(temp_img_feats):
|
256 |
+
if i in slow_index:
|
257 |
+
sqrt_len = int(math.sqrt(slow_token_num))
|
258 |
+
else:
|
259 |
+
sqrt_len = int(math.sqrt(fast_token_num))
|
260 |
+
if sqrt_len != 16:
|
261 |
+
feat = self.pool_feats(feat, out_size=(sqrt_len, sqrt_len))
|
262 |
+
new_img_feats.append(feat)
|
263 |
+
|
264 |
+
return new_img_feats
|
265 |
+
|
266 |
+
def slow_fast_pooling3(self, temp_img_feats):
|
267 |
+
num_frame = len(temp_img_feats)
|
268 |
+
slow_token_num = 144
|
269 |
+
fast_token_num = 36
|
270 |
+
|
271 |
+
slow_index = list(range(0, num_frame, 16))
|
272 |
+
|
273 |
+
new_img_feats = []
|
274 |
+
for i, feat in enumerate(temp_img_feats):
|
275 |
+
if i in slow_index:
|
276 |
+
sqrt_len = int(math.sqrt(slow_token_num))
|
277 |
+
else:
|
278 |
+
sqrt_len = int(math.sqrt(fast_token_num))
|
279 |
+
if sqrt_len != 16:
|
280 |
+
feat = self.pool_feats(feat, out_size=(sqrt_len, sqrt_len))
|
281 |
+
new_img_feats.append(feat)
|
282 |
+
|
283 |
+
return new_img_feats
|
284 |
+
|
285 |
+
def slow_fast(self, image_features, sf_masks):
|
286 |
+
new_image_features = []
|
287 |
+
temp_img_feats = [] # 初始化 temp_img_feats 在循环外
|
288 |
+
for i, img_feat in enumerate(image_features):
|
289 |
+
if i == 0 or sf_masks[i] != sf_masks[i-1]:
|
290 |
+
if temp_img_feats: # 如果 temp_img_feats 不为空,则添加到 new_image_features
|
291 |
+
if sf_masks[i-1] > 0:
|
292 |
+
temp_img_feats = self.slow_fast_pooling(temp_img_feats)
|
293 |
+
new_image_features.append(temp_img_feats)
|
294 |
+
temp_img_feats = [img_feat] # 重新初始化 temp_img_feats
|
295 |
+
else:
|
296 |
+
temp_img_feats.append(img_feat)
|
297 |
+
if temp_img_feats: # 处理最后一个子列表
|
298 |
+
if sf_masks[-1] > 0:
|
299 |
+
temp_img_feats = self.slow_fast_pooling(temp_img_feats)
|
300 |
+
new_image_features.append(temp_img_feats)
|
301 |
+
|
302 |
+
output_features = []
|
303 |
+
for e in new_image_features:
|
304 |
+
output_features += e
|
305 |
+
|
306 |
+
return output_features
|
307 |
+
|
308 |
+
def prepare_inputs_labels_for_multimodal(
|
309 |
+
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, audios, sf_masks, shared_v_pid_stride=None
|
310 |
+
):
|
311 |
+
vision_tower = self.get_vision_tower()
|
312 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
313 |
+
if (
|
314 |
+
past_key_values is not None
|
315 |
+
and vision_tower is not None
|
316 |
+
and images is not None
|
317 |
+
and input_ids.shape[1] == 1
|
318 |
+
):
|
319 |
+
target_shape = past_key_values[-1][-1].shape[-2] + 1
|
320 |
+
attention_mask = torch.cat(
|
321 |
+
(
|
322 |
+
attention_mask,
|
323 |
+
torch.ones(
|
324 |
+
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
|
325 |
+
dtype=attention_mask.dtype,
|
326 |
+
device=attention_mask.device,
|
327 |
+
),
|
328 |
+
),
|
329 |
+
dim=1,
|
330 |
+
)
|
331 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
332 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
333 |
+
|
334 |
+
if type(images) is list or images.ndim == 5:
|
335 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
336 |
+
image_features = self.encode_images(concat_images)
|
337 |
+
split_sizes = [image.shape[0] for image in images]
|
338 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
339 |
+
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
|
340 |
+
else:
|
341 |
+
image_features = self.encode_images(images).to(self.device)
|
342 |
+
|
343 |
+
image_features = [e for e in image_features]
|
344 |
+
if sf_masks is not None:
|
345 |
+
assert len(image_features) == len(sf_masks)
|
346 |
+
image_features = self.slow_fast(image_features, sf_masks)
|
347 |
+
|
348 |
+
audio_encoder = self.get_audio_encoder()
|
349 |
+
if audios is not None:
|
350 |
+
audio_features = audio_encoder(audios["audios"], audios["lengths"])
|
351 |
+
state_labels = audios.get("state_labels", None)
|
352 |
+
lengths_for_llm = audios["lengths_for_llm"]
|
353 |
+
if state_labels is not None:
|
354 |
+
assert len(audio_features["inputs_embeds"]) == len(state_labels) == len(lengths_for_llm)
|
355 |
+
else:
|
356 |
+
audio_features, state_labels, lengths_for_llm = None, None, None
|
357 |
+
|
358 |
+
# Let's just add dummy tensors if they do not exist,
|
359 |
+
# it is a headache to deal with None all the time.
|
360 |
+
# But it is not ideal, and if you have a better idea,
|
361 |
+
# please open an issue / submit a PR, thanks.
|
362 |
+
_labels = labels
|
363 |
+
_position_ids = position_ids
|
364 |
+
_attention_mask = attention_mask
|
365 |
+
if attention_mask is None:
|
366 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
367 |
+
else:
|
368 |
+
attention_mask = attention_mask.bool()
|
369 |
+
if position_ids is None:
|
370 |
+
position_ids = torch.arange(
|
371 |
+
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
|
372 |
+
)
|
373 |
+
if labels is None:
|
374 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
375 |
+
|
376 |
+
# remove the padding using attention_mask -- TODO: double check
|
377 |
+
input_ids = [
|
378 |
+
cur_input_ids[cur_attention_mask]
|
379 |
+
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
380 |
+
]
|
381 |
+
labels = [
|
382 |
+
cur_labels[cur_attention_mask]
|
383 |
+
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
|
384 |
+
]
|
385 |
+
|
386 |
+
new_input_embeds = []
|
387 |
+
new_labels = []
|
388 |
+
v_start_end = []
|
389 |
+
cur_image_idx = 0
|
390 |
+
cur_audio_idx = 0
|
391 |
+
assert (
|
392 |
+
sum([(cur == IMAGE_TOKEN_INDEX).sum() for cur in input_ids])
|
393 |
+
+ sum([(IMAGE_TOKEN_INDEX not in cur) for cur in input_ids])
|
394 |
+
== len(image_features)
|
395 |
+
), input_ids
|
396 |
+
assert (
|
397 |
+
sum([(cur == AUDIO_TOKEN_INDEX).sum() for cur in input_ids])
|
398 |
+
+ sum([(AUDIO_TOKEN_INDEX not in cur) for cur in input_ids])
|
399 |
+
== audio_features["inputs_embeds"].shape[0]
|
400 |
+
), input_ids
|
401 |
+
|
402 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
403 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
404 |
+
num_audio_frames = (cur_input_ids == AUDIO_TOKEN_INDEX).sum()
|
405 |
+
if num_images == 0 and num_audio_frames == 0:
|
406 |
+
cur_image_features = image_features[cur_image_idx]
|
407 |
+
cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
|
408 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
409 |
+
cur_input_embeds = torch.cat(
|
410 |
+
[cur_input_embeds_1, cur_image_features[0:0], cur_audio_features[0:0]], dim=0
|
411 |
+
)
|
412 |
+
new_input_embeds.append(cur_input_embeds)
|
413 |
+
new_labels.append(labels[batch_idx])
|
414 |
+
cur_image_idx += 1
|
415 |
+
cur_audio_idx += 1
|
416 |
+
continue
|
417 |
+
|
418 |
+
image_audio_token_indices = (
|
419 |
+
[-1]
|
420 |
+
+ torch.where(
|
421 |
+
(cur_input_ids == IMAGE_TOKEN_INDEX) | (cur_input_ids == AUDIO_TOKEN_INDEX)
|
422 |
+
)[0].tolist()
|
423 |
+
+ [cur_input_ids.shape[0]]
|
424 |
+
)
|
425 |
+
cur_input_ids_noim_noau = []
|
426 |
+
cur_labels = labels[batch_idx]
|
427 |
+
cur_labels_noim_noau = []
|
428 |
+
for i in range(len(image_audio_token_indices) - 1):
|
429 |
+
cur_input_ids_noim_noau.append(
|
430 |
+
cur_input_ids[
|
431 |
+
image_audio_token_indices[i] + 1 : image_audio_token_indices[i + 1]
|
432 |
+
]
|
433 |
+
)
|
434 |
+
cur_labels_noim_noau.append(
|
435 |
+
cur_labels[image_audio_token_indices[i] + 1 : image_audio_token_indices[i + 1]]
|
436 |
+
)
|
437 |
+
|
438 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim_noau]
|
439 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim_noau))
|
440 |
+
cur_input_embeds_no_im_no_au = torch.split(cur_input_embeds, split_sizes, dim=0)
|
441 |
+
cur_new_input_embeds = []
|
442 |
+
cur_new_labels = []
|
443 |
+
cur_v_start_end = []
|
444 |
+
for i in range(num_images + num_audio_frames + 1):
|
445 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im_no_au[i])
|
446 |
+
cur_new_labels.append(cur_labels_noim_noau[i])
|
447 |
+
if i < num_images + num_audio_frames:
|
448 |
+
if cur_input_ids[image_audio_token_indices[i + 1]] == IMAGE_TOKEN_INDEX:
|
449 |
+
cur_image_features = image_features[cur_image_idx]
|
450 |
+
cur_image_idx += 1
|
451 |
+
cur_new_input_embeds.append(cur_image_features)
|
452 |
+
cur_new_labels.append(
|
453 |
+
torch.full(
|
454 |
+
(cur_image_features.shape[0],),
|
455 |
+
IGNORE_INDEX,
|
456 |
+
device=cur_labels.device,
|
457 |
+
dtype=cur_labels.dtype,
|
458 |
+
)
|
459 |
+
)
|
460 |
+
if shared_v_pid_stride:
|
461 |
+
start = sum([x.shape[0] for x in cur_new_labels[:-1]])
|
462 |
+
end = start + cur_new_labels[-1].shape[0]
|
463 |
+
cur_v_start_end.append((start, end))
|
464 |
+
elif cur_input_ids[image_audio_token_indices[i + 1]] == AUDIO_TOKEN_INDEX:
|
465 |
+
cur_lengths_for_llm = lengths_for_llm[cur_audio_idx]
|
466 |
+
cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
|
467 |
+
if getattr(self.config, "audio_prompt_num", None):#self.config.audio_prompt_num:
|
468 |
+
cur_lengths_for_llm = cur_lengths_for_llm + self.config.audio_prompt_num
|
469 |
+
cur_audio_features = cur_audio_features[:cur_lengths_for_llm]
|
470 |
+
if state_labels is not None:
|
471 |
+
cur_state_label = state_labels[cur_audio_idx]
|
472 |
+
cur_audio_idx += 1
|
473 |
+
cur_new_input_embeds.append(cur_audio_features)
|
474 |
+
cur_new_labels.append(
|
475 |
+
torch.full(
|
476 |
+
(cur_audio_features.shape[0],),
|
477 |
+
IGNORE_INDEX,
|
478 |
+
device=cur_labels.device,
|
479 |
+
dtype=cur_labels.dtype,
|
480 |
+
)
|
481 |
+
)
|
482 |
+
if state_labels is not None:
|
483 |
+
cur_new_labels[-1][-1] = cur_state_label
|
484 |
+
else:
|
485 |
+
raise ValueError
|
486 |
+
|
487 |
+
if num_images != 0 and num_audio_frames == 0:
|
488 |
+
cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
|
489 |
+
cur_audio_idx += 1
|
490 |
+
cur_new_input_embeds.append(cur_audio_features[0:0])
|
491 |
+
elif num_images == 0 and num_audio_frames != 0:
|
492 |
+
cur_image_features = image_features[cur_image_idx]
|
493 |
+
cur_image_idx += 1
|
494 |
+
cur_new_input_embeds.append(cur_image_features[0:0])
|
495 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
496 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
497 |
+
|
498 |
+
new_input_embeds.append(cur_new_input_embeds)
|
499 |
+
new_labels.append(cur_new_labels)
|
500 |
+
|
501 |
+
if shared_v_pid_stride:
|
502 |
+
cur_v_start_end = merge_consecutive_tuples(cur_v_start_end)
|
503 |
+
v_start_end.append(cur_v_start_end)
|
504 |
+
|
505 |
+
assert cur_image_idx == len(image_features)
|
506 |
+
assert cur_audio_idx == audio_features["inputs_embeds"].shape[0]
|
507 |
+
if state_labels is not None:
|
508 |
+
assert cur_audio_idx == len(state_labels)
|
509 |
+
if state_labels is not None:
|
510 |
+
assert (
|
511 |
+
sum([(cur == AUDIO_TOKEN_INDEX).sum() for cur in input_ids])
|
512 |
+
== sum([(cur == -101).sum() for cur in new_labels]) + sum([(cur == -102).sum() for cur in new_labels])
|
513 |
+
), (input_ids, sum([(cur == AUDIO_TOKEN_INDEX).sum() for cur in input_ids]), sum([(cur == -101).sum() for cur in new_labels]), sum([(cur == -102).sum() for cur in new_labels]), new_labels.shape)
|
514 |
+
|
515 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
516 |
+
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
517 |
+
if tokenizer_model_max_length is not None:
|
518 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
519 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
520 |
+
|
521 |
+
# Combine them
|
522 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
523 |
+
batch_size = len(new_input_embeds)
|
524 |
+
|
525 |
+
new_input_embeds_padded = []
|
526 |
+
new_labels_padded = torch.full(
|
527 |
+
(batch_size, max_len),
|
528 |
+
IGNORE_INDEX,
|
529 |
+
dtype=new_labels[0].dtype,
|
530 |
+
device=new_labels[0].device,
|
531 |
+
)
|
532 |
+
attention_mask = torch.zeros(
|
533 |
+
(batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device
|
534 |
+
)
|
535 |
+
position_ids = torch.zeros(
|
536 |
+
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
|
537 |
+
)
|
538 |
+
|
539 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
540 |
+
cur_len = cur_new_embed.shape[0]
|
541 |
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
542 |
+
new_input_embeds_padded.append(
|
543 |
+
torch.cat(
|
544 |
+
(
|
545 |
+
torch.zeros(
|
546 |
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
547 |
+
dtype=cur_new_embed.dtype,
|
548 |
+
device=cur_new_embed.device,
|
549 |
+
),
|
550 |
+
cur_new_embed,
|
551 |
+
),
|
552 |
+
dim=0,
|
553 |
+
)
|
554 |
+
)
|
555 |
+
if cur_len > 0:
|
556 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
557 |
+
attention_mask[i, -cur_len:] = True
|
558 |
+
position_ids[i, -cur_len:] = torch.arange(
|
559 |
+
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
|
560 |
+
)
|
561 |
+
else:
|
562 |
+
new_input_embeds_padded.append(
|
563 |
+
torch.cat(
|
564 |
+
(
|
565 |
+
cur_new_embed,
|
566 |
+
torch.zeros(
|
567 |
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
568 |
+
dtype=cur_new_embed.dtype,
|
569 |
+
device=cur_new_embed.device,
|
570 |
+
),
|
571 |
+
),
|
572 |
+
dim=0,
|
573 |
+
)
|
574 |
+
)
|
575 |
+
if cur_len > 0:
|
576 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
577 |
+
attention_mask[i, :cur_len] = True
|
578 |
+
if shared_v_pid_stride is None:
|
579 |
+
position_ids[i, :cur_len] = torch.arange(
|
580 |
+
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
|
581 |
+
)
|
582 |
+
else:
|
583 |
+
cur_v_start_end = v_start_end[i]
|
584 |
+
cur_shared_position_ids = make_shared_position_ids(cur_v_start_end, cur_len, shared_v_pid_stride)
|
585 |
+
position_ids[i, :cur_len] = cur_shared_position_ids
|
586 |
+
|
587 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
588 |
+
|
589 |
+
if _labels is None:
|
590 |
+
new_labels = None
|
591 |
+
else:
|
592 |
+
new_labels = new_labels_padded
|
593 |
+
|
594 |
+
if _attention_mask is None:
|
595 |
+
attention_mask = None
|
596 |
+
else:
|
597 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
598 |
+
|
599 |
+
if _position_ids is None and shared_v_pid_stride is None:
|
600 |
+
position_ids = None
|
601 |
+
|
602 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
603 |
+
|
604 |
+
|
605 |
+
def merge_consecutive_tuples(tuples_list):
|
606 |
+
if not tuples_list:
|
607 |
+
return []
|
608 |
+
|
609 |
+
# 首先对列表按照起点索引进行排序
|
610 |
+
sorted_tuples = sorted(tuples_list, key=lambda x: x[0])
|
611 |
+
|
612 |
+
# 初始化合并后的列表
|
613 |
+
merged_tuples = [sorted_tuples[0]]
|
614 |
+
|
615 |
+
for current_start, current_end in sorted_tuples[1:]:
|
616 |
+
last_merged_start, last_merged_end = merged_tuples[-1]
|
617 |
+
if current_start <= last_merged_end: # 如果当前元组的起点小于等于上一个合并元组的终点
|
618 |
+
# 合并这两个元组
|
619 |
+
new_start, new_end = merged_tuples[-1][0], max(last_merged_end, current_end)
|
620 |
+
merged_tuples[-1] = (new_start, new_end)
|
621 |
+
else:
|
622 |
+
# 如果当前元组不连续,直接添加到合并后的列表中
|
623 |
+
merged_tuples.append((current_start, current_end))
|
624 |
+
|
625 |
+
return merged_tuples
|
626 |
+
|
627 |
+
|
628 |
+
def make_shared_position_ids(cur_v_start_end, cur_len, shared_v_pid_stride):
|
629 |
+
position_ids = torch.tensor([1.0] * cur_len)
|
630 |
+
|
631 |
+
for start, end in cur_v_start_end:
|
632 |
+
position_ids[start:end] = 1/shared_v_pid_stride
|
633 |
+
v_mod = (end - start) % shared_v_pid_stride
|
634 |
+
if v_mod != 0:
|
635 |
+
position_ids[end-v_mod:end] = 1 / v_mod
|
636 |
+
position_ids = position_ids.cumsum(dim=0)
|
637 |
+
position_ids = torch.ceil(position_ids).long() - 1
|
638 |
+
|
639 |
+
return position_ids
|
vita/model/vita_tts/adapter.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import copy
|
4 |
+
import re
|
5 |
+
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
class CNNAdapter(torch.nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
enc_out_dim: int = 512,
|
14 |
+
llm_embed_dim: int = 4096,
|
15 |
+
kernel_size: int = 5,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
|
20 |
+
self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
|
21 |
+
|
22 |
+
self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
|
23 |
+
self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 1, 0)
|
24 |
+
|
25 |
+
self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
|
26 |
+
self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
|
27 |
+
|
28 |
+
self.relu1 = nn.ReLU()
|
29 |
+
self.relu2 = nn.ReLU()
|
30 |
+
|
31 |
+
self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
|
32 |
+
|
33 |
+
def forward(self, x, mask_pad):
|
34 |
+
"""
|
35 |
+
x: B, T, enc_out_dim
|
36 |
+
mask: (B, T) or (B, 1, T)
|
37 |
+
"""
|
38 |
+
x = x.transpose(1, 2) # B, channels, T
|
39 |
+
|
40 |
+
# mask batch padding
|
41 |
+
if mask_pad.size(2) > 0: # time > 0
|
42 |
+
x.masked_fill_(~mask_pad, 0.0)
|
43 |
+
|
44 |
+
x = self.left_padding1(x)
|
45 |
+
x = self.conv1d1(x)
|
46 |
+
x = self.bn1(x)
|
47 |
+
x = self.relu1(x)
|
48 |
+
|
49 |
+
x = self.left_padding2(x)
|
50 |
+
x = self.conv1d2(x)
|
51 |
+
x = self.bn2(x)
|
52 |
+
x = self.relu2(x)
|
53 |
+
|
54 |
+
x = x.transpose(1, 2)
|
55 |
+
x = self.project(x)
|
56 |
+
|
57 |
+
return x, mask_pad
|
58 |
+
|
59 |
+
class LinearAdapter(torch.nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
enc_out_dim: int = 512,
|
63 |
+
llm_embed_dim: int = 4096,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
self.adpter = torch.nn.Linear(enc_out_dim, llm_embed_dim)
|
68 |
+
|
69 |
+
def forward(self, x, mask_pad):
|
70 |
+
return self.adpter(x), mask_pad
|
71 |
+
|
72 |
+
class CNNSubsampling(torch.nn.Module):
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
enc_out_dim: int = 512,
|
76 |
+
llm_embed_dim: int = 4096,
|
77 |
+
kernel_size: int = 5,
|
78 |
+
activation_func: str = 'relu',
|
79 |
+
norm: str = 'batch',
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.kernel_size = kernel_size
|
84 |
+
if enc_out_dim * 4 < llm_embed_dim:
|
85 |
+
self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
|
86 |
+
self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
|
87 |
+
self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
|
88 |
+
self.relu1 = nn.ReLU()
|
89 |
+
|
90 |
+
self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
|
91 |
+
self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 2, 0)
|
92 |
+
self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
|
93 |
+
self.relu2 = nn.ReLU()
|
94 |
+
|
95 |
+
self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
|
96 |
+
self.cnn_num = 2
|
97 |
+
else:
|
98 |
+
self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
|
99 |
+
self.conv1d2 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 2, 0)
|
100 |
+
if norm == 'batch':
|
101 |
+
self.bn2 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
|
102 |
+
elif norm == 'layer':
|
103 |
+
self.bn2 = nn.LayerNorm(2 * enc_out_dim, eps=1e-3)
|
104 |
+
if activation_func == 'gelu':
|
105 |
+
self.relu2 = nn.GELU()
|
106 |
+
else:
|
107 |
+
self.relu2 = nn.ReLU()
|
108 |
+
|
109 |
+
self.project = nn.Linear(2 * enc_out_dim, llm_embed_dim)
|
110 |
+
self.cnn_num = 1
|
111 |
+
|
112 |
+
def forward(self, x, mask_pad, cache=None, return_cache=False):
|
113 |
+
"""
|
114 |
+
x: B, T, enc_out_dim
|
115 |
+
mask: (B, T) or (B, 1, T)
|
116 |
+
"""
|
117 |
+
x = x.transpose(1, 2) # B, channels, T
|
118 |
+
|
119 |
+
# mask batch padding
|
120 |
+
if mask_pad.size(2) > 0: # time > 0
|
121 |
+
x.masked_fill_(~mask_pad, 0.0)
|
122 |
+
|
123 |
+
if self.cnn_num == 2:
|
124 |
+
if cache is None:
|
125 |
+
x = self.left_padding1(x)
|
126 |
+
else:
|
127 |
+
x = torch.cat((cache[1], x), dim=2)
|
128 |
+
if cache is not None:
|
129 |
+
cache[1] = x[:, :, 1-self.kernel_size:]
|
130 |
+
else:
|
131 |
+
cache = [None, x[:, :, 1-self.kernel_size:]]
|
132 |
+
x = self.conv1d1(x)
|
133 |
+
x = self.bn1(x)
|
134 |
+
x = self.relu1(x)
|
135 |
+
|
136 |
+
if cache is None or cache[0] is None:
|
137 |
+
x = self.left_padding2(x)
|
138 |
+
else:
|
139 |
+
x = torch.cat((cache[0], x), dim=2)
|
140 |
+
if cache is not None:
|
141 |
+
cache[0] = x[:, :, 1-self.kernel_size:]
|
142 |
+
else:
|
143 |
+
cache = [x[:, :, 1-self.kernel_size:]]
|
144 |
+
x = self.conv1d2(x)
|
145 |
+
if isinstance(self.bn2, nn.LayerNorm):
|
146 |
+
x = x.transpose(1, 2)
|
147 |
+
x = self.bn2(x)
|
148 |
+
if isinstance(self.bn2, nn.LayerNorm):
|
149 |
+
x = x.transpose(1, 2)
|
150 |
+
x = self.relu2(x)
|
151 |
+
|
152 |
+
x = x.transpose(1, 2)
|
153 |
+
x = self.project(x)
|
154 |
+
|
155 |
+
if return_cache:
|
156 |
+
return x, mask_pad[:, :, 0::2], cache
|
157 |
+
return x, mask_pad[:, :, 0::2]
|
vita/model/vita_tts/audioLLM.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import copy
|
4 |
+
import re
|
5 |
+
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from collections import defaultdict
|
11 |
+
from typing import Dict, List, Optional, Tuple
|
12 |
+
|
13 |
+
from transformers import AutoModelForCausalLM
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
from vita.model.vita_tts.adapter import *
|
17 |
+
|
18 |
+
IGNORE_ID = -1
|
19 |
+
|
20 |
+
class AudioLLM(torch.nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
encoder: torch.nn.Module,
|
24 |
+
llm_path: str,
|
25 |
+
freeze_llm: bool = True,
|
26 |
+
enc_out_dim: int = 512,
|
27 |
+
llm_embed_dim: int = 4096,
|
28 |
+
kernel_size: int = 3,
|
29 |
+
IGNORE_ID: int = -100,
|
30 |
+
adpter_type: str = 'cnn',
|
31 |
+
add_audio_bos_eos: bool = False,
|
32 |
+
task_num: int = 10,
|
33 |
+
add_ctc_prompt_ratio: float = 0.0,
|
34 |
+
lang_dict: dict = None,
|
35 |
+
ctc: torch.nn.Module = None,
|
36 |
+
tokenize_ctc_char: bool = False,
|
37 |
+
task_before_audio: bool = False,
|
38 |
+
hyp_before_task: bool = False,
|
39 |
+
prompt_finetune: bool = False,
|
40 |
+
add_prompt_before: bool = False,
|
41 |
+
prompt_num: int = 5,
|
42 |
+
prefix_finetune: bool = False,
|
43 |
+
prefix_num: int = 5,
|
44 |
+
llm_head_num: int = 32,
|
45 |
+
num_key_value_heads: int = None,
|
46 |
+
task_type: str = 'prompt',
|
47 |
+
freeze_encoder: bool = False,
|
48 |
+
freeze_adpter: bool = False,
|
49 |
+
activation_func: str = 'relu',
|
50 |
+
norm: str = 'batch',
|
51 |
+
use_lora: bool = False,
|
52 |
+
clone_encoder: torch.nn.Module = None,
|
53 |
+
chat_template: str = None,
|
54 |
+
predict_usr_state: int = 0,
|
55 |
+
chunk_size: int = -1,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.encoder = encoder
|
60 |
+
self.llm_decoder = AutoModelForCausalLM.from_pretrained(llm_path,
|
61 |
+
torch_dtype="auto",
|
62 |
+
trust_remote_code=True)
|
63 |
+
self.tokenizer = AutoTokenizer.from_pretrained(llm_path,
|
64 |
+
trust_remote_code=True)
|
65 |
+
self.freeze_llm = freeze_llm
|
66 |
+
self.enc_out_dim = enc_out_dim
|
67 |
+
self.llm_embed_dim = llm_embed_dim
|
68 |
+
self.IGNORE_ID = IGNORE_ID
|
69 |
+
self.add_audio_bos_eos = add_audio_bos_eos
|
70 |
+
self.add_ctc_prompt_ratio = add_ctc_prompt_ratio
|
71 |
+
self.lang_dict = lang_dict
|
72 |
+
self.tokenize_ctc_char = tokenize_ctc_char
|
73 |
+
self.task_before_audio = task_before_audio
|
74 |
+
self.hyp_before_task = hyp_before_task
|
75 |
+
self.prompt_finetune = prompt_finetune
|
76 |
+
self.add_prompt_before = add_prompt_before
|
77 |
+
self.prompt_num = prompt_num
|
78 |
+
self.prefix_finetune = prefix_finetune
|
79 |
+
self.prefix_num = prefix_num
|
80 |
+
self.llm_head_num = llm_head_num
|
81 |
+
if num_key_value_heads is None:
|
82 |
+
self.num_key_value_heads = llm_head_num
|
83 |
+
else:
|
84 |
+
self.num_key_value_heads = num_key_value_heads
|
85 |
+
self.kv_cache_dim = llm_embed_dim // self.llm_head_num * self.num_key_value_heads
|
86 |
+
self.task_type = task_type
|
87 |
+
self.freeze_encoder = freeze_encoder
|
88 |
+
self.freeze_adpter = freeze_adpter
|
89 |
+
self.predict_usr_state = predict_usr_state
|
90 |
+
self.chunk_size = chunk_size
|
91 |
+
|
92 |
+
if not hasattr(self.tokenizer, "eod_id"):
|
93 |
+
self.tokenizer.eod_id = self.tokenizer.eos_token_id
|
94 |
+
if not hasattr(self.llm_decoder, "transformer"):
|
95 |
+
self.llm_decoder.transformer = self.llm_decoder.model
|
96 |
+
self.llm_decoder.transformer.h = self.llm_decoder.transformer.layers
|
97 |
+
if not hasattr(self.llm_decoder.transformer, "wte"):
|
98 |
+
self.llm_decoder.transformer.wte = \
|
99 |
+
self.llm_decoder.transformer.embed_tokens
|
100 |
+
|
101 |
+
# for chat mode
|
102 |
+
if chat_template is not None:
|
103 |
+
self.tokenizer.eod_id = self.tokenizer('<|im_end|>'
|
104 |
+
)['input_ids'][0]
|
105 |
+
self.chat_template = {}
|
106 |
+
chat_template = chat_template.split('<audio>')
|
107 |
+
chat_prefix = chat_template[0].split('<|im_end|>')
|
108 |
+
chat_role = chat_prefix[0] + '<|im_end|>'
|
109 |
+
self.chat_template['role'] = self.tokenizer(
|
110 |
+
[chat_role], return_tensors="pt")['input_ids']
|
111 |
+
self.chat_template['prefix'] = self.tokenizer(
|
112 |
+
[chat_prefix[1]], return_tensors="pt")['input_ids']
|
113 |
+
self.chat_template['suffix'] = self.tokenizer(
|
114 |
+
[chat_template[1]], return_tensors="pt")['input_ids']
|
115 |
+
else:
|
116 |
+
self.chat_template = None
|
117 |
+
|
118 |
+
# for CTC prompt
|
119 |
+
if self.add_ctc_prompt_ratio > 0.0:
|
120 |
+
assert lang_dict is not None
|
121 |
+
assert ctc is not None
|
122 |
+
self.ctc = ctc.eval()
|
123 |
+
if clone_encoder is None:
|
124 |
+
self.clone_encoder = copy.deepcopy(encoder)
|
125 |
+
else:
|
126 |
+
self.clone_encoder = clone_encoder
|
127 |
+
self.clone_encoder.eval()
|
128 |
+
for (name, param) in self.clone_encoder.named_parameters():
|
129 |
+
param.requires_grad = False
|
130 |
+
for (name, param) in self.ctc.named_parameters():
|
131 |
+
param.requires_grad = False
|
132 |
+
else:
|
133 |
+
self.clone_encoder = None
|
134 |
+
|
135 |
+
if self.freeze_llm:
|
136 |
+
self.llm_decoder.eval()
|
137 |
+
for (name, param) in self.llm_decoder.named_parameters():
|
138 |
+
param.requires_grad = False
|
139 |
+
|
140 |
+
if use_lora:
|
141 |
+
config = LoraConfig(
|
142 |
+
r=lora_r,
|
143 |
+
lora_alpha=lora_alpha,
|
144 |
+
target_modules=UNET_TARGET_MODULES,
|
145 |
+
lora_dropout=args.lora_dropout,
|
146 |
+
bias=args.lora_bias,
|
147 |
+
)
|
148 |
+
|
149 |
+
if adpter_type == 'cnn':
|
150 |
+
self.adpter = CNNAdapter(enc_out_dim, llm_embed_dim, kernel_size)
|
151 |
+
elif adpter_type == 'linear':
|
152 |
+
self.adpter = LinearAdapter(enc_out_dim, llm_embed_dim)
|
153 |
+
elif adpter_type == 'subsampling':
|
154 |
+
self.adpter = CNNSubsampling(enc_out_dim, llm_embed_dim,
|
155 |
+
kernel_size, activation_func, norm)
|
156 |
+
|
157 |
+
self.task_embeddings = torch.nn.Embedding(task_num, llm_embed_dim)
|
158 |
+
if task_type == 'prefix':
|
159 |
+
self.prefix_embeddings = nn.ModuleList(
|
160 |
+
[
|
161 |
+
torch.nn.ModuleList(
|
162 |
+
[nn.Embedding(task_num, self.kv_cache_dim),
|
163 |
+
nn.Embedding(task_num, self.kv_cache_dim)]
|
164 |
+
)
|
165 |
+
for i in range(len(self.llm_decoder.transformer.h))
|
166 |
+
]
|
167 |
+
)
|
168 |
+
|
169 |
+
if self.prompt_finetune or self.prefix_finetune:
|
170 |
+
if self.prompt_finetune:
|
171 |
+
self.prompt_embeddings = nn.Embedding(prompt_num, llm_embed_dim)
|
172 |
+
self.prompt_ids = torch.Tensor([i for i in range(prompt_num)]).long()
|
173 |
+
if self.prefix_finetune:
|
174 |
+
self.prefix_embeddings = nn.ModuleList(
|
175 |
+
[
|
176 |
+
torch.nn.ModuleList(
|
177 |
+
[nn.Embedding(prefix_num, self.kv_cache_dim),
|
178 |
+
nn.Embedding(prefix_num, self.kv_cache_dim)]
|
179 |
+
)
|
180 |
+
for i in range(len(self.llm_decoder.transformer.h))
|
181 |
+
]
|
182 |
+
)
|
183 |
+
self.prefix_ids = torch.Tensor([i for i in range(prefix_num)]).long()
|
184 |
+
|
185 |
+
if self.freeze_encoder:
|
186 |
+
self.encoder.eval()
|
187 |
+
for (name, param) in self.encoder.named_parameters():
|
188 |
+
param.requires_grad = False
|
189 |
+
if self.freeze_adpter:
|
190 |
+
self.adpter.eval()
|
191 |
+
for (name, param) in self.adpter.named_parameters():
|
192 |
+
param.requires_grad = False
|
193 |
+
|
194 |
+
if self.predict_usr_state:
|
195 |
+
self.predictor_head = torch.nn.Linear(llm_embed_dim, predict_usr_state)
|
196 |
+
else:
|
197 |
+
self.predictor_head = None
|
198 |
+
|
199 |
+
# define task ids
|
200 |
+
self.task_ids = {
|
201 |
+
"sot": 0,
|
202 |
+
"transcribe": 1,
|
203 |
+
"translate": 2,
|
204 |
+
"zh": 3,
|
205 |
+
"en": 4,
|
206 |
+
"audio": 5,
|
207 |
+
"/audio": 6,
|
208 |
+
"hyps": 7,
|
209 |
+
"/hyps": 8,
|
210 |
+
}
|
211 |
+
|
212 |
+
def set_system_role(
|
213 |
+
self,
|
214 |
+
extra_inputs: Optional[dict] = None,
|
215 |
+
):
|
216 |
+
# Ensure 'past_key_values' does not exist in extra_inputs, raise an exception if it does
|
217 |
+
assert extra_inputs.get('past_key_values', None) is None, "past key values already exist!!!"
|
218 |
+
|
219 |
+
# If 'role' key is present in extra_inputs, use that role as the chat prefix
|
220 |
+
if extra_inputs.get('role', None) is not None:
|
221 |
+
chat_prefix = self.tokenizer([extra_inputs['role']],
|
222 |
+
return_tensors="pt")['input_ids'].to('cuda') # Convert role to tokens and move to CUDA device
|
223 |
+
else:
|
224 |
+
# If no 'role' is provided, use the default chat template and remove the last token (<|im_end|>)
|
225 |
+
chat_prefix = self.chat_template['role'][:, :-1].to('cuda')
|
226 |
+
|
227 |
+
# Use the LLM decoder's word embedding layer to convert the chat prefix into embeddings
|
228 |
+
inputs_embeds = self.llm_decoder.transformer.wte(chat_prefix)
|
229 |
+
|
230 |
+
# Create an attention mask with the same shape as the chat prefix, all values set to True
|
231 |
+
attention_mask = torch.full(chat_prefix.shape,
|
232 |
+
True).to(inputs_embeds.device)
|
233 |
+
|
234 |
+
# Prepare the input dictionary containing embeddings and attention mask
|
235 |
+
inputs = {
|
236 |
+
'inputs_embeds': inputs_embeds.half(), # Convert embeddings to half precision floats
|
237 |
+
'attention_mask': attention_mask,
|
238 |
+
}
|
239 |
+
|
240 |
+
# Call the _generate_one_step method to generate one step output, including past_key_values, etc.
|
241 |
+
_, past_key_values, stat, _ = self._generate_one_step(
|
242 |
+
copy.deepcopy(inputs), "sl")
|
243 |
+
|
244 |
+
# Return the generated past_key_values
|
245 |
+
return past_key_values
|
246 |
+
|
247 |
+
def recognize(
|
248 |
+
self,
|
249 |
+
speech: torch.Tensor,
|
250 |
+
speech_lengths: torch.Tensor,
|
251 |
+
extra_inputs: Optional[dict] = None,
|
252 |
+
):
|
253 |
+
assert extra_inputs.get('past_key_values', None) is not None, "must set system role first!!!"
|
254 |
+
|
255 |
+
buffer = extra_inputs.get('encoder_cache', None)
|
256 |
+
cnn_cache = extra_inputs.get('adapter_cache', None)
|
257 |
+
pe_index = extra_inputs.get('pe_index', 0)
|
258 |
+
if extra_inputs['stat'] == 'sl' or extra_inputs['stat'] == 'cl':
|
259 |
+
# Encoder
|
260 |
+
|
261 |
+
if buffer is None:
|
262 |
+
buffer = [None] * self.encoder.enc[1].num_blocks
|
263 |
+
|
264 |
+
encoder_out, buffer, _, _, pe_index = self.encoder.infer(speech, buffer,
|
265 |
+
0, None, pe_index)
|
266 |
+
|
267 |
+
encoder_mask = torch.full(encoder_out.shape[:2], True).unsqueeze(1
|
268 |
+
).to(encoder_out.device)
|
269 |
+
|
270 |
+
# adapter
|
271 |
+
inputs_embeds, encoder_mask, cnn_cache = self.adpter(encoder_out, encoder_mask,
|
272 |
+
cache=cnn_cache, return_cache=True) # 1, T, D
|
273 |
+
|
274 |
+
attention_mask = encoder_mask.squeeze(1) # 1, T
|
275 |
+
|
276 |
+
# prompt
|
277 |
+
if extra_inputs['stat'] == 'sl':
|
278 |
+
if self.prompt_finetune:
|
279 |
+
prompt_ids = self.prompt_ids.repeat(1, 1).to(inputs_embeds.device)
|
280 |
+
prompt_embeds = self.prompt_embeddings(
|
281 |
+
prompt_ids.to(inputs_embeds.device)) # B, 5, D
|
282 |
+
prompt_mask = torch.full(prompt_ids.shape,
|
283 |
+
True).to(inputs_embeds.device) # B, 5
|
284 |
+
|
285 |
+
if self.add_prompt_before:
|
286 |
+
inputs_embeds = torch.cat((prompt_embeds, inputs_embeds), 1) # B, (T+5), D
|
287 |
+
attention_mask = torch.cat((prompt_mask, attention_mask), 1) # B, (T+5)
|
288 |
+
|
289 |
+
# chat mode
|
290 |
+
if self.chat_template is not None:
|
291 |
+
if extra_inputs['stat'] == 'sl':
|
292 |
+
chat_prefix = self.chat_template['prefix'].to(
|
293 |
+
inputs_embeds.device)
|
294 |
+
chat_prefix = torch.cat((torch.tensor([[self.tokenizer.eod_id]]
|
295 |
+
).to(inputs_embeds.device), chat_prefix), 1)
|
296 |
+
chat_prefix_embeds = self.llm_decoder.transformer.wte(chat_prefix)
|
297 |
+
chat_prefix_mask = torch.full(chat_prefix.shape,
|
298 |
+
True).to(inputs_embeds.device)
|
299 |
+
inputs_embeds = torch.cat((chat_prefix_embeds, inputs_embeds), 1)
|
300 |
+
attention_mask = torch.cat((chat_prefix_mask, attention_mask), 1)
|
301 |
+
if extra_inputs['stat'] == 'ss':
|
302 |
+
chat_suffix = self.chat_template['suffix'].to('cuda')
|
303 |
+
chat_suffix_embeds = self.llm_decoder.transformer.wte(chat_suffix)
|
304 |
+
chat_suffix_mask = torch.full(chat_suffix.shape, True).to('cuda')
|
305 |
+
inputs_embeds = chat_suffix_embeds
|
306 |
+
attention_mask = chat_suffix_mask
|
307 |
+
|
308 |
+
if extra_inputs['stat'] != 'cs':
|
309 |
+
inputs = {
|
310 |
+
'inputs_embeds': inputs_embeds.half(),
|
311 |
+
'attention_mask': attention_mask,
|
312 |
+
}
|
313 |
+
else:
|
314 |
+
attention_mask = torch.full([1, 1], True).to('cuda')
|
315 |
+
inputs = {
|
316 |
+
'input_ids': extra_inputs['last_id'],
|
317 |
+
'attention_mask': attention_mask
|
318 |
+
}
|
319 |
+
|
320 |
+
# add kv cache
|
321 |
+
inputs['past_key_values'] = extra_inputs['past_key_values']
|
322 |
+
past_mask = torch.full([1, inputs['past_key_values'][0][0].size(2)],
|
323 |
+
True).to('cuda')
|
324 |
+
attention_mask = torch.cat((past_mask, attention_mask), 1)
|
325 |
+
inputs['attention_mask'] = attention_mask
|
326 |
+
|
327 |
+
top_p = extra_inputs.get('top_p', 1.0)
|
328 |
+
top_k = extra_inputs.get('top_k', 0)
|
329 |
+
temperature = extra_inputs.get('temperature', 1.0)
|
330 |
+
|
331 |
+
last_id, past_key_values, stat, hidden_state = self._generate_one_step(copy.deepcopy(inputs),
|
332 |
+
extra_inputs['stat'],
|
333 |
+
top_p=top_p,
|
334 |
+
top_k=top_k,
|
335 |
+
temperature=temperature)
|
336 |
+
|
337 |
+
return last_id, stat, past_key_values, cnn_cache, buffer, pe_index, hidden_state
|
338 |
+
|
339 |
+
def _post_decode(self, output, temperature=1.0, top_k=0, top_p=0.0):
|
340 |
+
"""
|
341 |
+
Decoding function, based on the posterior probability output,
|
342 |
+
uses top_k, top_p, and temperature parameters for sampling.
|
343 |
+
|
344 |
+
Parameters:
|
345 |
+
- output: torch.Tensor, shaped as (1, 1, D), represents the posterior probability output by the model.
|
346 |
+
- top_k: int, indicates selecting the top k tokens with the highest probability for sampling.
|
347 |
+
If 0, no top_k filtering is performed.
|
348 |
+
- top_p: float, indicates selecting tokens with cumulative probability not exceeding p for sampling.
|
349 |
+
If 0.0, no top_p filtering is performed.
|
350 |
+
- temperature: float, represents the sampling temperature parameter.
|
351 |
+
The higher the value, the more random the sampling;
|
352 |
+
the lower the value, the more deterministic the sampling.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
- Selected token index.
|
356 |
+
"""
|
357 |
+
output = output.squeeze(0).squeeze(0)
|
358 |
+
|
359 |
+
# temperature
|
360 |
+
if temperature != 1.0:
|
361 |
+
output = output / temperature
|
362 |
+
|
363 |
+
probs = torch.nn.functional.softmax(output, dim=-1)
|
364 |
+
|
365 |
+
# top_k
|
366 |
+
if top_k > 0:
|
367 |
+
top_k_probs, top_k_indices = torch.topk(probs, top_k)
|
368 |
+
probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs)
|
369 |
+
probs = probs / probs.sum()
|
370 |
+
|
371 |
+
# top_p
|
372 |
+
if top_p > 0.0:
|
373 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
374 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
375 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
376 |
+
if sorted_indices_to_remove[0]:
|
377 |
+
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
378 |
+
sorted_indices_to_remove[0] = 0
|
379 |
+
|
380 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
381 |
+
probs[indices_to_remove] = 0
|
382 |
+
probs = probs / probs.sum()
|
383 |
+
|
384 |
+
token_index = torch.multinomial(probs, 1)
|
385 |
+
return token_index.unsqueeze(0)
|
386 |
+
|
387 |
+
def _generate_one_step(
|
388 |
+
self,
|
389 |
+
inputs,
|
390 |
+
stat,
|
391 |
+
top_p: float = 1.0,
|
392 |
+
top_k: int = 0,
|
393 |
+
temperature: float = 1.0,
|
394 |
+
):
|
395 |
+
"""
|
396 |
+
Generates the model's next output based on the current input and state.
|
397 |
+
|
398 |
+
Parameters:
|
399 |
+
- inputs: The input tensor containing the model's input data.
|
400 |
+
- stat: The current state information used to control the generation process.
|
401 |
+
- top_p: The threshold for controlling top-p sampling.
|
402 |
+
- top_k: The threshold for controlling top-k sampling.
|
403 |
+
- temperature: Controls the randomness of sampling.
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
- last_id: The index of the last generated token.
|
407 |
+
- stat: The updated state information.
|
408 |
+
- past_key_values: The model's historical key-value pairs, used for cross-step memory.
|
409 |
+
- hidden_state: The model's hidden state, used to maintain cross-step contextual information.
|
410 |
+
"""
|
411 |
+
outputs = self.llm_decoder.model(**inputs)
|
412 |
+
if stat == 'sl' or stat == 'cl':
|
413 |
+
state_logits = self.predictor_head(
|
414 |
+
outputs['last_hidden_state'])[0, :]
|
415 |
+
prob = F.softmax(state_logits[:, :-1])
|
416 |
+
state_prob = prob[-1].clone()
|
417 |
+
state_1 = state_prob[1]
|
418 |
+
state_2 = state_prob[2]
|
419 |
+
print("State 1 prob: {:.4f}, State 2 prob: {:.4f}".format(state_1.item(), state_2.item()))
|
420 |
+
if state_2 > 0.5:
|
421 |
+
return None, outputs['past_key_values'], 'el', None
|
422 |
+
if state_1 > 0.5:
|
423 |
+
return None, outputs['past_key_values'], 'ss', None
|
424 |
+
return None, outputs['past_key_values'], 'cl', None
|
425 |
+
|
426 |
+
last_logit = self.llm_decoder.lm_head(outputs['last_hidden_state'][:, -1:, :])
|
427 |
+
last_id = self._post_decode(last_logit, temperature=temperature, top_k=top_k, top_p=top_p)
|
428 |
+
return_tts_state = outputs['last_hidden_state'][:, -1:, :]
|
429 |
+
|
430 |
+
if last_id[0][0] == self.tokenizer.eod_id:
|
431 |
+
return None, outputs['past_key_values'], 'sl', return_tts_state
|
432 |
+
else:
|
433 |
+
return last_id, outputs['past_key_values'], 'cs', return_tts_state
|
vita/model/vita_tts/decoder/decoder.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from typing import Dict, List, Tuple, Optional, Union
|
7 |
+
from transformers import LlamaConfig
|
8 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding
|
9 |
+
from transformers.cache_utils import DynamicCache
|
10 |
+
|
11 |
+
from vita.model.vita_tts.encoder.encoder import add_encoder_args
|
12 |
+
from vita.model.vita_tts.masks import *
|
13 |
+
|
14 |
+
IGNORE_ID = -1
|
15 |
+
|
16 |
+
class CrossEntropyLoss(torch.nn.Module):
|
17 |
+
def __init__(self, ignore_index=-1):
|
18 |
+
super(CrossEntropyLoss, self).__init__()
|
19 |
+
self.criterion = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=ignore_index)
|
20 |
+
|
21 |
+
def forward(self, logits, target, target_subsampling_factor=1):
|
22 |
+
"""
|
23 |
+
logits: B*T1*D
|
24 |
+
target: B*T2
|
25 |
+
"""
|
26 |
+
logits = logits[:, :target.shape[1], :]
|
27 |
+
logits = logits.transpose(1, 2)
|
28 |
+
target = target.to(torch.long)
|
29 |
+
loss = self.criterion(logits, target)
|
30 |
+
return loss
|
31 |
+
|
32 |
+
class LLM2TTSCodecAR(torch.nn.Module):
|
33 |
+
"""E2E module.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
idim (int): dimension of inputs
|
37 |
+
odim (int): dimension of outputs
|
38 |
+
args (namespace): argument Namespace containing options
|
39 |
+
|
40 |
+
"""
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def add_arguments(parser):
|
44 |
+
"""Extend arguments for transducer."""
|
45 |
+
group = parser.add_argument_group("TDNN model setting")
|
46 |
+
|
47 |
+
group.add_argument('--encoder-pre-norm-type',
|
48 |
+
default='ln', type=str, help="Type of input norm.")
|
49 |
+
group.add_argument('--encoder-drop-rate', default=0.0,
|
50 |
+
type=float, help="Dropout rate for output.")
|
51 |
+
group.add_argument('--encoder-criterion', default='cross-entropy',
|
52 |
+
type=str, help="Criterion for output")
|
53 |
+
group.add_argument('--encoder-upsample-rate', default=1, type=int)
|
54 |
+
group.add_argument('--kv-cache-prefix-finetune', default=0, type=int)
|
55 |
+
|
56 |
+
group = add_encoder_args(group)
|
57 |
+
|
58 |
+
return parser
|
59 |
+
|
60 |
+
def __init__(self, idim, odim, args):
|
61 |
+
"""Initialize transducer modules.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
idim (int): dimension of inputs
|
65 |
+
odim (int): dimension of outputs
|
66 |
+
args (Namespace): argument Namespace containing options
|
67 |
+
|
68 |
+
"""
|
69 |
+
super(LLM2TTSCodecAR, self).__init__()
|
70 |
+
self.idim = args.idim
|
71 |
+
self.odim = args.odim
|
72 |
+
self.encoder_pre_norm_type = args.encoder_pre_norm_type
|
73 |
+
self.encoder_drop_rate = args.encoder_drop_rate
|
74 |
+
self.encoder_criterion = args.encoder_criterion
|
75 |
+
self.encoder_upsample_rate = args.encoder_upsample_rate
|
76 |
+
self.reporter = None
|
77 |
+
|
78 |
+
self.vocab_size = self.odim
|
79 |
+
config = LlamaConfig(vocab_size=self.vocab_size + 4, hidden_size=args.transformer_attention_dim,
|
80 |
+
intermediate_size=args.transformer_linear_units,
|
81 |
+
num_hidden_layers=args.transformer_num_blocks,
|
82 |
+
num_attention_heads=args.transformer_attention_heads, max_position_embeddings=2048,
|
83 |
+
bos_token_id=self.vocab_size + 1,
|
84 |
+
eos_token_id=self.vocab_size + 2, pad_token_id=self.vocab_size + 3,
|
85 |
+
attention_dropout=args.transformer_dropout_rate)
|
86 |
+
|
87 |
+
self.embedding = nn.Embedding(self.vocab_size + 4, self.idim, padding_idx=self.vocab_size + 3)
|
88 |
+
self.init_pre_nn(config)
|
89 |
+
|
90 |
+
self.layers = nn.ModuleList(
|
91 |
+
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
92 |
+
)
|
93 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
94 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
95 |
+
|
96 |
+
self.dropout = nn.Dropout(p=self.encoder_drop_rate)
|
97 |
+
self.out_fnn = nn.Linear(args.encoder_output_dim, self.vocab_size + 4)
|
98 |
+
|
99 |
+
self.kv_cache_prefix_finetune = args.kv_cache_prefix_finetune
|
100 |
+
if self.kv_cache_prefix_finetune:
|
101 |
+
self.init_kv_cache_prefix(config)
|
102 |
+
self.embedding.eval()
|
103 |
+
self.layers.eval()
|
104 |
+
self.norm.eval()
|
105 |
+
self.rotary_emb.eval()
|
106 |
+
self.out_fnn.eval()
|
107 |
+
for (name, param) in self.embedding.named_parameters():
|
108 |
+
param.requires_grad = False
|
109 |
+
for (name, param) in self.layers.named_parameters():
|
110 |
+
param.requires_grad = False
|
111 |
+
for (name, param) in self.norm.named_parameters():
|
112 |
+
param.requires_grad = False
|
113 |
+
for (name, param) in self.rotary_emb.named_parameters():
|
114 |
+
param.requires_grad = False
|
115 |
+
for (name, param) in self.out_fnn.named_parameters():
|
116 |
+
param.requires_grad = False
|
117 |
+
|
118 |
+
if self.encoder_criterion == 'ce':
|
119 |
+
self.criterion = CrossEntropyLoss(ignore_index=self.vocab_size + 3)
|
120 |
+
|
121 |
+
def init_kv_cache_prefix(self, config):
|
122 |
+
self.layers_prefix = nn.ModuleList(
|
123 |
+
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
124 |
+
)
|
125 |
+
self.rotary_emb_prefix = LlamaRotaryEmbedding(config=config)
|
126 |
+
|
127 |
+
def kv_cache_prefix_forward(self, prefix, prefix_lens, past_key_values):
|
128 |
+
inputs_embeds = prefix
|
129 |
+
past_seen_tokens = 0
|
130 |
+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \
|
131 |
+
inputs_embeds.shape[1], device=inputs_embeds.device)
|
132 |
+
position_ids = cache_position.unsqueeze(0)
|
133 |
+
hidden_states = inputs_embeds
|
134 |
+
position_embeddings = self.rotary_emb_prefix(hidden_states, position_ids)
|
135 |
+
next_decoder_cache = None
|
136 |
+
batch_size, max_len, _ = prefix.size()
|
137 |
+
input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=prefix.device)
|
138 |
+
for i in range(batch_size):
|
139 |
+
input_mask[i, :prefix_lens[i], :prefix_lens[i]] = True
|
140 |
+
attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min
|
141 |
+
for decoder_layer in self.layers_prefix:
|
142 |
+
layer_outputs = decoder_layer(
|
143 |
+
hidden_states,
|
144 |
+
attention_mask=attention_mask,
|
145 |
+
position_ids=position_ids,
|
146 |
+
past_key_value=past_key_values,
|
147 |
+
output_attentions=False,
|
148 |
+
use_cache=True,
|
149 |
+
cache_position=None,
|
150 |
+
position_embeddings=position_embeddings,
|
151 |
+
)
|
152 |
+
hidden_states = layer_outputs[0]
|
153 |
+
next_decoder_cache = layer_outputs[1]
|
154 |
+
past_key_values = next_decoder_cache
|
155 |
+
|
156 |
+
def init_pre_nn(self, config):
|
157 |
+
self.layers_pre_nn = nn.ModuleList(
|
158 |
+
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers // 2)]
|
159 |
+
)
|
160 |
+
self.rotary_emb_pre_nn = LlamaRotaryEmbedding(config=config)
|
161 |
+
|
162 |
+
def pre_nn_forward(self, hidden, hidden_lens):
|
163 |
+
inputs_embeds = hidden
|
164 |
+
past_seen_tokens = 0
|
165 |
+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \
|
166 |
+
inputs_embeds.shape[1], device=inputs_embeds.device)
|
167 |
+
position_ids = cache_position.unsqueeze(0)
|
168 |
+
hidden_states = inputs_embeds
|
169 |
+
position_embeddings = self.rotary_emb_pre_nn(hidden_states, position_ids)
|
170 |
+
next_decoder_cache = None
|
171 |
+
batch_size, max_len, _ = hidden.size()
|
172 |
+
input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=hidden.device)
|
173 |
+
for i in range(batch_size):
|
174 |
+
input_mask[i, :hidden_lens[i], :hidden_lens[i]] = True
|
175 |
+
attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min
|
176 |
+
for decoder_layer in self.layers_pre_nn:
|
177 |
+
layer_outputs = decoder_layer(
|
178 |
+
hidden_states,
|
179 |
+
attention_mask=attention_mask,
|
180 |
+
position_ids=position_ids,
|
181 |
+
past_key_value=None,
|
182 |
+
output_attentions=False,
|
183 |
+
use_cache=False,
|
184 |
+
cache_position=None,
|
185 |
+
position_embeddings=position_embeddings,
|
186 |
+
)
|
187 |
+
hidden_states = layer_outputs[0]
|
188 |
+
return hidden_states
|
189 |
+
|
190 |
+
def forward(self, batch):
|
191 |
+
llm_hidden = batch['x']
|
192 |
+
llm_hidden_lens = batch['x_lens']
|
193 |
+
y = batch['y']
|
194 |
+
y[y == IGNORE_ID] = self.vocab_size + 3
|
195 |
+
y_lens = batch['y_lens']
|
196 |
+
past_key_values = DynamicCache.from_legacy_cache(None)
|
197 |
+
|
198 |
+
if self.kv_cache_prefix_finetune:
|
199 |
+
self.kv_cache_prefix_forward(batch['x_prefix'], batch['x_prefix_lens'], past_key_values)
|
200 |
+
|
201 |
+
# text_ids: (batch_size, max_len)
|
202 |
+
batch_size, max_len = y.size()
|
203 |
+
|
204 |
+
# Create bos, sos and eos tokens
|
205 |
+
bos_token = torch.full((batch_size, 1), self.vocab_size, dtype=torch.long, device=y.device)
|
206 |
+
sos_token = torch.full((batch_size, 1), self.vocab_size + 1, dtype=torch.long, device=y.device)
|
207 |
+
eos_token = torch.full((batch_size, 1), self.vocab_size + 2, dtype=torch.long, device=y.device)
|
208 |
+
padding_token = torch.full((batch_size, 1), self.vocab_size + 3, dtype=torch.long, device=y.device)
|
209 |
+
|
210 |
+
# Pass through pre_nn
|
211 |
+
llm_hidden = self.pre_nn_forward(llm_hidden, llm_hidden_lens)
|
212 |
+
|
213 |
+
# Concat bos embedding
|
214 |
+
bos_emb = self.embedding(bos_token)
|
215 |
+
llm_hidden = torch.cat([bos_emb, llm_hidden], dim=1)
|
216 |
+
llm_hidden_lens = llm_hidden_lens + 1
|
217 |
+
|
218 |
+
# Create input x with sos token at the beginning
|
219 |
+
x = torch.cat([sos_token, y], dim=1) # (batch_size, max_len + 1)
|
220 |
+
|
221 |
+
# Create output y with eos token at the end
|
222 |
+
y = torch.cat([y, padding_token], dim=1)
|
223 |
+
eos_positions = torch.arange(max_len + 1, device=y.device).expand(batch_size, max_len + 1) \
|
224 |
+
== y_lens.unsqueeze(1)
|
225 |
+
y = y.masked_scatter(eos_positions, eos_token.expand_as(y)[eos_positions])
|
226 |
+
|
227 |
+
# Embed the input sequence
|
228 |
+
x_emb = self.embedding(x) # (batch_size, max_len + 1, d_model)
|
229 |
+
|
230 |
+
# compute masks
|
231 |
+
if self.kv_cache_prefix_finetune:
|
232 |
+
x_prefix = batch['x_prefix']
|
233 |
+
x_prefix_lens = batch['x_prefix_lens']
|
234 |
+
input_lens = llm_hidden.size(1) + max_len + 1
|
235 |
+
input_mask = torch.zeros(batch_size, input_lens, x_prefix.size(1) + input_lens, \
|
236 |
+
dtype=torch.bool, device=x_emb.device)
|
237 |
+
for i in range(batch_size):
|
238 |
+
input_mask[i, :llm_hidden_lens[i], :x_prefix_lens[i]] = True
|
239 |
+
input_mask[i, :llm_hidden_lens[i], x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True
|
240 |
+
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :x_prefix_lens[i]] = True
|
241 |
+
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \
|
242 |
+
x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True
|
243 |
+
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \
|
244 |
+
x_prefix.size(1) + llm_hidden.size(1): x_prefix.size(1) + \
|
245 |
+
llm_hidden.size(1) + y_lens[i] + 1] \
|
246 |
+
= subsequent_mask(y_lens[i] + 1, x_emb.device)
|
247 |
+
else:
|
248 |
+
input_lens = llm_hidden.size(1) + max_len + 1
|
249 |
+
input_mask = torch.zeros(batch_size, input_lens, input_lens, dtype=torch.bool, device=x_emb.device)
|
250 |
+
for i in range(batch_size):
|
251 |
+
input_mask[i, :llm_hidden_lens[i], :llm_hidden_lens[i]] = True
|
252 |
+
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :llm_hidden_lens[i]] = True
|
253 |
+
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \
|
254 |
+
llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1] \
|
255 |
+
= subsequent_mask(y_lens[i] + 1, x_emb.device)
|
256 |
+
|
257 |
+
# Pass through the transformer
|
258 |
+
inputs_embeds = torch.cat([llm_hidden, x_emb], 1)
|
259 |
+
llm_hidden = self.dropout(llm_hidden)
|
260 |
+
past_seen_tokens = 0
|
261 |
+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \
|
262 |
+
device=inputs_embeds.device)
|
263 |
+
position_ids = cache_position.unsqueeze(0)
|
264 |
+
hidden_states = inputs_embeds
|
265 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
266 |
+
attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min
|
267 |
+
for decoder_layer in self.layers:
|
268 |
+
layer_outputs = decoder_layer(
|
269 |
+
hidden_states,
|
270 |
+
attention_mask=attention_mask,
|
271 |
+
position_ids=position_ids,
|
272 |
+
past_key_value=past_key_values,
|
273 |
+
output_attentions=False,
|
274 |
+
use_cache=True,
|
275 |
+
cache_position=None,
|
276 |
+
position_embeddings=position_embeddings,
|
277 |
+
)
|
278 |
+
hidden_states = layer_outputs[0]
|
279 |
+
hidden_states = self.norm(hidden_states)
|
280 |
+
|
281 |
+
encoder_out = hidden_states[:, llm_hidden.size(1):]
|
282 |
+
|
283 |
+
# Project to vocabulary size
|
284 |
+
logits = self.out_fnn(encoder_out)
|
285 |
+
|
286 |
+
if self.encoder_criterion == 'ce':
|
287 |
+
loss = self.criterion(logits, y)
|
288 |
+
|
289 |
+
if self.training:
|
290 |
+
self.reporter.log_loss('loss', float(loss))
|
291 |
+
|
292 |
+
return loss
|
293 |
+
|
294 |
+
def transformer_infer(self, inputs_embeds, cache_position, past_key_values):
|
295 |
+
position_ids = cache_position.unsqueeze(0)
|
296 |
+
hidden_states = inputs_embeds
|
297 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
298 |
+
next_decoder_cache = None
|
299 |
+
for decoder_layer in self.layers:
|
300 |
+
layer_outputs = decoder_layer(
|
301 |
+
hidden_states,
|
302 |
+
attention_mask=None,
|
303 |
+
position_ids=position_ids,
|
304 |
+
past_key_value=past_key_values,
|
305 |
+
output_attentions=False,
|
306 |
+
use_cache=True,
|
307 |
+
cache_position=None,
|
308 |
+
position_embeddings=position_embeddings,
|
309 |
+
)
|
310 |
+
hidden_states = layer_outputs[0]
|
311 |
+
next_decoder_cache = layer_outputs[1]
|
312 |
+
return hidden_states
|
313 |
+
|
314 |
+
def infer(self, hidden, top_k, prefix, penalty_window_size, penalty, max_tokens=1000):
|
315 |
+
# Pass through pre_nn
|
316 |
+
hidden = self.pre_nn_forward(hidden, [hidden.size(1)])
|
317 |
+
# Concat bos embedding
|
318 |
+
bos_emb = self.embedding(torch.full((1, 1), self.vocab_size, dtype=torch.long, device=hidden.device))
|
319 |
+
hidden = torch.cat([bos_emb, hidden], dim=1)
|
320 |
+
# init past key values
|
321 |
+
past_key_values = DynamicCache.from_legacy_cache(None)
|
322 |
+
# Pass through the prefix nar decoder
|
323 |
+
if prefix is not None and self.kv_cache_prefix_finetune:
|
324 |
+
self.kv_cache_prefix_forward(prefix, [prefix.size(1)], past_key_values)
|
325 |
+
inputs_embeds = hidden
|
326 |
+
past_seen_tokens = 0
|
327 |
+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \
|
328 |
+
device=inputs_embeds.device)
|
329 |
+
hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values)
|
330 |
+
|
331 |
+
# init generated tokens
|
332 |
+
cur_token = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device)
|
333 |
+
generated_tokens = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device)
|
334 |
+
# generate tokens
|
335 |
+
for i in range(max_tokens):
|
336 |
+
inputs_embeds = self.embedding(cur_token)
|
337 |
+
past_seen_tokens = past_key_values.get_seq_length()
|
338 |
+
if prefix is not None:
|
339 |
+
past_seen_tokens = past_seen_tokens - prefix.size(1)
|
340 |
+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \
|
341 |
+
device=inputs_embeds.device)
|
342 |
+
hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values)
|
343 |
+
hidden_states = self.norm(hidden_states)
|
344 |
+
|
345 |
+
# Project to vocabulary size
|
346 |
+
logits = self.out_fnn(hidden_states)
|
347 |
+
|
348 |
+
# apply penalty
|
349 |
+
if penalty_window_size > 0:
|
350 |
+
for token in set(generated_tokens[0][-penalty_window_size:]):
|
351 |
+
logits[:, :, token] /= penalty
|
352 |
+
|
353 |
+
# top k sampling
|
354 |
+
output = logits.squeeze(0).squeeze(0)
|
355 |
+
probs = torch.nn.functional.softmax(output, dim=-1)
|
356 |
+
top_k_probs, top_k_indices = torch.topk(probs, top_k)
|
357 |
+
probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs)
|
358 |
+
probs = probs / probs.sum()
|
359 |
+
next_token_id = torch.multinomial(probs, 1).unsqueeze(0)
|
360 |
+
|
361 |
+
generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1)
|
362 |
+
cur_token = next_token_id
|
363 |
+
|
364 |
+
# eos
|
365 |
+
if next_token_id == self.vocab_size + 2:
|
366 |
+
break
|
367 |
+
yield next_token_id
|
vita/model/vita_tts/decoder/llm2tts.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
import argparse
|
8 |
+
import subprocess
|
9 |
+
import numpy as np
|
10 |
+
import soundfile as sf
|
11 |
+
import subprocess
|
12 |
+
import concurrent.futures
|
13 |
+
|
14 |
+
from vita.model.vita_tts.decoder.decoder import LLM2TTSCodecAR
|
15 |
+
from vita.model.vita_tts.decoder.ticodec.vqvae_tester import VqvaeTester
|
16 |
+
|
17 |
+
class llm2TTS():
|
18 |
+
def __init__(self, model_path):
|
19 |
+
self.model = self.get_model(model_path).cuda().to(
|
20 |
+
torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
21 |
+
)
|
22 |
+
self.infer = self.model.infer
|
23 |
+
|
24 |
+
self.codec_model = VqvaeTester(config_path=model_path + "/codec/model.json",
|
25 |
+
model_path=model_path + "/codec/final.pt",
|
26 |
+
sample_rate=24000)
|
27 |
+
self.codec_model = self.codec_model.cuda()
|
28 |
+
self.codec_model.vqvae.generator.remove_weight_norm()
|
29 |
+
self.codec_model.vqvae.encoder.remove_weight_norm()
|
30 |
+
self.codec_model.eval()
|
31 |
+
|
32 |
+
def get_model_conf(self, model_path):
|
33 |
+
model_conf = model_path + "/decoder/model.json"
|
34 |
+
with open(model_conf, "rb") as f:
|
35 |
+
print('reading a config file from ' + model_conf)
|
36 |
+
confs = json.load(f)
|
37 |
+
# for asr, tts, mt
|
38 |
+
idim, odim, args = confs
|
39 |
+
return argparse.Namespace(**args)
|
40 |
+
|
41 |
+
def get_model(self, model_path):
|
42 |
+
args_load = self.get_model_conf(model_path)
|
43 |
+
args_load = vars(args_load)
|
44 |
+
args = argparse.Namespace(**args_load)
|
45 |
+
odim = args.odim
|
46 |
+
idim = args.idim
|
47 |
+
model = LLM2TTSCodecAR(idim, odim, args)
|
48 |
+
|
49 |
+
# Resume from a snapshot
|
50 |
+
snapshot_dict = torch.load(model_path + "/decoder/final.pt", map_location=lambda storage, loc: storage)
|
51 |
+
if 'model' in snapshot_dict.keys():
|
52 |
+
resume_model_dict = snapshot_dict['model']
|
53 |
+
else:
|
54 |
+
resume_model_dict = snapshot_dict
|
55 |
+
|
56 |
+
model_dict = model.state_dict()
|
57 |
+
for key in model_dict.keys():
|
58 |
+
if key in resume_model_dict.keys():
|
59 |
+
if model_dict[key].shape == resume_model_dict[key].shape:
|
60 |
+
model_dict[key] = resume_model_dict[key]
|
61 |
+
else:
|
62 |
+
print('Key {} has different shape, {} VS {}'.format(key, model_dict[key].shape,
|
63 |
+
resume_model_dict[key].shape))
|
64 |
+
else:
|
65 |
+
print('Key {} has not in resume model'.format(key))
|
66 |
+
model.load_state_dict(model_dict)
|
67 |
+
model.eval()
|
68 |
+
return model
|
69 |
+
|
70 |
+
def find_min_sum_index(self, buffer, syn, N, threshold):
|
71 |
+
"""
|
72 |
+
Find the index with the minimum sum of a sliding window in the given audio segment
|
73 |
+
and perform operations based on this index.
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
- buffer (torch.Tensor): The buffer containing previously processed audio segments.
|
77 |
+
- syn (torch.Tensor): The current audio segment to be processed.
|
78 |
+
- N (int): The size of the sliding window used to calculate the sum.
|
79 |
+
- threshold (float): Threshold value to determine whether to concatenate buffer and current segment or not.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
- tuple: A tuple containing the updated buffer and the processed audio segment.
|
83 |
+
|
84 |
+
"""
|
85 |
+
arr = syn[0, 0, :]
|
86 |
+
L = len(arr)
|
87 |
+
mid = L // 2
|
88 |
+
|
89 |
+
kernel = torch.ones(N).to(arr.device)
|
90 |
+
window_sums = torch.nn.functional.conv1d(arr.abs().view(1, 1, -1), kernel.view(1, 1, -1), padding=0).squeeze()
|
91 |
+
|
92 |
+
start_index = mid - (N // 2)
|
93 |
+
min_sum, min_index = torch.min(window_sums[start_index:], dim=0)
|
94 |
+
|
95 |
+
# get the start and end index of the window
|
96 |
+
start_index = max(0, min_index.item() + start_index)
|
97 |
+
end_index = min(L, min_index.item() + N + start_index)
|
98 |
+
|
99 |
+
# calculate the real min_sum and min_index
|
100 |
+
min_sum_real, min_index_real = torch.min(arr[start_index: end_index].abs(), dim=0)
|
101 |
+
min_index = min_index_real.item() + start_index
|
102 |
+
|
103 |
+
min_sum = min_sum / N
|
104 |
+
syn_clone = syn.clone()
|
105 |
+
|
106 |
+
if min_sum < threshold:
|
107 |
+
syn = torch.cat([buffer.clone(), syn[:, :, :min_index]], dim=-1)
|
108 |
+
buffer = syn_clone[:, :, min_index:]
|
109 |
+
else:
|
110 |
+
buffer = torch.cat([buffer, syn_clone], dim=-1)
|
111 |
+
syn = None
|
112 |
+
return buffer, syn
|
113 |
+
|
114 |
+
def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10,
|
115 |
+
penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01):
|
116 |
+
"""
|
117 |
+
Run the speech decoder process.
|
118 |
+
|
119 |
+
Parameters:
|
120 |
+
- hidden (torch.Tensor): The output for embedding layer of the language model.
|
121 |
+
- top_k (int): The number of top-k tokens to consider during inference.
|
122 |
+
- prefix (str, optional): The hidden state from the language model.
|
123 |
+
- codec_chunk_size (int, default=40): The size of each chunk to process in the codec model.
|
124 |
+
- codec_padding_size (int, default=10): The amount of padding to add on each side of the codec chunk.
|
125 |
+
- penalty_window_size (int, default=20): The window size for applying penalties during decoding.
|
126 |
+
- penalty (float, default=1.1): The penalty factor.
|
127 |
+
|
128 |
+
Yields:
|
129 |
+
- torch.Tensor: Intermediate audio segments generated by the codec model.
|
130 |
+
|
131 |
+
"""
|
132 |
+
codec_upsample_rate = 600
|
133 |
+
left_padding = 0
|
134 |
+
right_padding = codec_padding_size
|
135 |
+
prefix = None
|
136 |
+
buffer = torch.zeros([1, 1, 0]).to(hidden.device)
|
137 |
+
with torch.no_grad():
|
138 |
+
with torch.autocast(device_type="cuda",
|
139 |
+
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32):
|
140 |
+
print("Starting TTS...")
|
141 |
+
token = torch.full((1, 0), self.model.vocab_size, dtype=torch.long, device=hidden.device)
|
142 |
+
for next_token_id in self.infer(hidden, top_k, prefix, penalty_window_size, penalty):
|
143 |
+
token = torch.cat([token, next_token_id], dim=-1)
|
144 |
+
if token.size(1) == left_padding + codec_chunk_size + right_padding:
|
145 |
+
syn = self.codec_model.vqvae(token.unsqueeze(-1),
|
146 |
+
torch.tensor(self.codec_model.vqvae.h.global_tokens,
|
147 |
+
device=token.device).unsqueeze(0).unsqueeze(0))
|
148 |
+
print("Codec Done")
|
149 |
+
syn = syn[:, :, left_padding * codec_upsample_rate: -right_padding * codec_upsample_rate]
|
150 |
+
left_padding = codec_padding_size
|
151 |
+
token = token[:, -(left_padding + right_padding):]
|
152 |
+
buffer, syn = self.find_min_sum_index(buffer, syn, N, seg_threshold)
|
153 |
+
if syn is not None:
|
154 |
+
yield syn
|
155 |
+
if token.size(1) > 0:
|
156 |
+
print("Codec Done")
|
157 |
+
syn = self.codec_model.vqvae(token.unsqueeze(-1),
|
158 |
+
torch.tensor(self.codec_model.vqvae.h.global_tokens,
|
159 |
+
device=token.device).unsqueeze(0).unsqueeze(0))
|
160 |
+
syn = syn[:, :, left_padding * codec_upsample_rate:]
|
161 |
+
yield torch.cat([buffer, syn], dim=-1)
|
vita/model/vita_tts/decoder/ticodec/models.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import AvgPool1d
|
5 |
+
from torch.nn import Conv1d
|
6 |
+
from torch.nn import Conv2d
|
7 |
+
from torch.nn import ConvTranspose1d
|
8 |
+
from torch.nn.utils import remove_weight_norm
|
9 |
+
from torch.nn.utils import spectral_norm
|
10 |
+
from torch.nn.utils import weight_norm
|
11 |
+
|
12 |
+
LRELU_SLOPE = 0.1
|
13 |
+
|
14 |
+
def get_padding(kernel_size, dilation=1):
|
15 |
+
return int((kernel_size * dilation - dilation) / 2)
|
16 |
+
|
17 |
+
def init_weights(m, mean=0.0, std=0.01):
|
18 |
+
classname = m.__class__.__name__
|
19 |
+
if classname.find("Conv") != -1:
|
20 |
+
m.weight.data.normal_(mean, std)
|
21 |
+
|
22 |
+
class GlobalTokenEncoder(nn.Module):
|
23 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size=3, stride=1):
|
24 |
+
super().__init__()
|
25 |
+
self.pad = (kernel_size - stride) // 2
|
26 |
+
self.conv = nn.Sequential(
|
27 |
+
nn.Conv1d(in_channels, hidden_channels, kernel_size, stride, self.pad, bias=False),
|
28 |
+
nn.LeakyReLU(LRELU_SLOPE),
|
29 |
+
nn.Conv1d(hidden_channels, hidden_channels, kernel_size, stride, self.pad, bias=False),
|
30 |
+
nn.LeakyReLU(LRELU_SLOPE),
|
31 |
+
nn.Conv1d(hidden_channels, out_channels, kernel_size, stride, self.pad, bias=False),
|
32 |
+
nn.LeakyReLU(LRELU_SLOPE),
|
33 |
+
)
|
34 |
+
self.fn = nn.Sequential(
|
35 |
+
# # 2 layers
|
36 |
+
# nn.Linear(out_channels, hidden_channels),
|
37 |
+
# nn.LeakyReLU(LRELU_SLOPE),
|
38 |
+
# nn.Linear(hidden_channels, out_channels),
|
39 |
+
# nn.LeakyReLU(LRELU_SLOPE),
|
40 |
+
# 1 layer
|
41 |
+
nn.Linear(out_channels, out_channels),
|
42 |
+
nn.LeakyReLU(LRELU_SLOPE),
|
43 |
+
nn.BatchNorm1d(out_channels),
|
44 |
+
)
|
45 |
+
def forward(self, x):
|
46 |
+
"""
|
47 |
+
x --- [B, in_channels, T]
|
48 |
+
out -- [B, out_channels]
|
49 |
+
"""
|
50 |
+
# x_mask = torch.unsqueeze(sequence_mask(
|
51 |
+
# x_lengths, x.size(2)), 1).to(x.dtype)
|
52 |
+
# x = self.conv(x) * x_mask
|
53 |
+
x = self.conv(x)
|
54 |
+
# x = torch.sum(x, dim=2) / torch.sum(x_mask, dim=2) # [B, out_channels]
|
55 |
+
x = torch.mean(x, dim=2) # [B, out_channels]
|
56 |
+
x = self.fn(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
class ResBlock1(torch.nn.Module):
|
60 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
61 |
+
super(ResBlock1, self).__init__()
|
62 |
+
self.h = h
|
63 |
+
self.convs1 = nn.ModuleList([
|
64 |
+
weight_norm(
|
65 |
+
Conv1d(
|
66 |
+
channels,
|
67 |
+
channels,
|
68 |
+
kernel_size,
|
69 |
+
1,
|
70 |
+
dilation=dilation[0],
|
71 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
72 |
+
weight_norm(
|
73 |
+
Conv1d(
|
74 |
+
channels,
|
75 |
+
channels,
|
76 |
+
kernel_size,
|
77 |
+
1,
|
78 |
+
dilation=dilation[1],
|
79 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
80 |
+
weight_norm(
|
81 |
+
Conv1d(
|
82 |
+
channels,
|
83 |
+
channels,
|
84 |
+
kernel_size,
|
85 |
+
1,
|
86 |
+
dilation=dilation[2],
|
87 |
+
padding=get_padding(kernel_size, dilation[2])))
|
88 |
+
])
|
89 |
+
self.convs1.apply(init_weights)
|
90 |
+
|
91 |
+
self.convs2 = nn.ModuleList([
|
92 |
+
weight_norm(
|
93 |
+
Conv1d(
|
94 |
+
channels,
|
95 |
+
channels,
|
96 |
+
kernel_size,
|
97 |
+
1,
|
98 |
+
dilation=1,
|
99 |
+
padding=get_padding(kernel_size, 1))), weight_norm(
|
100 |
+
Conv1d(
|
101 |
+
channels,
|
102 |
+
channels,
|
103 |
+
kernel_size,
|
104 |
+
1,
|
105 |
+
dilation=1,
|
106 |
+
padding=get_padding(kernel_size, 1))), weight_norm(
|
107 |
+
Conv1d(
|
108 |
+
channels,
|
109 |
+
channels,
|
110 |
+
kernel_size,
|
111 |
+
1,
|
112 |
+
dilation=1,
|
113 |
+
padding=get_padding(kernel_size, 1)))
|
114 |
+
])
|
115 |
+
self.convs2.apply(init_weights)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
119 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
120 |
+
xt = c1(xt)
|
121 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
122 |
+
xt = c2(xt)
|
123 |
+
x = xt + x
|
124 |
+
return x
|
125 |
+
|
126 |
+
def remove_weight_norm(self):
|
127 |
+
for l in self.convs1:
|
128 |
+
remove_weight_norm(l)
|
129 |
+
for l in self.convs2:
|
130 |
+
remove_weight_norm(l)
|
131 |
+
|
132 |
+
|
133 |
+
class ResBlock2(torch.nn.Module):
|
134 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
135 |
+
super(ResBlock2, self).__init__()
|
136 |
+
self.h = h
|
137 |
+
self.convs = nn.ModuleList([
|
138 |
+
weight_norm(
|
139 |
+
Conv1d(
|
140 |
+
channels,
|
141 |
+
channels,
|
142 |
+
kernel_size,
|
143 |
+
1,
|
144 |
+
dilation=dilation[0],
|
145 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
146 |
+
weight_norm(
|
147 |
+
Conv1d(
|
148 |
+
channels,
|
149 |
+
channels,
|
150 |
+
kernel_size,
|
151 |
+
1,
|
152 |
+
dilation=dilation[1],
|
153 |
+
padding=get_padding(kernel_size, dilation[1])))
|
154 |
+
])
|
155 |
+
self.convs.apply(init_weights)
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
for c in self.convs:
|
159 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
160 |
+
xt = c(xt)
|
161 |
+
x = xt + x
|
162 |
+
return x
|
163 |
+
|
164 |
+
def remove_weight_norm(self):
|
165 |
+
for l in self.convs:
|
166 |
+
remove_weight_norm(l)
|
167 |
+
|
168 |
+
|
169 |
+
class Generator(torch.nn.Module):
|
170 |
+
def __init__(self, h):
|
171 |
+
"""
|
172 |
+
Initializes the Generator module.
|
173 |
+
|
174 |
+
Parameters:
|
175 |
+
- h (object): Configuration object containing hyperparameters for the generator.
|
176 |
+
"""
|
177 |
+
super(Generator, self).__init__()
|
178 |
+
self.h = h
|
179 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
180 |
+
self.num_upsamples = len(h.upsample_rates)
|
181 |
+
self.conv_pre = weight_norm(
|
182 |
+
Conv1d(512, h.upsample_initial_channel, 7, 1, padding=3))
|
183 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
184 |
+
|
185 |
+
self.ups = nn.ModuleList()
|
186 |
+
for i, (u,
|
187 |
+
k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
188 |
+
self.ups.append(
|
189 |
+
weight_norm(
|
190 |
+
ConvTranspose1d(
|
191 |
+
h.upsample_initial_channel // (2**i),
|
192 |
+
h.upsample_initial_channel // (2**(i + 1)),
|
193 |
+
k,
|
194 |
+
u,
|
195 |
+
# padding=(u//2 + u%2),
|
196 |
+
padding=(k - u) // 2,
|
197 |
+
# output_padding=u%2
|
198 |
+
)))
|
199 |
+
|
200 |
+
self.resblocks = nn.ModuleList()
|
201 |
+
for i in range(len(self.ups)):
|
202 |
+
ch = h.upsample_initial_channel // (2**(i + 1))
|
203 |
+
for j, (k, d) in enumerate(
|
204 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
205 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
206 |
+
|
207 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
208 |
+
self.ups.apply(init_weights)
|
209 |
+
self.conv_post.apply(init_weights)
|
210 |
+
|
211 |
+
def forward(self, x, global_features):
|
212 |
+
"""
|
213 |
+
Forward pass of the Generator module.
|
214 |
+
|
215 |
+
Parameters:
|
216 |
+
- x (torch.Tensor): Input tensor of shape [B, C, T], where B is the batch size,
|
217 |
+
C is the number of channels, and T is the sequence length.
|
218 |
+
- global_features (torch.Tensor): Global features tensor of shape [B, 128].
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
- torch.Tensor: Output tensor of shape [B, 1, T],
|
222 |
+
where B is the batch size, and T is the sequence length.
|
223 |
+
"""
|
224 |
+
x = self.conv_pre(x)
|
225 |
+
for i in range(self.num_upsamples):
|
226 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
227 |
+
x = self.ups[i](x)
|
228 |
+
xs = None
|
229 |
+
for j in range(self.num_kernels):
|
230 |
+
if xs is None:
|
231 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
232 |
+
else:
|
233 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
234 |
+
x = xs / self.num_kernels
|
235 |
+
# if i == self.num_upsamples//2 - 1:
|
236 |
+
if x.shape[-2] == global_features.shape[-1]:
|
237 |
+
x += global_features.unsqueeze(-1).repeat(1, 1, x.shape[-1])
|
238 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
239 |
+
x = self.conv_post(x)
|
240 |
+
x = torch.tanh(x)
|
241 |
+
|
242 |
+
return x
|
243 |
+
|
244 |
+
def remove_weight_norm(self):
|
245 |
+
"""
|
246 |
+
Removes weight normalization from all layers in the Generator module.
|
247 |
+
"""
|
248 |
+
print('Removing weight norm...')
|
249 |
+
for l in self.ups:
|
250 |
+
remove_weight_norm(l)
|
251 |
+
for l in self.resblocks:
|
252 |
+
l.remove_weight_norm()
|
253 |
+
remove_weight_norm(self.conv_pre)
|
254 |
+
remove_weight_norm(self.conv_post)
|
255 |
+
|
256 |
+
|
257 |
+
class DiscriminatorP(torch.nn.Module):
|
258 |
+
def __init__(self, period, kernel_size=5, stride=3,
|
259 |
+
use_spectral_norm=False):
|
260 |
+
super(DiscriminatorP, self).__init__()
|
261 |
+
self.period = period
|
262 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
263 |
+
self.convs = nn.ModuleList([
|
264 |
+
norm_f(
|
265 |
+
Conv2d(
|
266 |
+
1,
|
267 |
+
32, (kernel_size, 1), (stride, 1),
|
268 |
+
padding=(get_padding(5, 1), 0))),
|
269 |
+
norm_f(
|
270 |
+
Conv2d(
|
271 |
+
32,
|
272 |
+
128, (kernel_size, 1), (stride, 1),
|
273 |
+
padding=(get_padding(5, 1), 0))),
|
274 |
+
norm_f(
|
275 |
+
Conv2d(
|
276 |
+
128,
|
277 |
+
512, (kernel_size, 1), (stride, 1),
|
278 |
+
padding=(get_padding(5, 1), 0))),
|
279 |
+
norm_f(
|
280 |
+
Conv2d(
|
281 |
+
512,
|
282 |
+
1024, (kernel_size, 1), (stride, 1),
|
283 |
+
padding=(get_padding(5, 1), 0))),
|
284 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
285 |
+
])
|
286 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
fmap = []
|
290 |
+
|
291 |
+
# 1d to 2d
|
292 |
+
b, c, t = x.shape
|
293 |
+
if t % self.period != 0: # pad first
|
294 |
+
n_pad = self.period - (t % self.period)
|
295 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
296 |
+
t = t + n_pad
|
297 |
+
x = x.view(b, c, t // self.period, self.period)
|
298 |
+
|
299 |
+
for l in self.convs:
|
300 |
+
x = l(x)
|
301 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
302 |
+
fmap.append(x)
|
303 |
+
x = self.conv_post(x)
|
304 |
+
fmap.append(x)
|
305 |
+
x = torch.flatten(x, 1, -1)
|
306 |
+
|
307 |
+
return x, fmap
|
308 |
+
|
309 |
+
|
310 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
311 |
+
def __init__(self):
|
312 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
313 |
+
self.discriminators = nn.ModuleList([
|
314 |
+
DiscriminatorP(2),
|
315 |
+
DiscriminatorP(3),
|
316 |
+
DiscriminatorP(5),
|
317 |
+
DiscriminatorP(7),
|
318 |
+
DiscriminatorP(11),
|
319 |
+
])
|
320 |
+
|
321 |
+
def forward(self, y, y_hat):
|
322 |
+
y_d_rs = []
|
323 |
+
y_d_gs = []
|
324 |
+
fmap_rs = []
|
325 |
+
fmap_gs = []
|
326 |
+
for i, d in enumerate(self.discriminators):
|
327 |
+
y_d_r, fmap_r = d(y)
|
328 |
+
y_d_g, fmap_g = d(y_hat)
|
329 |
+
y_d_rs.append(y_d_r)
|
330 |
+
fmap_rs.append(fmap_r)
|
331 |
+
y_d_gs.append(y_d_g)
|
332 |
+
fmap_gs.append(fmap_g)
|
333 |
+
|
334 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
335 |
+
|
336 |
+
|
337 |
+
class DiscriminatorS(torch.nn.Module):
|
338 |
+
def __init__(self, use_spectral_norm=False):
|
339 |
+
super(DiscriminatorS, self).__init__()
|
340 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
341 |
+
self.convs = nn.ModuleList([
|
342 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
343 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
344 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
345 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
346 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
347 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
348 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
349 |
+
])
|
350 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
351 |
+
|
352 |
+
def forward(self, x):
|
353 |
+
fmap = []
|
354 |
+
for l in self.convs:
|
355 |
+
x = l(x)
|
356 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
357 |
+
fmap.append(x)
|
358 |
+
x = self.conv_post(x)
|
359 |
+
fmap.append(x)
|
360 |
+
x = torch.flatten(x, 1, -1)
|
361 |
+
|
362 |
+
return x, fmap
|
363 |
+
|
364 |
+
|
365 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
366 |
+
def __init__(self):
|
367 |
+
super(MultiScaleDiscriminator, self).__init__()
|
368 |
+
self.discriminators = nn.ModuleList([
|
369 |
+
DiscriminatorS(use_spectral_norm=True),
|
370 |
+
DiscriminatorS(),
|
371 |
+
DiscriminatorS(),
|
372 |
+
])
|
373 |
+
self.meanpools = nn.ModuleList(
|
374 |
+
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
|
375 |
+
|
376 |
+
def forward(self, y, y_hat):
|
377 |
+
y_d_rs = []
|
378 |
+
y_d_gs = []
|
379 |
+
fmap_rs = []
|
380 |
+
fmap_gs = []
|
381 |
+
for i, d in enumerate(self.discriminators):
|
382 |
+
if i != 0:
|
383 |
+
y = self.meanpools[i - 1](y)
|
384 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
385 |
+
y_d_r, fmap_r = d(y)
|
386 |
+
y_d_g, fmap_g = d(y_hat)
|
387 |
+
y_d_rs.append(y_d_r)
|
388 |
+
fmap_rs.append(fmap_r)
|
389 |
+
y_d_gs.append(y_d_g)
|
390 |
+
fmap_gs.append(fmap_g)
|
391 |
+
|
392 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
393 |
+
|
394 |
+
|
395 |
+
def feature_loss(fmap_r, fmap_g):
|
396 |
+
loss = 0
|
397 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
398 |
+
for rl, gl in zip(dr, dg):
|
399 |
+
loss += torch.mean(torch.abs(rl - gl))
|
400 |
+
|
401 |
+
return loss * 2
|
402 |
+
|
403 |
+
|
404 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
405 |
+
loss = 0
|
406 |
+
r_losses = []
|
407 |
+
g_losses = []
|
408 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
409 |
+
r_loss = torch.mean((1 - dr)**2)
|
410 |
+
g_loss = torch.mean(dg**2)
|
411 |
+
loss += (r_loss + g_loss)
|
412 |
+
r_losses.append(r_loss.item())
|
413 |
+
g_losses.append(g_loss.item())
|
414 |
+
|
415 |
+
return loss, r_losses, g_losses
|
416 |
+
|
417 |
+
|
418 |
+
def generator_loss(disc_outputs):
|
419 |
+
loss = 0
|
420 |
+
gen_losses = []
|
421 |
+
for dg in disc_outputs:
|
422 |
+
l = torch.mean((1 - dg)**2)
|
423 |
+
gen_losses.append(l)
|
424 |
+
loss += l
|
425 |
+
|
426 |
+
return loss, gen_losses
|
427 |
+
|
428 |
+
|
429 |
+
class Encoder(torch.nn.Module):
|
430 |
+
def __init__(self, h):
|
431 |
+
super(Encoder, self).__init__()
|
432 |
+
self.h = h
|
433 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
434 |
+
self.num_upsamples = len(h.upsample_rates)
|
435 |
+
self.conv_pre = weight_norm(Conv1d(1, 32, 7, 1, padding=3))
|
436 |
+
self.normalize = nn.ModuleList()
|
437 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
438 |
+
|
439 |
+
self.ups = nn.ModuleList()
|
440 |
+
for i, (u, k) in enumerate(
|
441 |
+
list(
|
442 |
+
reversed(
|
443 |
+
list(zip(h.upsample_rates, h.upsample_kernel_sizes))))):
|
444 |
+
self.ups.append(
|
445 |
+
weight_norm(
|
446 |
+
Conv1d(
|
447 |
+
32 * (2**i),
|
448 |
+
32 * (2**(i + 1)),
|
449 |
+
k,
|
450 |
+
u,
|
451 |
+
padding=((k - u) // 2)
|
452 |
+
# padding=(u//2 + u%2)
|
453 |
+
)))
|
454 |
+
self.resblocks = nn.ModuleList()
|
455 |
+
for i in range(len(self.ups)):
|
456 |
+
ch = 32 * (2**(i + 1))
|
457 |
+
for j, (k, d) in enumerate(
|
458 |
+
zip(
|
459 |
+
list(reversed(h.resblock_kernel_sizes)),
|
460 |
+
list(reversed(h.resblock_dilation_sizes)))):
|
461 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
462 |
+
self.normalize.append(
|
463 |
+
torch.nn.GroupNorm(ch // 16, ch, eps=1e-6, affine=True))
|
464 |
+
self.conv_post = Conv1d(512, 512, 3, 1, padding=1)
|
465 |
+
self.ups.apply(init_weights)
|
466 |
+
self.conv_post.apply(init_weights)
|
467 |
+
self.linear = nn.Sequential(
|
468 |
+
nn.Linear(128, 128),
|
469 |
+
nn.LeakyReLU(LRELU_SLOPE)
|
470 |
+
)
|
471 |
+
self.gfc = h.global_feature_conv
|
472 |
+
self.GlobalTokenEncoder = GlobalTokenEncoder(self.gfc[0], self.gfc[1], self.gfc[2], self.gfc[3], self.gfc[4])
|
473 |
+
self.GlobalTokenEncoder.apply(init_weights)
|
474 |
+
|
475 |
+
def forward(self, x, xx=None):
|
476 |
+
x = self.conv_pre(x)
|
477 |
+
global_features = None
|
478 |
+
for i in range(self.num_upsamples):
|
479 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
480 |
+
x = self.ups[i](x)
|
481 |
+
xs = None
|
482 |
+
for j in range(self.num_kernels):
|
483 |
+
if xs is None:
|
484 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
485 |
+
xs = self.normalize[i * self.num_kernels + j](xs)
|
486 |
+
else:
|
487 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
488 |
+
xs = self.normalize[i * self.num_kernels + j](xs)
|
489 |
+
x = xs / self.num_kernels
|
490 |
+
if i == self.num_upsamples//2 - 1:
|
491 |
+
mid_features = x
|
492 |
+
global_features = self.GlobalTokenEncoder(x)
|
493 |
+
x = F.leaky_relu(x)
|
494 |
+
x = self.conv_post(x)
|
495 |
+
if xx is not None:
|
496 |
+
xx = self.conv_pre(xx)
|
497 |
+
global_features2 = None
|
498 |
+
for i in range(self.num_upsamples//2):
|
499 |
+
xx = F.leaky_relu(xx, LRELU_SLOPE)
|
500 |
+
xx = self.ups[i](xx)
|
501 |
+
xxs = None
|
502 |
+
for j in range(self.num_kernels):
|
503 |
+
if xxs is None:
|
504 |
+
xxs = self.resblocks[i * self.num_kernels + j](xx)
|
505 |
+
xxs = self.normalize[i * self.num_kernels + j](xxs)
|
506 |
+
else:
|
507 |
+
xxs += self.resblocks[i * self.num_kernels + j](xx)
|
508 |
+
xxs = self.normalize[i * self.num_kernels + j](xxs)
|
509 |
+
xx = xxs / self.num_kernels
|
510 |
+
mid_features2 = xx
|
511 |
+
global_features2 = self.GlobalTokenEncoder(xx)
|
512 |
+
global_features2 = global_features2.detach()
|
513 |
+
return x, global_features, global_features2
|
514 |
+
return x, global_features
|
515 |
+
|
516 |
+
def remove_weight_norm(self):
|
517 |
+
print('Removing weight norm...')
|
518 |
+
for l in self.ups:
|
519 |
+
remove_weight_norm(l)
|
520 |
+
for l in self.resblocks:
|
521 |
+
l.remove_weight_norm()
|
522 |
+
remove_weight_norm(self.conv_pre)
|
523 |
+
|
524 |
+
|
525 |
+
class Quantizer_module(torch.nn.Module):
|
526 |
+
def __init__(self, n_e, e_dim):
|
527 |
+
super(Quantizer_module, self).__init__()
|
528 |
+
self.embedding = nn.Embedding(n_e, e_dim)
|
529 |
+
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
|
530 |
+
|
531 |
+
def forward(self, x):
|
532 |
+
# compute Euclidean distance
|
533 |
+
d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) \
|
534 |
+
- 2 * torch.matmul(x, self.embedding.weight.T)
|
535 |
+
min_indicies = torch.argmin(d, 1)
|
536 |
+
z_q = self.embedding(min_indicies)
|
537 |
+
return z_q, min_indicies
|
538 |
+
|
539 |
+
|
540 |
+
class Quantizer(torch.nn.Module):
|
541 |
+
def __init__(self, h):
|
542 |
+
super(Quantizer, self).__init__()
|
543 |
+
assert 512 % h.n_code_groups == 0
|
544 |
+
self.quantizer_modules = nn.ModuleList([
|
545 |
+
Quantizer_module(h.n_codes, 512 // h.n_code_groups)
|
546 |
+
for _ in range(h.n_code_groups)
|
547 |
+
])
|
548 |
+
self.residul_layer = h.residul_layer
|
549 |
+
if h.residul_layer == 2:
|
550 |
+
self.quantizer_modules2 = nn.ModuleList([
|
551 |
+
Quantizer_module(h.n_codes, 512 // h.n_code_groups)
|
552 |
+
for _ in range(h.n_code_groups)
|
553 |
+
])
|
554 |
+
if h.residul_layer == 4:
|
555 |
+
self.quantizer_modules2 = nn.ModuleList([
|
556 |
+
Quantizer_module(h.n_codes, 512 // h.n_code_groups)
|
557 |
+
for _ in range(h.n_code_groups)
|
558 |
+
])
|
559 |
+
self.quantizer_modules3 = nn.ModuleList([
|
560 |
+
Quantizer_module(h.n_codes, 512 // h.n_code_groups)
|
561 |
+
for _ in range(h.n_code_groups)
|
562 |
+
])
|
563 |
+
self.quantizer_modules4 = nn.ModuleList([
|
564 |
+
Quantizer_module(h.n_codes, 512 // h.n_code_groups)
|
565 |
+
for _ in range(h.n_code_groups)
|
566 |
+
])
|
567 |
+
|
568 |
+
self.quantizer_modules_globaltokens = nn.ModuleList([
|
569 |
+
Quantizer_module(h.n_codes, 128//h.global_code_num)
|
570 |
+
for _ in range(h.global_code_num)
|
571 |
+
])
|
572 |
+
# self.quantizer_modules3 = nn.ModuleList([
|
573 |
+
# Quantizer_module(h.n_codes, 128//h.global_code_num)
|
574 |
+
# for _ in range(h.global_code_num)
|
575 |
+
# ])
|
576 |
+
self.h = h
|
577 |
+
self.codebook_loss_lambda = self.h.codebook_loss_lambda # e.g., 1
|
578 |
+
self.commitment_loss_lambda = self.h.commitment_loss_lambda # e.g., 0.25
|
579 |
+
# self.residul_layer = 2
|
580 |
+
self.n_code_groups = h.n_code_groups
|
581 |
+
self.global_code_num = h.global_code_num
|
582 |
+
|
583 |
+
def for_one_step(self, xin, idx):
|
584 |
+
xin = xin.transpose(1, 2)
|
585 |
+
x = xin.reshape(-1, 512)
|
586 |
+
x = torch.split(x, 512 // self.h.n_code_groups, dim=-1)
|
587 |
+
min_indicies = []
|
588 |
+
z_q = []
|
589 |
+
if idx == 0:
|
590 |
+
for _x, m in zip(x, self.quantizer_modules):
|
591 |
+
_z_q, _min_indicies = m(_x)
|
592 |
+
z_q.append(_z_q)
|
593 |
+
min_indicies.append(_min_indicies) #B * T,
|
594 |
+
elif idx == 1:
|
595 |
+
for _x, m in zip(x, self.quantizer_modules2):
|
596 |
+
_z_q, _min_indicies = m(_x)
|
597 |
+
z_q.append(_z_q)
|
598 |
+
min_indicies.append(_min_indicies) #B * T,
|
599 |
+
elif idx == 2:
|
600 |
+
for _x, m in zip(x, self.quantizer_modules3):
|
601 |
+
_z_q, _min_indicies = m(_x)
|
602 |
+
z_q.append(_z_q)
|
603 |
+
min_indicies.append(_min_indicies)
|
604 |
+
elif idx == 3:
|
605 |
+
for _x, m in zip(x, self.quantizer_modules4):
|
606 |
+
_z_q, _min_indicies = m(_x)
|
607 |
+
z_q.append(_z_q)
|
608 |
+
min_indicies.append(_min_indicies)
|
609 |
+
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
610 |
+
# loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
611 |
+
loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \
|
612 |
+
+ self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2)
|
613 |
+
z_q = xin + (z_q - xin).detach()
|
614 |
+
z_q = z_q.transpose(1, 2)
|
615 |
+
return z_q, loss, min_indicies
|
616 |
+
|
617 |
+
def for_one_step_gst(self, xin):
|
618 |
+
# xin = xin.transpose(1, 2)
|
619 |
+
x = xin.reshape(-1, 128) #B * 1, 128
|
620 |
+
x = torch.split(x, 128 // self.global_code_num, dim=-1)
|
621 |
+
min_indicies = []
|
622 |
+
z_q = []
|
623 |
+
for _x, m in zip(x, self.quantizer_modules_globaltokens):
|
624 |
+
_z_q, _min_indicies = m(_x)
|
625 |
+
z_q.append(_z_q)
|
626 |
+
min_indicies.append(_min_indicies)
|
627 |
+
# for _x, m in zip(x, self.quantizer_modules3):
|
628 |
+
# _z_q, _min_indicies = m(_x)
|
629 |
+
# z_q.append(_z_q)
|
630 |
+
# min_indicies.append(_min_indicies)
|
631 |
+
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
632 |
+
# loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
633 |
+
loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \
|
634 |
+
+ self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2)
|
635 |
+
z_q = xin + (z_q - xin).detach()
|
636 |
+
z_q = z_q.squeeze(1)
|
637 |
+
return z_q, loss, min_indicies
|
638 |
+
|
639 |
+
def forward(self, xin, global_style):
|
640 |
+
#B, C, T
|
641 |
+
quantized_out = 0.0
|
642 |
+
residual = xin
|
643 |
+
all_losses = []
|
644 |
+
all_indices = []
|
645 |
+
for i in range(self.residul_layer):
|
646 |
+
quantized, loss, indices = self.for_one_step(residual, i) #
|
647 |
+
residual = residual - quantized
|
648 |
+
quantized_out = quantized_out + quantized
|
649 |
+
all_indices.extend(indices) #
|
650 |
+
all_losses.append(loss)
|
651 |
+
all_losses = torch.stack(all_losses)
|
652 |
+
loss = torch.mean(all_losses)
|
653 |
+
global_style_quantized, loss_gst_vq, global_style_tokens= self.for_one_step_gst(global_style)
|
654 |
+
loss += loss_gst_vq
|
655 |
+
# global_style_quantized = global_style
|
656 |
+
# global_style_tokens = global_style
|
657 |
+
# global_style_quantized = global_style_quantized.squeeze(1)
|
658 |
+
# global_style_tokens = global_style_tokens.squeeze(1)
|
659 |
+
return quantized_out, loss, all_indices, global_style_quantized, global_style_tokens
|
660 |
+
|
661 |
+
def embed(self, x):
|
662 |
+
#idx: N, T, 4
|
663 |
+
#print('x ', x.shape)
|
664 |
+
quantized_out = torch.tensor(0.0, device=x.device)
|
665 |
+
x = torch.split(x, 1, 2) # split, 将最后一个维度分开, 每个属于一个index group
|
666 |
+
#print('x.shape ', len(x),x[0].shape)
|
667 |
+
for i in range(self.residul_layer):
|
668 |
+
ret = []
|
669 |
+
if i == 0:
|
670 |
+
for j in range(self.n_code_groups):
|
671 |
+
q = x[j]
|
672 |
+
embed = self.quantizer_modules[j]
|
673 |
+
q = embed.embedding(q.squeeze(-1))
|
674 |
+
ret.append(q)
|
675 |
+
ret = torch.cat(ret, -1)
|
676 |
+
#print(ret.shape)
|
677 |
+
quantized_out = quantized_out + ret
|
678 |
+
elif i == 1:
|
679 |
+
for j in range(self.n_code_groups):
|
680 |
+
q = x[j + self.n_code_groups]
|
681 |
+
embed = self.quantizer_modules2[j]
|
682 |
+
q = embed.embedding(q.squeeze(-1))
|
683 |
+
ret.append(q)
|
684 |
+
ret = torch.cat(ret, -1)
|
685 |
+
quantized_out = quantized_out + ret
|
686 |
+
elif i == 2:
|
687 |
+
for j in range(self.n_code_groups):
|
688 |
+
q = x[j + self.n_code_groups * 2]
|
689 |
+
embed = self.quantizer_modules3[j]
|
690 |
+
q = embed.embedding(q.squeeze(-1))
|
691 |
+
ret.append(q)
|
692 |
+
ret = torch.cat(ret, -1)
|
693 |
+
quantized_out = quantized_out + ret
|
694 |
+
elif i == 3:
|
695 |
+
for j in range(self.n_code_groups):
|
696 |
+
q = x[j + self.n_code_groups * 3]
|
697 |
+
embed = self.quantizer_modules4[j]
|
698 |
+
q = embed.embedding(q.squeeze(-1))
|
699 |
+
ret.append(q)
|
700 |
+
ret = torch.cat(ret, -1)
|
701 |
+
quantized_out = quantized_out + ret
|
702 |
+
return quantized_out.transpose(1, 2) #N, C, T
|
703 |
+
def embed_gst(self, x):
|
704 |
+
quantized_out = torch.tensor(0.0, device=x.device)
|
705 |
+
ret = []
|
706 |
+
x = torch.split(x, 1, 2)
|
707 |
+
for j in range(self.global_code_num):
|
708 |
+
q = x[j]
|
709 |
+
embed = self.quantizer_modules_globaltokens[j]
|
710 |
+
# embed = self.quantizer_modules3[j]
|
711 |
+
q = embed.embedding(q.squeeze(-1))
|
712 |
+
ret.append(q)
|
713 |
+
ret = torch.cat(ret, -1)
|
714 |
+
quantized_out = quantized_out + ret
|
715 |
+
return quantized_out.transpose(1, 2)
|
716 |
+
# return x
|
vita/model/vita_tts/decoder/ticodec/vqvae.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from vita.model.vita_tts.decoder.ticodec.models import Encoder
|
7 |
+
from vita.model.vita_tts.decoder.ticodec.models import Generator
|
8 |
+
from vita.model.vita_tts.decoder.ticodec.models import Quantizer
|
9 |
+
|
10 |
+
class AttrDict(dict):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
13 |
+
self.__dict__ = self
|
14 |
+
|
15 |
+
class VQVAE(nn.Module):
|
16 |
+
def __init__(self,
|
17 |
+
config_path,
|
18 |
+
ckpt_path,
|
19 |
+
with_encoder=False):
|
20 |
+
super(VQVAE, self).__init__()
|
21 |
+
ckpt = torch.load(ckpt_path)
|
22 |
+
with open(config_path) as f:
|
23 |
+
data = f.read()
|
24 |
+
json_config = json.loads(data)
|
25 |
+
self.h = AttrDict(json_config)
|
26 |
+
# self.gst = GST()
|
27 |
+
# self.gst = Proposed(n_specs=128, token_num=10, E=128, n_layers=4)
|
28 |
+
self.quantizer = Quantizer(self.h)
|
29 |
+
self.generator = Generator(self.h)
|
30 |
+
self.generator.load_state_dict(ckpt['generator'])
|
31 |
+
self.quantizer.load_state_dict(ckpt['quantizer'])
|
32 |
+
# self.gst.load_state_dict(ckpt['gst'])
|
33 |
+
if with_encoder:
|
34 |
+
self.encoder = Encoder(self.h)
|
35 |
+
self.encoder.load_state_dict(ckpt['encoder'])
|
36 |
+
|
37 |
+
def forward(self, x, global_style_token):
|
38 |
+
# x is the codebook
|
39 |
+
# x.shape (B, T, Nq)
|
40 |
+
quant_emb = self.quantizer.embed(x)
|
41 |
+
global_style_quantized_emb = self.quantizer.embed_gst(global_style_token).squeeze(-1)
|
42 |
+
return self.generator(quant_emb, global_style_quantized_emb)
|
43 |
+
|
44 |
+
def encode(self, x):
|
45 |
+
batch_size = x.size(0)
|
46 |
+
if len(x.shape) == 3 and x.shape[-1] == 1:
|
47 |
+
x = x.squeeze(-1)
|
48 |
+
# print(x.shape)
|
49 |
+
|
50 |
+
c, global_features = self.encoder(x.unsqueeze(1))
|
51 |
+
# mid = mid.transpose(1, 2).unsqueeze(1)
|
52 |
+
# global_style = self.gst(mid)
|
53 |
+
q, loss_q, local_token, g, global_style_token = self.quantizer(c, global_features)
|
54 |
+
local_token = [code.reshape(batch_size, -1) for code in local_token]
|
55 |
+
global_style_token = torch.stack(global_style_token, -1).unsqueeze(1)
|
56 |
+
# shape: [N, T, 4]
|
57 |
+
return torch.stack(local_token, -1), global_style_token
|
vita/model/vita_tts/decoder/ticodec/vqvae_tester.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from vita.model.vita_tts.decoder.ticodec.vqvae import VQVAE
|
8 |
+
|
9 |
+
class VqvaeTester(nn.Module):
|
10 |
+
def __init__(self, config_path, model_path, sample_rate=24000):
|
11 |
+
super().__init__()
|
12 |
+
self.vqvae = VQVAE(config_path, model_path, with_encoder=True)
|
13 |
+
self.sample_rate = sample_rate
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
def forward(self, wav_path):
|
17 |
+
# 单声道
|
18 |
+
# wav.shape (T, ), 按照模型的 sr 读取
|
19 |
+
wav, sr = librosa.load(wav_path, sr=self.sample_rate)
|
20 |
+
fid = os.path.basename(wav_path)[:-4]
|
21 |
+
wav = torch.tensor(wav).unsqueeze(0)
|
22 |
+
wav = wav.cuda()
|
23 |
+
# vq_codes is acoustic token
|
24 |
+
vq_codes, global_token = self.vqvae.encode(wav)
|
25 |
+
import pdb; pdb.set_trace()
|
26 |
+
syn = self.vqvae(vq_codes, global_token)
|
27 |
+
return fid, syn
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def vq(self, wav_path):
|
31 |
+
wav, sr = librosa.load(wav_path, sr=self.sample_rate)
|
32 |
+
fid = os.path.basename(wav_path)[:-4]
|
33 |
+
wav = torch.tensor(wav).unsqueeze(0)
|
34 |
+
wav = wav.cuda()
|
35 |
+
# vq_codes is acoustic token
|
36 |
+
vq_codes, global_token = self.vqvae.encode(wav)
|
37 |
+
return fid, vq_codes, global_token
|
vita/model/vita_tts/encoder/attention.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import numpy
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
class PositionalEncoding(torch.nn.Module):
|
8 |
+
"""Positional encoding.
|
9 |
+
:param int d_model: embedding dim
|
10 |
+
:param float dropout_rate: dropout rate
|
11 |
+
:param int max_len: maximum input length
|
12 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
13 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
14 |
+
"""
|
15 |
+
def __init__(self,
|
16 |
+
d_model: int,
|
17 |
+
dropout_rate: float,
|
18 |
+
max_len: int = 1500,
|
19 |
+
reverse: bool = False):
|
20 |
+
"""Construct an PositionalEncoding object."""
|
21 |
+
super().__init__()
|
22 |
+
self.d_model = d_model
|
23 |
+
self.xscale = math.sqrt(self.d_model)
|
24 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
25 |
+
self.max_len = max_len
|
26 |
+
|
27 |
+
self.pe = torch.zeros(self.max_len, self.d_model)
|
28 |
+
position = torch.arange(0, self.max_len,
|
29 |
+
dtype=torch.float32).unsqueeze(1)
|
30 |
+
div_term = torch.exp(
|
31 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
|
32 |
+
-(math.log(10000.0) / self.d_model))
|
33 |
+
self.pe[:, 0::2] = torch.sin(position * div_term)
|
34 |
+
self.pe[:, 1::2] = torch.cos(position * div_term)
|
35 |
+
self.pe = self.pe.unsqueeze(0)
|
36 |
+
|
37 |
+
def forward(self,
|
38 |
+
x: torch.Tensor,
|
39 |
+
offset: int = 0):
|
40 |
+
"""Add positional encoding.
|
41 |
+
Args:
|
42 |
+
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
43 |
+
offset (int): position offset
|
44 |
+
Returns:
|
45 |
+
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
46 |
+
torch.Tensor: for compatibility to RelPositionalEncoding
|
47 |
+
"""
|
48 |
+
assert offset + x.size(1) < self.max_len
|
49 |
+
self.pe = self.pe.to(x.device)
|
50 |
+
pos_emb = self.pe[:, offset:offset + x.size(1)]
|
51 |
+
x = x * self.xscale + pos_emb
|
52 |
+
return self.dropout(x), self.dropout(pos_emb)
|
53 |
+
|
54 |
+
def position_encoding(self, offset: int, size: int):
|
55 |
+
""" For getting encoding in a streaming fashion
|
56 |
+
Attention!!!!!
|
57 |
+
we apply dropout only once at the whole utterance level in a none
|
58 |
+
streaming way, but will call this function several times with
|
59 |
+
increasing input size in a streaming scenario, so the dropout will
|
60 |
+
be applied several times.
|
61 |
+
Args:
|
62 |
+
offset (int): start offset
|
63 |
+
size (int): requried size of position encoding
|
64 |
+
Returns:
|
65 |
+
torch.Tensor: Corresponding encoding
|
66 |
+
"""
|
67 |
+
assert offset + size < self.max_len
|
68 |
+
return self.dropout(self.pe[:, offset:offset + size])
|
69 |
+
|
70 |
+
class RelPositionalEncoding(PositionalEncoding):
|
71 |
+
"""Relative positional encoding module.
|
72 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
73 |
+
Args:
|
74 |
+
d_model (int): Embedding dimension.
|
75 |
+
dropout_rate (float): Dropout rate.
|
76 |
+
max_len (int): Maximum input length.
|
77 |
+
"""
|
78 |
+
def __init__(self, d_model: int, dropout_rate: float, chunk_size: int, left_chunks: int, max_len: int = 5000):
|
79 |
+
"""Initialize class."""
|
80 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
81 |
+
self.chunk_size = chunk_size
|
82 |
+
self.left_chunks = left_chunks
|
83 |
+
self.full_chunk_size = (self.left_chunks + 1) * self.chunk_size
|
84 |
+
|
85 |
+
self.div_term = torch.exp(
|
86 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
|
87 |
+
-(math.log(10000.0) / self.d_model))
|
88 |
+
self.max_len = self.chunk_size * (max_len // self.chunk_size) - self.full_chunk_size
|
89 |
+
|
90 |
+
def forward(self,
|
91 |
+
x: torch.Tensor,
|
92 |
+
offset: int = 0):
|
93 |
+
"""Compute positional encoding.
|
94 |
+
Args:
|
95 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
96 |
+
Returns:
|
97 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
98 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
99 |
+
"""
|
100 |
+
self.pe = self.pe.to(x.device)
|
101 |
+
x = x * self.xscale
|
102 |
+
pos_emb = self.pe[:, offset:offset + x.size(1)]
|
103 |
+
return self.dropout(x), self.dropout(pos_emb)
|
104 |
+
|
105 |
+
def infer(self, xs, pe_index, pe_length):
|
106 |
+
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
107 |
+
pe_index = pe_index % self.max_len
|
108 |
+
xs = xs * self.xscale
|
109 |
+
|
110 |
+
# pe = torch.zeros(self.full_chunk_size, self.d_model)
|
111 |
+
pe = torch.zeros(pe_length, self.d_model)
|
112 |
+
position = torch.arange(max(0, pe_index-self.full_chunk_size),
|
113 |
+
max(0, pe_index-self.full_chunk_size)
|
114 |
+
+ pe_length, # self.full_chunk_size,
|
115 |
+
dtype=torch.float32).unsqueeze(1)
|
116 |
+
pe[:, 0::2] = torch.sin(position * self.div_term)
|
117 |
+
pe[:, 1::2] = torch.cos(position * self.div_term)
|
118 |
+
pos_emb = pe.unsqueeze(0)
|
119 |
+
|
120 |
+
pe_index = pe_index + self.chunk_size
|
121 |
+
return xs, pos_emb, pe_index
|
122 |
+
|
123 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
124 |
+
"""Positionwise feed forward layer.
|
125 |
+
:param int idim: input dimenstion
|
126 |
+
:param int hidden_units: number of hidden units
|
127 |
+
:param float dropout_rate: dropout rate
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(self, idim, hidden_units, dropout_rate):
|
131 |
+
"""Construct an PositionwiseFeedForward object."""
|
132 |
+
super(PositionwiseFeedForward, self).__init__()
|
133 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
134 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
135 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
"""Forward funciton."""
|
139 |
+
return self.w_2(self.dropout(torch.relu(self.w_1(x))))
|
140 |
+
|
141 |
+
def infer(self, xs, buffer, buffer_index, buffer_out):
|
142 |
+
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
143 |
+
return self.w_2(torch.relu(self.w_1(xs))), buffer, buffer_index, buffer_out
|
144 |
+
|
145 |
+
class MultiLayeredConv1d(torch.nn.Module):
|
146 |
+
"""Multi-layered conv1d for Transformer block.
|
147 |
+
|
148 |
+
This is a module of multi-leyered conv1d designed
|
149 |
+
to replace positionwise feed-forward network
|
150 |
+
in Transformer block, which is introduced in
|
151 |
+
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
152 |
+
|
153 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
154 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
155 |
+
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
159 |
+
"""Initialize MultiLayeredConv1d module.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
in_chans (int): Number of input channels.
|
163 |
+
hidden_chans (int): Number of hidden channels.
|
164 |
+
kernel_size (int): Kernel size of conv1d.
|
165 |
+
dropout_rate (float): Dropout rate.
|
166 |
+
|
167 |
+
"""
|
168 |
+
super(MultiLayeredConv1d, self).__init__()
|
169 |
+
self.w_1 = torch.nn.Conv1d(
|
170 |
+
in_chans,
|
171 |
+
hidden_chans,
|
172 |
+
kernel_size,
|
173 |
+
stride=1,
|
174 |
+
padding=(kernel_size - 1) // 2,
|
175 |
+
)
|
176 |
+
self.w_2 = torch.nn.Conv1d(
|
177 |
+
hidden_chans,
|
178 |
+
in_chans,
|
179 |
+
kernel_size,
|
180 |
+
stride=1,
|
181 |
+
padding=(kernel_size - 1) // 2,
|
182 |
+
)
|
183 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
"""Calculate forward propagation.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
193 |
+
|
194 |
+
"""
|
195 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
196 |
+
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
197 |
+
|
198 |
+
class Conv1dLinear(torch.nn.Module):
|
199 |
+
"""Conv1D + Linear for Transformer block.
|
200 |
+
|
201 |
+
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
202 |
+
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
206 |
+
"""Initialize Conv1dLinear module.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
in_chans (int): Number of input channels.
|
210 |
+
hidden_chans (int): Number of hidden channels.
|
211 |
+
kernel_size (int): Kernel size of conv1d.
|
212 |
+
dropout_rate (float): Dropout rate.
|
213 |
+
|
214 |
+
"""
|
215 |
+
super(Conv1dLinear, self).__init__()
|
216 |
+
self.lorder = (kernel_size - 1)
|
217 |
+
self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
|
218 |
+
self.w_1 = torch.nn.Sequential(
|
219 |
+
torch.nn.Conv1d(
|
220 |
+
in_chans,
|
221 |
+
in_chans,
|
222 |
+
kernel_size,
|
223 |
+
stride=1,
|
224 |
+
padding=0,
|
225 |
+
groups=in_chans
|
226 |
+
),
|
227 |
+
torch.nn.Conv1d(
|
228 |
+
in_chans,
|
229 |
+
hidden_chans,
|
230 |
+
1,
|
231 |
+
padding=0
|
232 |
+
)
|
233 |
+
)
|
234 |
+
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
235 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
236 |
+
self.in_chans = in_chans
|
237 |
+
|
238 |
+
# cnn_buffer = 1, in_chans, self.lorder
|
239 |
+
self.buffer_size = 1 * self.in_chans * self.lorder
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
"""Calculate forward propagation.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
249 |
+
|
250 |
+
"""
|
251 |
+
x = torch.relu(self.w_1(self.left_padding(x.transpose(-1, 1)))).transpose(-1, 1)
|
252 |
+
return self.w_2(self.dropout(x))
|
253 |
+
|
254 |
+
def infer(self, x, buffer, buffer_index, buffer_out):
|
255 |
+
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
256 |
+
x = x.transpose(-1, 1)
|
257 |
+
|
258 |
+
cnn_buffer = buffer[buffer_index: buffer_index + self.buffer_size].reshape([1, self.in_chans, self.lorder])
|
259 |
+
x = torch.cat([cnn_buffer, x], dim=2)
|
260 |
+
buffer_out.append(x[:, :, -self.lorder:].reshape(-1))
|
261 |
+
buffer_index = buffer_index + self.buffer_size
|
262 |
+
|
263 |
+
x = self.w_1(x)
|
264 |
+
x = torch.relu(x).transpose(-1, 1)
|
265 |
+
x = self.w_2(x)
|
266 |
+
return x, buffer, buffer_index, buffer_out
|
267 |
+
|
268 |
+
class MultiHeadedAttention(nn.Module):
|
269 |
+
"""Multi-Head Attention layer.
|
270 |
+
|
271 |
+
:param int n_head: the number of head s
|
272 |
+
:param int n_feat: the number of features
|
273 |
+
:param float dropout_rate: dropout rate
|
274 |
+
|
275 |
+
"""
|
276 |
+
def __init__(self, n_head, n_feat, dropout_rate, chunk_size, left_chunks, pos_enc_class):
|
277 |
+
"""Construct an MultiHeadedAttention object."""
|
278 |
+
super(MultiHeadedAttention, self).__init__()
|
279 |
+
assert n_feat % n_head == 0
|
280 |
+
# We assume d_v always equals d_k
|
281 |
+
self.d_k = n_feat // n_head
|
282 |
+
self.h = n_head
|
283 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
284 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
285 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
286 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
287 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
288 |
+
self.min_value = float(numpy.finfo(torch.tensor(0, dtype=torch.float16).numpy().dtype).min)
|
289 |
+
# chunk par
|
290 |
+
if chunk_size > 0 and left_chunks > 0: #for streaming mode
|
291 |
+
self.buffersize = chunk_size * (left_chunks)
|
292 |
+
self.left_chunk_size = chunk_size * left_chunks
|
293 |
+
else: # for non-streaming mode
|
294 |
+
self.buffersize = 1
|
295 |
+
self.left_chunk_size = 1
|
296 |
+
self.chunk_size = chunk_size
|
297 |
+
|
298 |
+
# encoding setup
|
299 |
+
if pos_enc_class == "rel-enc":
|
300 |
+
self.rel_enc = True
|
301 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
302 |
+
# these two learnable bias are used in matrix c and matrix d
|
303 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
304 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
305 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
306 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
307 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
308 |
+
else:
|
309 |
+
self.rel_enc = False
|
310 |
+
self.linear_pos = nn.Identity()
|
311 |
+
self.pos_bias_u = torch.tensor([0])
|
312 |
+
self.pos_bias_v = torch.tensor([0])
|
313 |
+
|
314 |
+
# buffer
|
315 |
+
# key_buffer = 1, self.h, self.buffersize, self.d_k
|
316 |
+
self.key_buffer_size = 1 * self.h * self.buffersize * self.d_k
|
317 |
+
# value_buffer = 1, self.h, self.buffersize, self.d_k
|
318 |
+
self.value_buffer_size = 1 * self.h * self.buffersize * self.d_k
|
319 |
+
if self.chunk_size > 0:
|
320 |
+
# buffer_mask_size = 1, self.h, self.chunk_size, self.buffersize
|
321 |
+
self.buffer_mask_size = 1 * self.h * self.chunk_size * self.buffersize
|
322 |
+
else:
|
323 |
+
self.buffer_mask = torch.ones([1, self.h, 1, 1], dtype=torch.bool)
|
324 |
+
|
325 |
+
def rel_shift(self, x, zero_triu: bool = False):
|
326 |
+
"""Compute relative positinal encoding.
|
327 |
+
Args:
|
328 |
+
x (torch.Tensor): Input tensor (batch, time, size).
|
329 |
+
zero_triu (bool): If true, return the lower triangular part of
|
330 |
+
the matrix.
|
331 |
+
Returns:
|
332 |
+
torch.Tensor: Output tensor.
|
333 |
+
"""
|
334 |
+
|
335 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
336 |
+
device=x.device,
|
337 |
+
dtype=x.dtype)
|
338 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
339 |
+
|
340 |
+
x_padded = x_padded.view(x.size()[0],
|
341 |
+
x.size()[1],
|
342 |
+
x.size(3) + 1, x.size(2))
|
343 |
+
x = x_padded[:, :, 1:].view_as(x)
|
344 |
+
|
345 |
+
if zero_triu:
|
346 |
+
ones = torch.ones((x.size(2), x.size(3)))
|
347 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
348 |
+
return x
|
349 |
+
|
350 |
+
def forward(self, query, key, value, mask=None, pos_emb=torch.tensor(1.0)):
|
351 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tensor) -> Tensor
|
352 |
+
"""Compute 'Scaled Dot Product Attention'.
|
353 |
+
|
354 |
+
:param torch.Tensor query: (batch, time1, size)
|
355 |
+
:param torch.Tensor key: (batch, time2, size)
|
356 |
+
:param torch.Tensor value: (batch, time2, size)
|
357 |
+
:param torch.Tensor mask: (batch, time1, time2)
|
358 |
+
:param torch.nn.Dropout dropout:
|
359 |
+
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
|
360 |
+
weighted by the query dot key attention (batch, head, time1, time2)
|
361 |
+
"""
|
362 |
+
n_batch = query.size(0)
|
363 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
364 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
365 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
366 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
367 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
368 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
369 |
+
|
370 |
+
if self.rel_enc:
|
371 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
372 |
+
n_batch_pos = pos_emb.size(0)
|
373 |
+
p = self.linear_pos(pos_emb.to(query.dtype)).view(n_batch_pos, -1, self.h, self.d_k)
|
374 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
375 |
+
# (batch, head, time1, d_k)
|
376 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
377 |
+
# (batch, head, time1, d_k)
|
378 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
379 |
+
# compute attention score
|
380 |
+
# first compute matrix a and matrix c
|
381 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
382 |
+
# (batch, head, time1, time2)
|
383 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
384 |
+
# compute matrix b and matrix d
|
385 |
+
# (batch, head, time1, time2)
|
386 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
387 |
+
# Remove rel_shift since it is useless in speech recognition,
|
388 |
+
# and it requires special attention for streaming.
|
389 |
+
# matrix_bd = self.rel_shift(matrix_bd)
|
390 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
|
391 |
+
else:
|
392 |
+
scores = torch.matmul(q, k.transpose(-2, -1) ) / math.sqrt(self.d_k) # (batch, head, time1, time2)
|
393 |
+
|
394 |
+
if mask is not None:
|
395 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
396 |
+
scores = scores.masked_fill(mask, self.min_value)
|
397 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
398 |
+
else:
|
399 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
400 |
+
|
401 |
+
p_attn = self.dropout(attn)
|
402 |
+
|
403 |
+
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
|
404 |
+
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
405 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
406 |
+
|
407 |
+
def infer(self, query, key, value, pos_emb, buffer, buffer_index, buffer_out):
|
408 |
+
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
409 |
+
n_batch = query.size(0)
|
410 |
+
|
411 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2) # (batch, head, len_q, d_k)
|
412 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2) # (batch, head, len_k, d_k)
|
413 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2) # (batch, head, len_v, d_k)
|
414 |
+
|
415 |
+
key_value_buffer = buffer[buffer_index]
|
416 |
+
if buffer[buffer_index] is None:
|
417 |
+
buffer[buffer_index] = [None, None]
|
418 |
+
key_buffer = k
|
419 |
+
value_buffer = v
|
420 |
+
else:
|
421 |
+
key_buffer = torch.cat([key_value_buffer[0], k], dim=2)
|
422 |
+
value_buffer = torch.cat([key_value_buffer[1], v], dim=2)
|
423 |
+
if key_buffer.size(2) > self.buffersize:
|
424 |
+
buffer[buffer_index][0] = key_buffer[:, :, -self.buffersize:, :]
|
425 |
+
buffer[buffer_index][1] = value_buffer[:, :, -self.buffersize:, :]
|
426 |
+
else:
|
427 |
+
buffer[buffer_index] = [key_buffer, value_buffer]
|
428 |
+
buffer_index += 1
|
429 |
+
|
430 |
+
if self.rel_enc:
|
431 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
432 |
+
n_batch_pos = pos_emb.size(0)
|
433 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
434 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
435 |
+
# (batch, head, time1, d_k)
|
436 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
437 |
+
# (batch, head, time1, d_k)
|
438 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
439 |
+
# compute attention score
|
440 |
+
# first compute matrix a and matrix c
|
441 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
442 |
+
# (batch, head, time1, time2)
|
443 |
+
matrix_ac = torch.matmul(q_with_bias_u, key_buffer.transpose(-2, -1))
|
444 |
+
# compute matrix b and matrix d
|
445 |
+
# (batch, head, time1, time2)
|
446 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
447 |
+
# Remove rel_shift since it is useless in speech recognition,
|
448 |
+
# and it requires special attention for streaming.
|
449 |
+
# matrix_bd = self.rel_shift(matrix_bd)
|
450 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
|
451 |
+
else:
|
452 |
+
# (batch, head, len_q, buffersize)
|
453 |
+
scores = torch.matmul(q, key_buffer.transpose(-2, -1) ) / math.sqrt(self.d_k)
|
454 |
+
|
455 |
+
attn = torch.softmax(scores, dim=-1)
|
456 |
+
|
457 |
+
x = torch.matmul(attn, value_buffer) # (batch, head, len_q, d_k)
|
458 |
+
x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
459 |
+
return self.linear_out(x), buffer, buffer_index, buffer_out # (batch, time1, d_model)
|
vita/model/vita_tts/encoder/cmvn.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class GlobalCMVN(torch.nn.Module):
|
8 |
+
def __init__(self,
|
9 |
+
mean: torch.Tensor,
|
10 |
+
istd: torch.Tensor,
|
11 |
+
norm_var: bool = True):
|
12 |
+
"""
|
13 |
+
Args:
|
14 |
+
mean (torch.Tensor): mean stats
|
15 |
+
istd (torch.Tensor): inverse std, std which is 1.0 / std
|
16 |
+
"""
|
17 |
+
super().__init__()
|
18 |
+
assert mean.shape == istd.shape
|
19 |
+
self.norm_var = norm_var
|
20 |
+
# The buffer can be accessed from this module using self.mean
|
21 |
+
self.register_buffer("mean", mean)
|
22 |
+
self.register_buffer("istd", istd)
|
23 |
+
|
24 |
+
def forward(self, x: torch.Tensor):
|
25 |
+
"""
|
26 |
+
Args:
|
27 |
+
x (torch.Tensor): (batch, max_len, feat_dim)
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
(torch.Tensor): normalized feature
|
31 |
+
"""
|
32 |
+
x = x - self.mean
|
33 |
+
if self.norm_var:
|
34 |
+
x = x * self.istd
|
35 |
+
return x
|
36 |
+
|
37 |
+
def _load_json_cmvn(json_cmvn_file):
|
38 |
+
""" Load the json format cmvn stats file and calculate cmvn
|
39 |
+
|
40 |
+
Args:
|
41 |
+
json_cmvn_file: cmvn stats file in json format
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
a numpy array of [means, vars]
|
45 |
+
"""
|
46 |
+
with open(json_cmvn_file) as f:
|
47 |
+
cmvn_stats = json.load(f)
|
48 |
+
|
49 |
+
means = cmvn_stats['mean_stat']
|
50 |
+
variance = cmvn_stats['var_stat']
|
51 |
+
count = cmvn_stats['frame_num']
|
52 |
+
for i in range(len(means)):
|
53 |
+
means[i] /= count
|
54 |
+
variance[i] = variance[i] / count - means[i] * means[i]
|
55 |
+
if variance[i] < 1.0e-20:
|
56 |
+
variance[i] = 1.0e-20
|
57 |
+
variance[i] = 1.0 / math.sqrt(variance[i])
|
58 |
+
cmvn = np.array([means, variance])
|
59 |
+
return cmvn
|
60 |
+
|
61 |
+
def _load_kaldi_cmvn(kaldi_cmvn_file):
|
62 |
+
""" Load the kaldi format cmvn stats file and calculate cmvn
|
63 |
+
|
64 |
+
Args:
|
65 |
+
kaldi_cmvn_file: kaldi text style global cmvn file, which
|
66 |
+
is generated by:
|
67 |
+
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
a numpy array of [means, vars]
|
71 |
+
"""
|
72 |
+
means = []
|
73 |
+
variance = []
|
74 |
+
with open(kaldi_cmvn_file, 'r') as fid:
|
75 |
+
# kaldi binary file start with '\0B'
|
76 |
+
if fid.read(2) == '\0B':
|
77 |
+
print('kaldi cmvn binary file is not supported, please '
|
78 |
+
'recompute it by: compute-cmvn-stats --binary=false '
|
79 |
+
' scp:feats.scp global_cmvn')
|
80 |
+
sys.exit(1)
|
81 |
+
fid.seek(0)
|
82 |
+
arr = fid.read().split()
|
83 |
+
assert (arr[0] == '[')
|
84 |
+
assert (arr[-2] == '0')
|
85 |
+
assert (arr[-1] == ']')
|
86 |
+
feat_dim = int((len(arr) - 2 - 2) / 2)
|
87 |
+
for i in range(1, feat_dim + 1):
|
88 |
+
means.append(float(arr[i]))
|
89 |
+
count = float(arr[feat_dim + 1])
|
90 |
+
for i in range(feat_dim + 2, 2 * feat_dim + 2):
|
91 |
+
variance.append(float(arr[i]))
|
92 |
+
|
93 |
+
for i in range(len(means)):
|
94 |
+
means[i] /= count
|
95 |
+
variance[i] = variance[i] / count - means[i] * means[i]
|
96 |
+
if variance[i] < 1.0e-20:
|
97 |
+
variance[i] = 1.0e-20
|
98 |
+
variance[i] = 1.0 / math.sqrt(variance[i])
|
99 |
+
cmvn = np.array([means, variance])
|
100 |
+
return cmvn
|
101 |
+
|
102 |
+
def load_cmvn(cmvn_file, is_json):
|
103 |
+
if is_json:
|
104 |
+
cmvn = _load_json_cmvn(cmvn_file)
|
105 |
+
else:
|
106 |
+
cmvn = _load_kaldi_cmvn(cmvn_file)
|
107 |
+
return cmvn[0], cmvn[1]
|
vita/model/vita_tts/encoder/encoder.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
from typing import Tuple, Dict, Optional
|
8 |
+
|
9 |
+
from vita.model.vita_tts.encoder.transformer import Transformer
|
10 |
+
from vita.model.vita_tts.encoder.subsampling import Subsampling
|
11 |
+
|
12 |
+
def add_encoder_args(group):
|
13 |
+
"""Add Encoder common arguments."""
|
14 |
+
group.add_argument(
|
15 |
+
"--encoder-layer-config",
|
16 |
+
type=str,
|
17 |
+
default="tdnn-dtc",
|
18 |
+
help="Layer config of encoder. Format layername-layername-...",
|
19 |
+
)
|
20 |
+
group.add_argument(
|
21 |
+
"--encoder-input-dim",
|
22 |
+
type=int,
|
23 |
+
default=256,
|
24 |
+
help="Input dim of encoder. Must equal to the input dim of the first Component (default=40)"
|
25 |
+
)
|
26 |
+
group.add_argument(
|
27 |
+
"--encoder-output-dim",
|
28 |
+
type=int,
|
29 |
+
default=256,
|
30 |
+
help="Output dim of encoder. Must enqual to the output dim of the last Component ! (default=256)"
|
31 |
+
)
|
32 |
+
group = Transformer.add_arguments(group)
|
33 |
+
group = Subsampling.add_arguments(group)
|
34 |
+
return group
|
35 |
+
|
36 |
+
def assign_args_from_dict(args, dict, prefix_key=None):
|
37 |
+
if prefix_key is not None:
|
38 |
+
dict = dict[prefix_key]
|
39 |
+
for k, v in dict.items():
|
40 |
+
k_args = k.replace('-', '_')
|
41 |
+
if hasattr(args, k_args):
|
42 |
+
setattr(args, k_args, dict[k])
|
43 |
+
return args
|
44 |
+
|
45 |
+
class speechEncoder(torch.nn.Module):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
input_dim,
|
49 |
+
overview_conf = None,
|
50 |
+
para_conf = None,
|
51 |
+
global_cmvn = None):
|
52 |
+
super(speechEncoder, self).__init__()
|
53 |
+
|
54 |
+
parser = argparse.ArgumentParser()
|
55 |
+
add_encoder_args(parser)
|
56 |
+
args, _ = parser.parse_known_args()
|
57 |
+
assign_args_from_dict(args, overview_conf)
|
58 |
+
|
59 |
+
self.config = args.encoder_layer_config.split('-')
|
60 |
+
encoder_input_dim = args.encoder_input_dim
|
61 |
+
encoder_output_dim = args.encoder_output_dim
|
62 |
+
prev_output_dim = encoder_input_dim
|
63 |
+
prev_component_name = "encoder"
|
64 |
+
|
65 |
+
self.global_cmvn = global_cmvn
|
66 |
+
self.enc = torch.nn.ModuleList([])
|
67 |
+
for name in self.config:
|
68 |
+
assign_args_from_dict(args, para_conf[name])
|
69 |
+
if len(name.split('_')) == 2:
|
70 |
+
name = name.split('_')[0]
|
71 |
+
elif len(name.split('_')) == 1:
|
72 |
+
name = name
|
73 |
+
else:
|
74 |
+
print("WRONG CONFIG! {} is not valid".format("encoder", name))
|
75 |
+
sys.exit()
|
76 |
+
if name == "transformer":
|
77 |
+
self.enc.append(Transformer(args))
|
78 |
+
elif name == "subsampling":
|
79 |
+
self.enc.append(Subsampling(args))
|
80 |
+
else:
|
81 |
+
print("{} is not supported now!".format(name))
|
82 |
+
return NotImplemented
|
83 |
+
component_input_dim = getattr(args, name + "_input_dim")
|
84 |
+
if component_input_dim != prev_output_dim:
|
85 |
+
print("WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-input-dim ({})"
|
86 |
+
.format(prev_component_name, prev_output_dim, name, component_input_dim))
|
87 |
+
sys.exit()
|
88 |
+
prev_output_dim = getattr(args, name + "_output_dim")
|
89 |
+
prev_component_name = name
|
90 |
+
|
91 |
+
if (prev_output_dim != encoder_output_dim):
|
92 |
+
print("WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-output-dim ({}, the last component)"
|
93 |
+
.format("encoder", encoder_output_dim, name, prev_output_dim))
|
94 |
+
sys.exit()
|
95 |
+
|
96 |
+
self._output_size = encoder_output_dim
|
97 |
+
|
98 |
+
num_params = sum(p.numel() for p in self.parameters())
|
99 |
+
print('the number of speech encoder params: {}M'.format(num_params/1024/1024))
|
100 |
+
|
101 |
+
def output_size(self) -> int:
|
102 |
+
return self._output_size
|
103 |
+
|
104 |
+
def forward(self, xs, ilens, decoding_chunk_size=None, num_decoding_left_chunks=None):
|
105 |
+
"""
|
106 |
+
Forward pass through the encoder.
|
107 |
+
|
108 |
+
Parameters:
|
109 |
+
- xs: torch.Tensor, shape (batch_size, sequence_length, input_dim)
|
110 |
+
The input tensor containing the sequence of input vectors.
|
111 |
+
- batch_size: The number of sequences in the batch.
|
112 |
+
- sequence_length: The length of each sequence.
|
113 |
+
- input_dim: The dimensionality of each input vector.
|
114 |
+
|
115 |
+
- ilens: torch.Tensor, shape (batch_size,)
|
116 |
+
The lengths of each sequence in the batch, used for padding masks.
|
117 |
+
|
118 |
+
- decoding_chunk_size: int, optional (default=None)
|
119 |
+
The size of chunks to use for decoding
|
120 |
+
|
121 |
+
- num_decoding_left_chunks: int, optional (default=None)
|
122 |
+
The number of left chunks to use for decoding
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
- xs: torch.Tensor, shape (batch_size, sequence_length, encoded_dim)
|
126 |
+
The encoded output tensor, where encoded_dim is the dimensionality of the encoded representation.
|
127 |
+
|
128 |
+
- masks: torch.Tensor, shape (batch_size, 1, sequence_length)
|
129 |
+
The padding mask tensor, where True indicates valid elements and False indicates padded elements.
|
130 |
+
"""
|
131 |
+
if decoding_chunk_size is not None and num_decoding_left_chunks is not None:
|
132 |
+
for layer in self.enc:
|
133 |
+
if hasattr(layer, "chunk_size"):
|
134 |
+
layer.chunk_size = decoding_chunk_size
|
135 |
+
if hasattr(layer, "left_chunks"):
|
136 |
+
layer.left_chunks = num_decoding_left_chunks
|
137 |
+
if hasattr(layer, "transformer_dynamic_chunks"):
|
138 |
+
layer.transformer_dynamic_chunks = False
|
139 |
+
|
140 |
+
assert(len(xs.shape)) == 3
|
141 |
+
T = xs.size(1)
|
142 |
+
masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T)
|
143 |
+
if self.global_cmvn is not None:
|
144 |
+
xs = self.global_cmvn(xs)
|
145 |
+
for module in self.enc:
|
146 |
+
xs, ilens, masks = module(xs, ilens, masks)
|
147 |
+
return xs, masks
|
148 |
+
|
149 |
+
def infer(self, xs_pad, buffer, buffer_index, buffer_out, pe_index):
|
150 |
+
if self.global_cmvn is not None:
|
151 |
+
xs_pad = self.global_cmvn(xs_pad)
|
152 |
+
for module in self.enc:
|
153 |
+
xs_pad, buffer, buffer_index, buffer_out, pe_index = module.infer(xs_pad,
|
154 |
+
buffer, buffer_index, buffer_out, pe_index)
|
155 |
+
return xs_pad, buffer, buffer_index, buffer_out, pe_index
|
vita/model/vita_tts/encoder/subsampling.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class BaseSubsampling(torch.nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
self.right_context = 0
|
9 |
+
self.subsampling_rate = 1
|
10 |
+
|
11 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
12 |
+
size: int) -> torch.Tensor:
|
13 |
+
return self.pos_enc.position_encoding(offset, size)
|
14 |
+
|
15 |
+
class Conv2dSubsampling4(BaseSubsampling):
|
16 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
17 |
+
|
18 |
+
Args:
|
19 |
+
idim (int): Input dimension.
|
20 |
+
odim (int): Output dimension.
|
21 |
+
dropout_rate (float): Dropout rate.
|
22 |
+
|
23 |
+
"""
|
24 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float):
|
25 |
+
"""Construct an Conv2dSubsampling4 object."""
|
26 |
+
super().__init__()
|
27 |
+
self.conv = torch.nn.Sequential(
|
28 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
29 |
+
torch.nn.ReLU(),
|
30 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
31 |
+
torch.nn.ReLU(),
|
32 |
+
)
|
33 |
+
self.out = torch.nn.Sequential(
|
34 |
+
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
|
35 |
+
# The right context for every conv layer is computed by:
|
36 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
37 |
+
self.subsampling_rate = 4
|
38 |
+
# 6 = (3 - 1) * 1 + (3 - 1) * 2
|
39 |
+
self.right_context = 6
|
40 |
+
|
41 |
+
def forward(
|
42 |
+
self,
|
43 |
+
x: torch.Tensor,
|
44 |
+
x_mask: torch.Tensor
|
45 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
46 |
+
"""Subsample x.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
50 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
54 |
+
where time' = time // 4.
|
55 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
56 |
+
where time' = time // 4.
|
57 |
+
torch.Tensor: positional encoding
|
58 |
+
|
59 |
+
"""
|
60 |
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
61 |
+
x = self.conv(x)
|
62 |
+
b, c, t, f = x.size()
|
63 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
64 |
+
|
65 |
+
return x, x_mask[:, :, 2::2][:, :, 2::2]
|
66 |
+
|
67 |
+
def infer(self, x, buffer, buffer_index, buffer_out):
|
68 |
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
69 |
+
x = self.conv(x)
|
70 |
+
b, c, t, f = x.size()
|
71 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
72 |
+
|
73 |
+
return x, buffer, buffer_index, buffer_out
|
74 |
+
|
75 |
+
class Subsampling(torch.nn.Module):
|
76 |
+
@staticmethod
|
77 |
+
def add_arguments(group):
|
78 |
+
"""Add Subsampling common arguments."""
|
79 |
+
group.add_argument('--subsampling-rate', default=4, type=int)
|
80 |
+
group.add_argument('--subsampling-input-dim', default=256, type=int)
|
81 |
+
group.add_argument('--subsampling-output-dim', default=256, type=int)
|
82 |
+
group.add_argument('--subsampling-dropout-rate', default=0.1, type=float)
|
83 |
+
|
84 |
+
return group
|
85 |
+
|
86 |
+
def __init__(self, args):
|
87 |
+
super().__init__()
|
88 |
+
self.subsampling_rate = args.subsampling_rate
|
89 |
+
self.subsampling_input_dim = args.subsampling_input_dim
|
90 |
+
self.subsampling_output_dim = args.subsampling_output_dim
|
91 |
+
self.subsampling_dropout_rate = args.subsampling_dropout_rate
|
92 |
+
|
93 |
+
if self.subsampling_rate == 4:
|
94 |
+
self.core = Conv2dSubsampling4(self.subsampling_input_dim,
|
95 |
+
self.subsampling_output_dim,
|
96 |
+
self.subsampling_dropout_rate)
|
97 |
+
|
98 |
+
def forward(self, xs, ilens, masks):
|
99 |
+
xs, masks = self.core(xs, masks)
|
100 |
+
ilens = masks.squeeze(1).sum(1)
|
101 |
+
return xs, ilens, masks
|
102 |
+
|
103 |
+
def infer(self, x, buffer, buffer_index, buffer_out, pe_index):
|
104 |
+
x, buffer, buffer_index, buffer_out = self.core.infer(x,
|
105 |
+
buffer, buffer_index, buffer_out)
|
106 |
+
return x, buffer, buffer_index, buffer_out, pe_index
|
vita/model/vita_tts/encoder/transformer.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
import pdb
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from vita.model.vita_tts.encoder.attention import *
|
10 |
+
from vita.model.vita_tts.masks import *
|
11 |
+
|
12 |
+
IGNORE_ID = -1
|
13 |
+
|
14 |
+
def strtobool(x):
|
15 |
+
return bool(dist_strtobool(x))
|
16 |
+
|
17 |
+
def repeat(N, fn):
|
18 |
+
"""Repeat module N times.
|
19 |
+
|
20 |
+
:param int N: repeat time
|
21 |
+
:param function fn: function to generate module
|
22 |
+
:return: repeated modules
|
23 |
+
:rtype: MultiSequential
|
24 |
+
"""
|
25 |
+
return MultiSequential(*[fn(n) for n in range(N)])
|
26 |
+
|
27 |
+
class MultiSequential(torch.nn.Sequential):
|
28 |
+
"""Multi-input multi-output torch.nn.Sequential."""
|
29 |
+
def forward(self, x, masks, pos_emb):
|
30 |
+
|
31 |
+
"""Repeat."""
|
32 |
+
for m in self:
|
33 |
+
x, masks, pos_emb = m(x, masks, pos_emb)
|
34 |
+
return x, masks, pos_emb
|
35 |
+
|
36 |
+
@torch.jit.export
|
37 |
+
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
|
38 |
+
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
|
39 |
+
"""Repeat."""
|
40 |
+
for m in self:
|
41 |
+
x, pos_emb, buffer, buffer_index, buffer_out = m.infer(x, pos_emb, buffer, buffer_index, buffer_out)
|
42 |
+
return x, pos_emb, buffer, buffer_index, buffer_out
|
43 |
+
|
44 |
+
class TransformerLayer(nn.Module):
|
45 |
+
"""Transformer layer module.
|
46 |
+
|
47 |
+
:param int size: input dim
|
48 |
+
:param self_attn: self attention module
|
49 |
+
:param feed_forward: feed forward module
|
50 |
+
:param float dropout_rate: dropout rate
|
51 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
52 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
53 |
+
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
|
54 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
55 |
+
|
56 |
+
"""
|
57 |
+
def __init__(self, size, self_attn, feed_forward, dropout_rate,
|
58 |
+
normalize_before=True, concat_after=False):
|
59 |
+
"""Construct an TransformerLayer object."""
|
60 |
+
super(TransformerLayer, self).__init__()
|
61 |
+
self.self_attn = self_attn
|
62 |
+
self.feed_forward = feed_forward
|
63 |
+
self.norm1 = torch.nn.LayerNorm(size)
|
64 |
+
self.norm2 = torch.nn.LayerNorm(size)
|
65 |
+
self.dropout = nn.Dropout(dropout_rate)
|
66 |
+
self.size = size
|
67 |
+
self.normalize_before = normalize_before
|
68 |
+
self.concat_after = concat_after
|
69 |
+
if self.concat_after:
|
70 |
+
self.concat_linear = nn.Linear(size + size, size)
|
71 |
+
else:
|
72 |
+
self.concat_linear = nn.Identity()
|
73 |
+
|
74 |
+
@torch.jit.unused
|
75 |
+
def forward(self, x, mask, pos_emb):
|
76 |
+
"""Compute encoded features.
|
77 |
+
|
78 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
79 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
80 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
81 |
+
"""
|
82 |
+
residual = x
|
83 |
+
if self.normalize_before:
|
84 |
+
x = self.norm1(x)
|
85 |
+
if self.concat_after:
|
86 |
+
x_concat = torch.cat((x, self.self_attn(x, x, x, mask, pos_emb)), dim=-1)
|
87 |
+
x = residual + self.concat_linear(x_concat)
|
88 |
+
else:
|
89 |
+
x = residual + self.dropout(self.self_attn(x, x, x, mask, pos_emb))
|
90 |
+
if not self.normalize_before:
|
91 |
+
x = self.norm1(x)
|
92 |
+
|
93 |
+
residual = x
|
94 |
+
if self.normalize_before:
|
95 |
+
x = self.norm2(x)
|
96 |
+
x = residual + self.dropout(self.feed_forward(x))
|
97 |
+
if not self.normalize_before:
|
98 |
+
x = self.norm2(x)
|
99 |
+
|
100 |
+
return x, mask, pos_emb
|
101 |
+
|
102 |
+
@torch.jit.export
|
103 |
+
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
|
104 |
+
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
|
105 |
+
residual = x.clone()
|
106 |
+
if self.normalize_before:
|
107 |
+
x = self.norm1(x)
|
108 |
+
if self.concat_after:
|
109 |
+
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(x, x, x,
|
110 |
+
pos_emb, buffer,
|
111 |
+
buffer_index, buffer_out)
|
112 |
+
x_concat = torch.cat((x, x_att), dim=-1)
|
113 |
+
x = residual + self.concat_linear(x_concat)
|
114 |
+
else:
|
115 |
+
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(x, x, x,
|
116 |
+
pos_emb, buffer,
|
117 |
+
buffer_index, buffer_out)
|
118 |
+
x = residual + x_att
|
119 |
+
if not self.normalize_before:
|
120 |
+
x = self.norm1(x)
|
121 |
+
|
122 |
+
residual = x.clone()
|
123 |
+
if self.normalize_before:
|
124 |
+
x = self.norm2(x)
|
125 |
+
x_feed, buffer, buffer_index, buffer_out = self.feed_forward.infer(x, buffer, buffer_index, buffer_out)
|
126 |
+
x = residual + x_feed
|
127 |
+
if not self.normalize_before:
|
128 |
+
x = self.norm2(x)
|
129 |
+
|
130 |
+
return x, pos_emb, buffer, buffer_index, buffer_out
|
131 |
+
|
132 |
+
class Transformer(torch.nn.Module):
|
133 |
+
@staticmethod
|
134 |
+
def add_arguments(group):
|
135 |
+
"""Add TDNN common arguments."""
|
136 |
+
group.add_argument('--transformer-input-dim', default=256, type=int)
|
137 |
+
group.add_argument('--transformer-output-dim', default=4, type=int)
|
138 |
+
group.add_argument('--transformer-attention-dim', default=256, type=int)
|
139 |
+
group.add_argument('--transformer-attention-heads', default=4, type=int)
|
140 |
+
group.add_argument('--transformer-linear-units', default=1024, type=int)
|
141 |
+
group.add_argument('--transformer-num-blocks', default=6, type=int)
|
142 |
+
group.add_argument('--transformer-dropout-rate', default=0.1, type=float)
|
143 |
+
group.add_argument('--transformer-attention-dropout-rate', default=0.0, type=float)
|
144 |
+
group.add_argument('--transformer-positional-dropout-rate', default=0.1, type=float)
|
145 |
+
group.add_argument('--transformer-input-layer', default='linear', type=str)
|
146 |
+
group.add_argument('--transformer-pos-enc-class', default='abs-enc', type=str)
|
147 |
+
group.add_argument('--transformer-normalize-before', default=True, type=strtobool)
|
148 |
+
group.add_argument('--transformer-concat-after', default=False, type=strtobool)
|
149 |
+
group.add_argument('--transformer-positionwise-layer-type', default='linear', type=str)
|
150 |
+
group.add_argument('--transformer-positionwise-conv-kernel_size', default=1, type=int)
|
151 |
+
group.add_argument('--transformer-chunk_size', default=-1, type=int)
|
152 |
+
group.add_argument('--transformer-left_chunks', default=-1, type=int)
|
153 |
+
group.add_argument('--transformer-dynamic-chunks', default=True, type=strtobool)
|
154 |
+
return group
|
155 |
+
|
156 |
+
def __init__(self, args):
|
157 |
+
"""Construct an Encoder object."""
|
158 |
+
super(Transformer, self).__init__()
|
159 |
+
|
160 |
+
self.input_dim = args.transformer_input_dim
|
161 |
+
self.output_dim = args.transformer_output_dim
|
162 |
+
self.attention_dim = args.transformer_attention_dim
|
163 |
+
self.attention_heads = args.transformer_attention_heads
|
164 |
+
self.linear_units = args.transformer_linear_units
|
165 |
+
self.num_blocks = args.transformer_num_blocks
|
166 |
+
self.dropout_rate = args.transformer_dropout_rate
|
167 |
+
self.positional_dropout_rate = args.transformer_positional_dropout_rate
|
168 |
+
self.attention_dropout_rate = args.transformer_attention_dropout_rate
|
169 |
+
self.input_layer = args.transformer_input_layer
|
170 |
+
self.pos_enc_class = args.transformer_pos_enc_class
|
171 |
+
self.normalize_before = args.transformer_normalize_before
|
172 |
+
self.concat_after = args.transformer_concat_after
|
173 |
+
self.positionwise_layer_type = args.transformer_positionwise_layer_type
|
174 |
+
self.positionwise_conv_kernel_size = args.transformer_positionwise_conv_kernel_size
|
175 |
+
self.chunk_size = args.transformer_chunk_size
|
176 |
+
self.left_chunks = args.transformer_left_chunks
|
177 |
+
self.transformer_dynamic_chunks = args.transformer_dynamic_chunks
|
178 |
+
|
179 |
+
if self.pos_enc_class == "abs-enc":
|
180 |
+
pos_enc_args = (self.attention_dim, self.positional_dropout_rate)
|
181 |
+
pos_enc_class = PositionalEncoding
|
182 |
+
elif self.pos_enc_class == "rel-enc":
|
183 |
+
pos_enc_args = (self.attention_dim, self.positional_dropout_rate, self.chunk_size, self.left_chunks)
|
184 |
+
pos_enc_class = RelPositionalEncoding
|
185 |
+
|
186 |
+
if self.input_layer == "linear":
|
187 |
+
self.embed = torch.nn.Sequential(
|
188 |
+
torch.nn.Linear(self.input_dim, self.attention_dim),
|
189 |
+
torch.nn.LayerNorm(self.attention_dim),
|
190 |
+
torch.nn.Dropout(self.dropout_rate),
|
191 |
+
torch.nn.ReLU()
|
192 |
+
)
|
193 |
+
elif self.input_layer == "embed":
|
194 |
+
self.embed = torch.nn.Sequential(
|
195 |
+
torch.nn.Embedding(self.input_dim, self.attention_dim, padding_idx=IGNORE_ID)
|
196 |
+
)
|
197 |
+
elif self.input_layer == "none":
|
198 |
+
self.embed = torch.nn.Sequential(
|
199 |
+
torch.nn.Identity()
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
raise ValueError("unknown input_layer: " + self.input_layer)
|
203 |
+
self.pe = pos_enc_class(*pos_enc_args)
|
204 |
+
self.embed_layer_num = len(self.embed)
|
205 |
+
|
206 |
+
if self.positionwise_layer_type == "linear":
|
207 |
+
positionwise_layer = PositionwiseFeedForward
|
208 |
+
positionwise_layer_args = (self.attention_dim, self.linear_units, self.dropout_rate)
|
209 |
+
elif self.positionwise_layer_type == "conv1d":
|
210 |
+
positionwise_layer = MultiLayeredConv1d
|
211 |
+
positionwise_layer_args = (self.attention_dim, self.linear_units,
|
212 |
+
self.positionwise_conv_kernel_size, self.dropout_rate)
|
213 |
+
elif self.positionwise_layer_type == "conv1d-linear":
|
214 |
+
positionwise_layer = Conv1dLinear
|
215 |
+
positionwise_layer_args = (self.attention_dim, self.linear_units,
|
216 |
+
self.positionwise_conv_kernel_size, self.dropout_rate)
|
217 |
+
else:
|
218 |
+
raise NotImplementedError("Support only linear or conv1d.")
|
219 |
+
|
220 |
+
self.encoders = repeat(
|
221 |
+
self.num_blocks,
|
222 |
+
lambda lnum: TransformerLayer(
|
223 |
+
self.attention_dim,
|
224 |
+
MultiHeadedAttention(self.attention_heads, self.attention_dim,
|
225 |
+
self.attention_dropout_rate, self.chunk_size,
|
226 |
+
self.left_chunks, self.pos_enc_class),
|
227 |
+
positionwise_layer(*positionwise_layer_args),
|
228 |
+
self.dropout_rate,
|
229 |
+
self.normalize_before,
|
230 |
+
self.concat_after
|
231 |
+
)
|
232 |
+
)
|
233 |
+
if self.normalize_before:
|
234 |
+
self.after_norm = torch.nn.LayerNorm(self.attention_dim)
|
235 |
+
|
236 |
+
@torch.jit.unused
|
237 |
+
def forward(self, xs, ilens=None, masks=None):
|
238 |
+
"""Embed positions in tensor.
|
239 |
+
|
240 |
+
:param torch.Tensor xs: input tensor
|
241 |
+
:param torch.Tensor masks: input mask
|
242 |
+
:return: position embedded tensor and mask
|
243 |
+
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
244 |
+
"""
|
245 |
+
if self.transformer_dynamic_chunks == True: # and self.training:
|
246 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
247 |
+
True,
|
248 |
+
True,
|
249 |
+
0,
|
250 |
+
0,
|
251 |
+
-1)
|
252 |
+
else:
|
253 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
254 |
+
False,
|
255 |
+
False,
|
256 |
+
self.chunk_size,
|
257 |
+
self.chunk_size,
|
258 |
+
self.left_chunks).to(xs.device)
|
259 |
+
xs = self.embed(xs)
|
260 |
+
xs, pos_emb = self.pe(xs)
|
261 |
+
xs, chunk_masks, pos_emb = self.encoders(xs, chunk_masks, pos_emb)
|
262 |
+
if self.normalize_before:
|
263 |
+
xs = self.after_norm(xs)
|
264 |
+
return xs, ilens, masks
|
265 |
+
|
266 |
+
@torch.jit.export
|
267 |
+
def infer(self, xs, buffer, buffer_index, buffer_out, pe_index):
|
268 |
+
xs = self.embed(xs)
|
269 |
+
|
270 |
+
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
|
271 |
+
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
|
272 |
+
# buffer_out.append(pe_index.reshape(-1).to(torch.float32))
|
273 |
+
# buffer_index = buffer_index + 1
|
274 |
+
if buffer[0] is None:
|
275 |
+
pe_length = xs.size(1)
|
276 |
+
else:
|
277 |
+
pe_length = buffer[0][0].size(2) + xs.size(1)
|
278 |
+
xs, pos_emb, pe_index = self.pe.infer(xs, pe_index, pe_length)
|
279 |
+
pos_emb = pos_emb.to('cuda')
|
280 |
+
xs, pos_emb, buffer, buffer_index, buffer_out = self.encoders.infer(xs, pos_emb,
|
281 |
+
buffer, buffer_index, buffer_out)
|
282 |
+
|
283 |
+
if self.normalize_before:
|
284 |
+
xs = self.after_norm(xs)
|
285 |
+
return xs, buffer, buffer_index, buffer_out, pe_index
|
vita/model/vita_tts/masks.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def casual_chunk_mask(ilens, chunk_size, left_chunks=1):
|
4 |
+
# type: (List[int], int, int) -> Tensor
|
5 |
+
# param ilens: list, (B, )
|
6 |
+
# param chunk_size: int
|
7 |
+
# return chunk_mask: torch.Tensor, (B, T, T)
|
8 |
+
B = len(ilens)
|
9 |
+
T = max(ilens)
|
10 |
+
chunk_mask = torch.zeros(B, T, T)
|
11 |
+
for b in range(0, B):
|
12 |
+
if chunk_size == -1 :
|
13 |
+
chunk_mask[b, 0:ilens[b], 0:ilens[b]] = 1
|
14 |
+
else:
|
15 |
+
for t in range(0, ilens[b], chunk_size):
|
16 |
+
ty_start = t
|
17 |
+
ty_end = min(t + chunk_size, ilens[b])
|
18 |
+
tx_start = max(t - chunk_size * left_chunks, 0)
|
19 |
+
tx_end = min(t + chunk_size, ilens[b])
|
20 |
+
chunk_mask[b, ty_start:ty_end, tx_start:tx_end] = 1
|
21 |
+
return chunk_mask
|
22 |
+
|
23 |
+
def subsequent_chunk_mask(
|
24 |
+
size: int,
|
25 |
+
chunk_size: int,
|
26 |
+
num_left_chunks: int = -1
|
27 |
+
) -> torch.Tensor:
|
28 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
29 |
+
this is for streaming encoder
|
30 |
+
|
31 |
+
Args:
|
32 |
+
size (int): size of mask
|
33 |
+
chunk_size (int): size of chunk
|
34 |
+
num_left_chunks (int): number of left chunks
|
35 |
+
<0: use full chunk
|
36 |
+
>=0: use num_left_chunks
|
37 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: mask
|
41 |
+
|
42 |
+
Examples:
|
43 |
+
>>> subsequent_chunk_mask(4, 2)
|
44 |
+
[[1, 1, 0, 0],
|
45 |
+
[1, 1, 0, 0],
|
46 |
+
[1, 1, 1, 1],
|
47 |
+
[1, 1, 1, 1]]
|
48 |
+
"""
|
49 |
+
ret = torch.zeros(size, size, dtype=torch.bool)
|
50 |
+
for i in range(size):
|
51 |
+
if num_left_chunks < 0:
|
52 |
+
start = 0
|
53 |
+
else:
|
54 |
+
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
55 |
+
ending = min((i // chunk_size + 1) * chunk_size, size)
|
56 |
+
ret[i, start:ending] = torch.ones(ending-start, dtype=torch.bool)
|
57 |
+
return ret
|
58 |
+
|
59 |
+
def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
|
60 |
+
use_dynamic_chunk: bool,
|
61 |
+
use_dynamic_left_chunk: bool,
|
62 |
+
decoding_chunk_size: int, static_chunk_size: int,
|
63 |
+
num_decoding_left_chunks: int):
|
64 |
+
""" Apply optional mask for encoder.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
68 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
69 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
70 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
71 |
+
training.
|
72 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
73 |
+
0: default for training, use random dynamic chunk.
|
74 |
+
<0: for decoding, use full chunk.
|
75 |
+
>0: for decoding, use fixed chunk size as set.
|
76 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
77 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
78 |
+
this parameter will be ignored
|
79 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
80 |
+
the chunk size is decoding_chunk_size.
|
81 |
+
>=0: use num_decoding_left_chunks
|
82 |
+
<0: use all left chunks
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
torch.Tensor: chunk mask of the input xs.
|
86 |
+
"""
|
87 |
+
# Whether to use chunk mask or not
|
88 |
+
if use_dynamic_chunk:
|
89 |
+
max_len = xs.size(1)
|
90 |
+
if decoding_chunk_size < 0:
|
91 |
+
chunk_size = max_len
|
92 |
+
num_left_chunks = -1
|
93 |
+
elif decoding_chunk_size > 0:
|
94 |
+
chunk_size = decoding_chunk_size
|
95 |
+
num_left_chunks = num_decoding_left_chunks
|
96 |
+
else:
|
97 |
+
# chunk size is either [1, 25] or full context(max_len).
|
98 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
99 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
100 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
101 |
+
num_left_chunks = -1
|
102 |
+
if chunk_size > max_len // 2:
|
103 |
+
chunk_size = max_len
|
104 |
+
else:
|
105 |
+
chunk_size = chunk_size % 25 + 1
|
106 |
+
if use_dynamic_left_chunk:
|
107 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
108 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
109 |
+
(1, )).item()
|
110 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
111 |
+
num_left_chunks) # (L, L)
|
112 |
+
chunk_masks = chunk_masks.to(xs.device)
|
113 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
114 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
115 |
+
elif static_chunk_size > 0:
|
116 |
+
num_left_chunks = num_decoding_left_chunks
|
117 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
118 |
+
num_left_chunks) # (L, L)
|
119 |
+
chunk_masks = chunk_masks.unsqueeze(0).to(masks.device) # (1, L, L)
|
120 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
121 |
+
else:
|
122 |
+
chunk_masks = masks
|
123 |
+
return chunk_masks
|
124 |
+
|
125 |
+
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
126 |
+
"""Make mask tensor containing indices of padded part.
|
127 |
+
|
128 |
+
See description of make_non_pad_mask.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
132 |
+
Returns:
|
133 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
134 |
+
|
135 |
+
Examples:
|
136 |
+
>>> lengths = [5, 3, 2]
|
137 |
+
>>> make_pad_mask(lengths)
|
138 |
+
masks = [[0, 0, 0, 0 ,0],
|
139 |
+
[0, 0, 0, 1, 1],
|
140 |
+
[0, 0, 1, 1, 1]]
|
141 |
+
"""
|
142 |
+
batch_size = int(lengths.size(0))
|
143 |
+
max_len = int(lengths.max().item())
|
144 |
+
seq_range = torch.arange(0,
|
145 |
+
max_len,
|
146 |
+
dtype=torch.int64,
|
147 |
+
device=lengths.device)
|
148 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
149 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
150 |
+
mask = seq_range_expand >= seq_length_expand
|
151 |
+
return mask
|
152 |
+
|
153 |
+
def subsequent_mask(
|
154 |
+
size: int,
|
155 |
+
device: torch.device = torch.device("cpu"),
|
156 |
+
) -> torch.Tensor:
|
157 |
+
"""Create mask for subsequent steps (size, size).
|
158 |
+
|
159 |
+
This mask is used only in decoder which works in an auto-regressive mode.
|
160 |
+
This means the current step could only do attention with its left steps.
|
161 |
+
|
162 |
+
In encoder, fully attention is used when streaming is not necessary and
|
163 |
+
the sequence is not long. In this case, no attention mask is needed.
|
164 |
+
|
165 |
+
When streaming is need, chunk-based attention is used in encoder. See
|
166 |
+
subsequent_chunk_mask for the chunk-based attention mask.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
size (int): size of mask
|
170 |
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
171 |
+
dtype (torch.device): result dtype
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
torch.Tensor: mask
|
175 |
+
|
176 |
+
Examples:
|
177 |
+
>>> subsequent_mask(3)
|
178 |
+
[[1, 0, 0],
|
179 |
+
[1, 1, 0],
|
180 |
+
[1, 1, 1]]
|
181 |
+
"""
|
182 |
+
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
183 |
+
return torch.tril(ret, out=ret)
|
184 |
+
|
185 |
+
def target_mask(ys_in_pad, ignore_id):
|
186 |
+
# type: (Tensor, int) -> Tensor
|
187 |
+
# Create mask for decoder self-attention.
|
188 |
+
# :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
|
189 |
+
# :param int ignore_id: index of padding
|
190 |
+
# :param torch.dtype dtype: result dtype
|
191 |
+
# :rtype: torch.Tensor
|
192 |
+
|
193 |
+
ys_mask = ys_in_pad != ignore_id
|
194 |
+
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
|
195 |
+
return ys_mask.unsqueeze(-2) & m
|
vita/model/vita_tts/pipeline.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import yaml
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
from vita.model.vita_tts.utils import init_encoder_llm, load_checkpoint
|
7 |
+
|
8 |
+
class inferencePipeline():
|
9 |
+
def __init__(self, args):
|
10 |
+
self.args = args
|
11 |
+
|
12 |
+
with open(self.args.model_path + "/audiollm/train.yaml", 'r') as fin:
|
13 |
+
configs = yaml.safe_load(fin)
|
14 |
+
configs['cmvn_file'] = self.args.model_path + "/audiollm/global_cmvn"
|
15 |
+
configs['model_conf']['llm_path'] = self.args.llm_path
|
16 |
+
|
17 |
+
# Init asr model from configs
|
18 |
+
self.model = init_encoder_llm(configs)
|
19 |
+
|
20 |
+
load_checkpoint(self.model, self.args.model_path + "/audiollm/final.pt")
|
21 |
+
device = torch.device('cuda')
|
22 |
+
self.model = self.model.to(device)
|
23 |
+
self.model.eval()
|
24 |
+
|
25 |
+
def speech_dialogue(self,
|
26 |
+
audio: tuple,
|
27 |
+
role: str=None,
|
28 |
+
stat: str='sl',
|
29 |
+
past_key_values=None,
|
30 |
+
last_id=None,
|
31 |
+
past_tokens=None,
|
32 |
+
adapter_cache=None,
|
33 |
+
encoder_cache=None,
|
34 |
+
pe_index=0):
|
35 |
+
with torch.no_grad():
|
36 |
+
## input fbank
|
37 |
+
feats = audio
|
38 |
+
if feats is not None:
|
39 |
+
feats = feats.to('cuda')
|
40 |
+
feats_lengths = torch.tensor([feats.size(1)]).to('cuda')
|
41 |
+
else:
|
42 |
+
feats_lengths = None
|
43 |
+
|
44 |
+
extra_inputs = {}
|
45 |
+
extra_inputs['top_p'] = self.args.top_p
|
46 |
+
extra_inputs['top_k'] = self.args.top_k
|
47 |
+
extra_inputs['temperature'] = self.args.temperature
|
48 |
+
extra_inputs['past_key_values'] = past_key_values
|
49 |
+
extra_inputs['stat'] = stat
|
50 |
+
extra_inputs['last_id'] = last_id
|
51 |
+
extra_inputs['adapter_cache'] = adapter_cache
|
52 |
+
extra_inputs['encoder_cache'] = encoder_cache
|
53 |
+
extra_inputs['pe_index'] = pe_index
|
54 |
+
if role is not None and past_key_values is None:
|
55 |
+
# add <|im_end|> in chat_prefix
|
56 |
+
extra_inputs['role'] = '<|im_start|>system\n' + role # + '<|im_end|>'
|
57 |
+
|
58 |
+
with torch.autocast(device_type="cuda",
|
59 |
+
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32):
|
60 |
+
# preprocess system role first
|
61 |
+
if stat == 'pre':
|
62 |
+
past_key_values = self.model.set_system_role(extra_inputs)
|
63 |
+
stat = 'sl'
|
64 |
+
else:
|
65 |
+
(last_id, stat, past_key_values, adapter_cache,
|
66 |
+
encoder_cache, pe_index, hidden_state) = self.model.recognize(
|
67 |
+
feats,
|
68 |
+
feats_lengths,
|
69 |
+
extra_inputs=extra_inputs)
|
70 |
+
|
71 |
+
outputs = dict(
|
72 |
+
past_key_values=past_key_values,
|
73 |
+
stat=stat,
|
74 |
+
last_id=last_id,
|
75 |
+
adapter_cache=adapter_cache,
|
76 |
+
encoder_cache=encoder_cache,
|
77 |
+
pe_index=pe_index,
|
78 |
+
)
|
79 |
+
|
80 |
+
if stat == 'cs':
|
81 |
+
if past_tokens is None:
|
82 |
+
past_tokens = []
|
83 |
+
past_tokens.append(last_id[0][0])
|
84 |
+
text = self.model.tokenizer.decode(past_tokens, skip_special_tokens=True)
|
85 |
+
outputs['hidden_state'] = hidden_state
|
86 |
+
outputs['text'] = text
|
87 |
+
outputs['past_tokens'] = past_tokens
|
88 |
+
|
89 |
+
return outputs
|
90 |
+
|
91 |
+
def post_process(self, text):
|
92 |
+
"""
|
93 |
+
Post-processes the input text to standardize various characters and formatting.
|
94 |
+
|
95 |
+
Parameters:
|
96 |
+
- text (str): The input text string to be post-processed.
|
97 |
+
|
98 |
+
Actions:
|
99 |
+
1. Replaces various Chinese and English punctuation marks with standardized ones.
|
100 |
+
2. Removes newline, tab, and other unwanted whitespace characters.
|
101 |
+
3. Removes special characters like asterisks, underscores, backticks, and tildes.
|
102 |
+
4. Condenses whitespace following periods and colons.
|
103 |
+
5. Adjusts the format of numbered lists to use appropriate separators
|
104 |
+
6. Ensures the text ends with an appropriate punctuation mark
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
- str: The post-processed text string.
|
108 |
+
"""
|
109 |
+
text = text.replace('、', ',')
|
110 |
+
text = text.replace('(', ',')
|
111 |
+
text = text.replace(')', ',')
|
112 |
+
text = text.replace('(', ',')
|
113 |
+
text = text.replace(')', ',')
|
114 |
+
|
115 |
+
text = re.sub(r'[\n\r\t]', '', text)
|
116 |
+
text = re.sub(r'[*_`~]', '', text)
|
117 |
+
|
118 |
+
text = re.sub(r'(\.|\:)\s+', r'\1', text)
|
119 |
+
|
120 |
+
if re.search(r'[\u4e00-\u9fa5]', text):
|
121 |
+
text = re.sub(r'(\d+)\.\s*([\u4e00-\u9fa5A-Za-z])', r'\1:\2', text)
|
122 |
+
else:
|
123 |
+
text = re.sub(r'(\d+)\.\s*([\w])', r'\1:\2', text)
|
124 |
+
|
125 |
+
if text and text[-1] not in ["。", "?", "!", ".", "?", "!"]:
|
126 |
+
if text[-1] in [",", ",", ";", ";", ":", ":", "、"]:
|
127 |
+
text = text[:-1] + "。"
|
128 |
+
else:
|
129 |
+
text += "。"
|
130 |
+
|
131 |
+
return text
|
vita/model/vita_tts/utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import re
|
3 |
+
import os
|
4 |
+
|
5 |
+
from vita.model.vita_tts.audioLLM import AudioLLM
|
6 |
+
|
7 |
+
from vita.model.vita_tts.encoder.cmvn import GlobalCMVN, load_cmvn
|
8 |
+
from vita.model.vita_tts.encoder.encoder import speechEncoder
|
9 |
+
|
10 |
+
def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
print('Checkpoint: loading from checkpoint %s for GPU' % path)
|
13 |
+
checkpoint = torch.load(path)
|
14 |
+
else:
|
15 |
+
print('Checkpoint: loading from checkpoint %s for CPU' % path)
|
16 |
+
checkpoint = torch.load(path, map_location='cpu')
|
17 |
+
|
18 |
+
# load parm from checkpoint
|
19 |
+
model.load_state_dict(checkpoint, strict=False)
|
20 |
+
|
21 |
+
info_path = re.sub('.pt$', '.yaml', path)
|
22 |
+
configs = {}
|
23 |
+
# get configs
|
24 |
+
if os.path.exists(info_path):
|
25 |
+
with open(info_path, 'r') as fin:
|
26 |
+
configs = yaml.safe_load(fin)
|
27 |
+
return configs
|
28 |
+
|
29 |
+
def init_encoder_llm(configs):
|
30 |
+
if configs['cmvn_file'] is not None:
|
31 |
+
# read cmvn
|
32 |
+
mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn'])
|
33 |
+
# init cmvn layer
|
34 |
+
global_cmvn = GlobalCMVN(
|
35 |
+
torch.from_numpy(mean).float(),
|
36 |
+
torch.from_numpy(istd).float())
|
37 |
+
else:
|
38 |
+
global_cmvn = None
|
39 |
+
|
40 |
+
input_dim = configs['input_dim']
|
41 |
+
vocab_size = configs['output_dim']
|
42 |
+
|
43 |
+
# init speech encoder
|
44 |
+
encoder = speechEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
|
45 |
+
# init audioLLM
|
46 |
+
model = AudioLLM(encoder=encoder, **configs['model_conf'])
|
47 |
+
|
48 |
+
return model
|