dongyh20 commited on
Commit
1938217
·
1 Parent(s): a4679af

update space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +382 -0
  2. ola/CosyVoice +1 -0
  3. ola/__pycache__/arguments.cpython-310.pyc +0 -0
  4. ola/__pycache__/arguments.cpython-38.pyc +0 -0
  5. ola/__pycache__/constants.cpython-310.pyc +0 -0
  6. ola/__pycache__/constants.cpython-38.pyc +0 -0
  7. ola/__pycache__/conversation.cpython-310.pyc +0 -0
  8. ola/__pycache__/conversation.cpython-38.pyc +0 -0
  9. ola/__pycache__/mm_utils.cpython-310.pyc +0 -0
  10. ola/__pycache__/mm_utils.cpython-38.pyc +0 -0
  11. ola/__pycache__/utils.cpython-310.pyc +0 -0
  12. ola/__pycache__/utils.cpython-38.pyc +0 -0
  13. ola/arguments.py +65 -0
  14. ola/constants.py +14 -0
  15. ola/conversation.py +254 -0
  16. ola/datasets/__init__.py +0 -0
  17. ola/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  18. ola/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  19. ola/datasets/__pycache__/preprocess.cpython-310.pyc +0 -0
  20. ola/datasets/__pycache__/preprocess.cpython-38.pyc +0 -0
  21. ola/datasets/preprocess.py +413 -0
  22. ola/mm_utils.py +272 -0
  23. ola/model/__init__.py +1 -0
  24. ola/model/__pycache__/__init__.cpython-310.pyc +0 -0
  25. ola/model/__pycache__/__init__.cpython-38.pyc +0 -0
  26. ola/model/__pycache__/builder.cpython-310.pyc +0 -0
  27. ola/model/__pycache__/builder.cpython-38.pyc +0 -0
  28. ola/model/__pycache__/ola_arch.cpython-310.pyc +0 -0
  29. ola/model/__pycache__/ola_arch.cpython-38.pyc +0 -0
  30. ola/model/builder.py +91 -0
  31. ola/model/language_model/__pycache__/ola_qwen.cpython-310.pyc +0 -0
  32. ola/model/language_model/__pycache__/ola_qwen.cpython-38.pyc +0 -0
  33. ola/model/language_model/ola_qwen.py +237 -0
  34. ola/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  35. ola/model/multimodal_encoder/__pycache__/builder.cpython-38.pyc +0 -0
  36. ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-310.pyc +0 -0
  37. ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-38.pyc +0 -0
  38. ola/model/multimodal_encoder/builder.py +9 -0
  39. ola/model/multimodal_encoder/oryx_vit.py +1126 -0
  40. ola/model/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
  41. ola/model/multimodal_projector/__pycache__/builder.cpython-38.pyc +0 -0
  42. ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc +0 -0
  43. ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-38.pyc +0 -0
  44. ola/model/multimodal_projector/builder.py +179 -0
  45. ola/model/multimodal_projector/pooler_projector.py +74 -0
  46. ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc +0 -0
  47. ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc +0 -0
  48. ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc +0 -0
  49. ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc +0 -0
  50. ola/model/multimodal_resampler/builder.py +24 -0
app.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['LOWRES_RESIZE'] = '384x32'
3
+ os.environ['HIGHRES_BASE'] = '0x32'
4
+ os.environ['VIDEO_RESIZE'] = "0x64"
5
+ os.environ['VIDEO_MAXRES'] = "480"
6
+ os.environ['VIDEO_MINRES'] = "288"
7
+ os.environ['MAXRES'] = '1536'
8
+ os.environ['MINRES'] = '0'
9
+ os.environ['REGIONAL_POOL'] = '2x'
10
+ os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
11
+ os.environ['LOAD_VISION_EARLY'] = '1'
12
+ os.environ['SKIP_LOAD_VIT'] = '1'
13
+
14
+
15
+ import gradio as gr
16
+ import torch
17
+ import re
18
+ from decord import VideoReader, cpu
19
+ from PIL import Image
20
+ import numpy as np
21
+ import transformers
22
+ import moviepy.editor as mp
23
+ from typing import Dict, Optional, Sequence, List
24
+ import librosa
25
+ import whisper
26
+
27
+ # import subprocess
28
+ # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
29
+
30
+ import sys
31
+ sys.path.append('./ola/CosyVoice/')
32
+ from ola.conversation import conv_templates, SeparatorStyle
33
+ from ola.model.builder import load_pretrained_model
34
+ from ola.utils import disable_torch_init
35
+ from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token
36
+ from ola.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image_genli
37
+ from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
38
+ # from ola.CosyVoice.cosyvoice.cli.cosyvoice import CosyVoice
39
+
40
+ model_path = "/mnt/lzy/ola-model/Ola-7b"
41
+ tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None)
42
+ model = model.to('cuda').eval()
43
+ model = model.bfloat16()
44
+
45
+ # tts_model = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
46
+ # OUTPUT_SPEECH = False
47
+
48
+ USE_SPEECH=False
49
+
50
+ title_markdown = """
51
+ <div style="display: flex; justify-content: left; align-items: center; text-align: left; background: linear-gradient(45deg, rgba(204,255,231, 0.8), rgba(204,255,231, 0.3)); border-radius: 10px; box-shadow: 0 8px 16px 0 rgba(0,0,0,0.1);"> <a href="https://llava-vl.github.io/blog/2024-04-30-llava-next-video/"" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
52
+ <img src="https://ola-omni.github.io/static/images/icon.png" alt="Oryx" style="max-width: 80px; height: auto; border-radius: 10px;">
53
+ </a>
54
+ <div>
55
+ <h2 ><a href="https://github.com/Ola-Omni/Ola">Ola: Pushing the Frontiers of Omni-Modal Language Model with Progressive Modality Alignment</a> </h2>
56
+ <h5 style="margin: 0;"><a href="https://ola-omni.github.io/">Project Page</a> | <a href="https://github.com/Ola-Omni/Ola">Github</a> | <a href="https://huggingface.co/THUdyh/Ola-7b">Huggingface</a> | <a href="https://arxiv.org/abs/2502.04328">Paper</a> </h5>
57
+ </div>
58
+ </div>
59
+ """
60
+
61
+ bibtext = """
62
+ ### Citation
63
+ ```
64
+ @article{liu2025ola,
65
+ title={Ola: Pushing the Frontiers of Omni-Modal Language Model with Progressive Modality Alignment},
66
+ author={Liu, Zuyan and Dong, Yuhao and Wang, Jiahui and Liu, Ziwei and Hu, Winston and Lu, Jiwen and Rao, Yongming},
67
+ journal={arXiv preprint arXiv:2502.04328},
68
+ year={2025}
69
+ }
70
+ ```
71
+ """
72
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
73
+
74
+
75
+ def load_audio(audio_file_name):
76
+ speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
77
+ if len(speech_wav.shape) > 1:
78
+ speech_wav = speech_wav[:, 0]
79
+ speech_wav = speech_wav.astype(np.float32)
80
+ CHUNK_LIM = 480000
81
+ SAMPLE_RATE = 16000
82
+ speechs = []
83
+ speech_wavs = []
84
+
85
+ if len(speech_wav) <= CHUNK_LIM:
86
+ speech = whisper.pad_or_trim(speech_wav)
87
+ speech_wav = whisper.pad_or_trim(speech_wav)
88
+ speechs.append(speech)
89
+ speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0))
90
+ else:
91
+ for i in range(0, len(speech_wav), CHUNK_LIM):
92
+ chunk = speech_wav[i : i + CHUNK_LIM]
93
+ if len(chunk) < CHUNK_LIM:
94
+ chunk = whisper.pad_or_trim(chunk)
95
+ speechs.append(chunk)
96
+ speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
97
+ mels = []
98
+ for chunk in speechs:
99
+ chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
100
+ mels.append(chunk)
101
+
102
+ mels = torch.cat(mels, dim=0)
103
+ speech_wavs = torch.cat(speech_wavs, dim=0)
104
+ if mels.shape[0] > 25:
105
+ mels = mels[:25]
106
+ speech_wavs = speech_wavs[:25]
107
+
108
+ speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
109
+ speech_chunks = torch.LongTensor([mels.shape[0]])
110
+ return mels, speech_length, speech_chunks, speech_wavs
111
+
112
+ def extract_audio(videos_file_path):
113
+ my_clip = mp.VideoFileClip(videos_file_path)
114
+ return my_clip.audio
115
+
116
+ def ola_inference(multimodal, audio_path):
117
+ visual, text = multimodal["files"][0], multimodal["text"]
118
+ if visual.endswith("image2.png"):
119
+ modality = "video"
120
+ visual = f"{cur_dir}/case/case1.mp4"
121
+ if visual.endswith(".mp4"):
122
+ modality = "video"
123
+ else:
124
+ modality = "image"
125
+
126
+ # input audio and video, do not parse audio in the video, else parse audio in the video
127
+ if audio_path:
128
+ USE_SPEECH = True
129
+ elif modality == "video":
130
+ USE_SPEECH = True
131
+ else:
132
+ USE_SPEECH = False
133
+
134
+ speechs = []
135
+ speech_lengths = []
136
+ speech_wavs = []
137
+ speech_chunks = []
138
+ if modality == "video":
139
+ vr = VideoReader(visual, ctx=cpu(0))
140
+ total_frame_num = len(vr)
141
+ fps = round(vr.get_avg_fps())
142
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int)
143
+ frame_idx = uniform_sampled_frames.tolist()
144
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
145
+ video = [Image.fromarray(frame) for frame in spare_frames]
146
+ else:
147
+ image = [Image.open(visual)]
148
+ image_sizes = [image[0].size]
149
+
150
+ if USE_SPEECH and audio_path:
151
+ audio_path = audio_path
152
+ speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path)
153
+ speechs.append(speech.bfloat16().to('cuda'))
154
+ speech_lengths.append(speech_length.to('cuda'))
155
+ speech_chunks.append(speech_chunk.to('cuda'))
156
+ speech_wavs.append(speech_wav.to('cuda'))
157
+ print('load audio')
158
+ elif USE_SPEECH and not audio_path:
159
+ # parse audio in the video
160
+ audio = extract_audio(visual)
161
+ audio.write_audiofile("./video_audio.wav")
162
+ video_audio_path = './video_audio.wav'
163
+ speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path)
164
+ speechs.append(speech.bfloat16().to('cuda'))
165
+ speech_lengths.append(speech_length.to('cuda'))
166
+ speech_chunks.append(speech_chunk.to('cuda'))
167
+ speech_wavs.append(speech_wav.to('cuda'))
168
+ else:
169
+ speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')]
170
+ speech_lengths = [torch.LongTensor([3000]).to('cuda')]
171
+ speech_wavs = [torch.zeros([1, 480000]).to('cuda')]
172
+ speech_chunks = [torch.LongTensor([1]).to('cuda')]
173
+
174
+ conv_mode = "qwen_1_5"
175
+ if text:
176
+ qs = text
177
+ else:
178
+ qs = ''
179
+ if USE_SPEECH and audio_path:
180
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + "User's question in speech: " + DEFAULT_SPEECH_TOKEN + '\n'
181
+ elif USE_SPEECH:
182
+ qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + qs
183
+ else:
184
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
185
+
186
+ conv = conv_templates[conv_mode].copy()
187
+ conv.append_message(conv.roles[0], qs)
188
+ conv.append_message(conv.roles[1], None)
189
+ prompt = conv.get_prompt()
190
+ if USE_SPEECH and audio_path:
191
+ input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
192
+ elif USE_SPEECH:
193
+ input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
194
+ else:
195
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
196
+
197
+ if modality == "video":
198
+ video_processed = []
199
+ for idx, frame in enumerate(video):
200
+ image_processor.do_resize = False
201
+ image_processor.do_center_crop = False
202
+ frame = process_anyres_video(frame, image_processor)
203
+
204
+ if frame_idx is not None and idx in frame_idx:
205
+ video_processed.append(frame.unsqueeze(0))
206
+ elif frame_idx is None:
207
+ video_processed.append(frame.unsqueeze(0))
208
+
209
+ if frame_idx is None:
210
+ frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
211
+
212
+ video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda")
213
+ video_processed = (video_processed, video_processed)
214
+
215
+ video_data = (video_processed, (384, 384), "video")
216
+ else:
217
+ image_processor.do_resize = False
218
+ image_processor.do_center_crop = False
219
+ image_tensor, image_highres_tensor = [], []
220
+ for visual in image:
221
+ image_tensor_, image_highres_tensor_ = process_anyres_highres_image_genli(visual, image_processor)
222
+ image_tensor.append(image_tensor_)
223
+ image_highres_tensor.append(image_highres_tensor_)
224
+ if all(x.shape == image_tensor[0].shape for x in image_tensor):
225
+ image_tensor = torch.stack(image_tensor, dim=0)
226
+ if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor):
227
+ image_highres_tensor = torch.stack(image_highres_tensor, dim=0)
228
+ if type(image_tensor) is list:
229
+ image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor]
230
+ else:
231
+ image_tensor = image_tensor.bfloat16().to("cuda")
232
+ if type(image_highres_tensor) is list:
233
+ image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor]
234
+ else:
235
+ image_highres_tensor = image_highres_tensor.bfloat16().to("cuda")
236
+
237
+ pad_token_ids = 151643
238
+
239
+ attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')
240
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
241
+ keywords = [stop_str]
242
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
243
+
244
+ gen_kwargs = {}
245
+
246
+ if "max_new_tokens" not in gen_kwargs:
247
+ gen_kwargs["max_new_tokens"] = 1024
248
+ if "temperature" not in gen_kwargs:
249
+ gen_kwargs["temperature"] = 0.2
250
+ if "top_p" not in gen_kwargs:
251
+ gen_kwargs["top_p"] = None
252
+ if "num_beams" not in gen_kwargs:
253
+ gen_kwargs["num_beams"] = 1
254
+
255
+ with torch.inference_mode():
256
+ if modality == "video":
257
+ output_ids = model.generate(
258
+ inputs=input_ids,
259
+ images=video_data[0][0],
260
+ images_highres=video_data[0][1],
261
+ modalities=video_data[2],
262
+ speech=speechs,
263
+ speech_lengths=speech_lengths,
264
+ speech_chunks=speech_chunks,
265
+ speech_wav=speech_wavs,
266
+ attention_mask=attention_masks,
267
+ use_cache=True,
268
+ stopping_criteria=[stopping_criteria],
269
+ do_sample=True if gen_kwargs["temperature"] > 0 else False,
270
+ temperature=gen_kwargs["temperature"],
271
+ top_p=gen_kwargs["top_p"],
272
+ num_beams=gen_kwargs["num_beams"],
273
+ max_new_tokens=gen_kwargs["max_new_tokens"],
274
+ )
275
+ else:
276
+ output_ids = model.generate(
277
+ inputs=input_ids,
278
+ images=image_tensor,
279
+ images_highres=image_highres_tensor,
280
+ image_sizes=image_sizes,
281
+ modalities=['image'],
282
+ speech=speechs,
283
+ speech_lengths=speech_lengths,
284
+ speech_chunks=speech_chunks,
285
+ speech_wav=speech_wavs,
286
+ attention_mask=attention_masks,
287
+ use_cache=True,
288
+ stopping_criteria=[stopping_criteria],
289
+ do_sample=True if gen_kwargs["temperature"] > 0 else False,
290
+ temperature=gen_kwargs["temperature"],
291
+ top_p=gen_kwargs["top_p"],
292
+ num_beams=gen_kwargs["num_beams"],
293
+ max_new_tokens=gen_kwargs["max_new_tokens"],
294
+ )
295
+
296
+
297
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
298
+ outputs = outputs.strip()
299
+ if outputs.endswith(stop_str):
300
+ outputs = outputs[:-len(stop_str)]
301
+ outputs = outputs.strip()
302
+
303
+ # if OUTPUT_SPEECH:
304
+ # voice_all = []
305
+ # for i, j in enumerate(cosyvoice.inference_sft('Visual data comes in various forms, ranging from small icons of just a few pixels to long videos spanning hours. Existing multi-modal LLMs usually standardize these diverse visual inputs to a fixed resolution for visual encoders and yield similar numbers of tokens for LLMs. This approach is non-optimal for multimodal understanding and inefficient for processing inputs with long and short visual contents. To solve the problem, we propose Oryx, a unified multimodal architecture for the spatial-temporal understanding of images, videos, and multi-view 3D scenes. Oryx offers an on-demand solution to seamlessly and efficiently process visual inputs with arbitrary spatial sizes and temporal lengths through two core innovations: 1) a pre-trained OryxViT model that can encode images at any resolution into LLM-friendly visual representations; 2) a dynamic compressor module that supports 1x to 16x compression on visual tokens by request. These design features enable Oryx to accommodate extremely long visual contexts, such as videos, with lower resolution and high compression while maintaining high recognition precision for tasks like document understanding with native resolution and no compression. Beyond the architectural improvements, enhanced data curation and specialized training on long-context retrieval and spatial-aware data help Oryx achieve strong capabilities in image, video, and 3D multimodal understanding simultaneously. ', '英文女', stream=False)):
306
+ # voice_all.append(j['tts_speech'])
307
+ # voice_all = torch.cat(voice_all, dim=1)
308
+ # torchaudio.save('sft.wav', voice_all, 22050)
309
+ # return outputs, "sft.wav"
310
+ # else:
311
+ return outputs, None
312
+
313
+ # Define input and output for the Gradio interface
314
+ demo = gr.Interface(
315
+ fn=ola_inference,
316
+ inputs=[gr.MultimodalTextbox(file_types=[".mp4", "image"],placeholder="Enter message or upload file..."), gr.Audio(type="filepath")],
317
+ outputs=["text", "audio"],
318
+ # examples=[
319
+ # {
320
+ # "files":[f"{cur_dir}/case/image2.png"],
321
+ # "text":"Describe what is happening in this video in detail.",
322
+ # },
323
+ # {
324
+ # "files":[f"{cur_dir}/case/image.png"],
325
+ # "text":"Describe this icon.",
326
+ # },
327
+ # ],
328
+ title="Ola Demo",
329
+ description=title_markdown,
330
+ article=bibtext,
331
+ )
332
+
333
+ # textbox = gr.Textbox(
334
+ # show_label=False, placeholder="Enter text and press ENTER", container=False, max_lines=100
335
+ # )
336
+ # with gr.Blocks(
337
+ # title="Oryx-7B",
338
+ # theme="finlaymacklon/smooth_slate",
339
+ # css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 50px}",
340
+ # fill_height=True
341
+ # ) as demo:
342
+ # html_header = "https://oryx-mllm.github.io/"
343
+ # gr.HTML(html_header)
344
+
345
+ # with gr.Row(equal_height=True):
346
+ # with gr.Column(scale=3):
347
+ # with gr.Row():
348
+ # video = gr.Video(label="Input Video", height=400)
349
+ # cur_dir = os.path.dirname(os.path.abspath(__file__))
350
+ # with gr.Row():
351
+ # gr.Examples(
352
+ # examples=[
353
+ # [
354
+ # f"{cur_dir}/case/case1.mp4",
355
+ # "Describe what is happening in this video in detail.",
356
+ # ],
357
+ # ],
358
+ # inputs=[video, textbox],
359
+ # )
360
+
361
+ # with gr.Column(scale=7):
362
+ # chatbot = gr.Chatbot(label="Oryx", bubble_full_width=False, height=660)
363
+ # with gr.Row():
364
+ # with gr.Column(scale=8):
365
+ # textbox.render()
366
+ # with gr.Column(scale=1, min_width=50):
367
+ # submit_btn = gr.Button(
368
+ # value="Send", variant="primary", interactive=True
369
+ # )
370
+ # # with gr.Row(elem_id="buttons") as button_row:
371
+ # # upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
372
+ # # downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
373
+ # # flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
374
+ # # clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
375
+
376
+ # submit_btn.click(
377
+ # oryx_inference,
378
+ # [video, textbox],
379
+ # [chatbot, textbox, video],
380
+ # )
381
+ # Launch the Gradio app
382
+ demo.launch(server_name="0.0.0.0",server_port=80)
ola/CosyVoice ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 027e1ccb82ce59bbc12f35a96e0f92625cf18369
ola/__pycache__/arguments.cpython-310.pyc ADDED
Binary file (2.65 kB). View file
 
ola/__pycache__/arguments.cpython-38.pyc ADDED
Binary file (2.64 kB). View file
 
ola/__pycache__/constants.cpython-310.pyc ADDED
Binary file (508 Bytes). View file
 
ola/__pycache__/constants.cpython-38.pyc ADDED
Binary file (506 Bytes). View file
 
ola/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (6.21 kB). View file
 
ola/__pycache__/conversation.cpython-38.pyc ADDED
Binary file (6.28 kB). View file
 
ola/__pycache__/mm_utils.cpython-310.pyc ADDED
Binary file (6.44 kB). View file
 
ola/__pycache__/mm_utils.cpython-38.pyc ADDED
Binary file (6.41 kB). View file
 
ola/__pycache__/utils.cpython-310.pyc ADDED
Binary file (7.5 kB). View file
 
ola/__pycache__/utils.cpython-38.pyc ADDED
Binary file (7.53 kB). View file
 
ola/arguments.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+
6
+
7
+ @dataclass
8
+ class ModelArguments:
9
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
10
+ version: Optional[str] = field(default="v0")
11
+ freeze_backbone: bool = field(default=False)
12
+ tune_speech_projector: bool = field(default=False)
13
+ tune_speech_encoder: bool = field(default=False)
14
+ tune_speech_generator_only: bool = field(default=False)
15
+ speech_encoder_type: Optional[str] = field(default=None)
16
+ speech_encoder: Optional[str] = field(default=None)
17
+ pretrain_speech_projector: Optional[str] = field(default=None)
18
+ speech_projector_type: Optional[str] = field(default='linear')
19
+ speech_encoder_ds_rate: int = 5
20
+ speech_encoder_hidden_size: int = 1280
21
+
22
+
23
+ @dataclass
24
+ class DataArguments:
25
+ data_path: str = field(default=None,
26
+ metadata={"help": "Path to the training data."})
27
+ is_multimodal: bool = False
28
+ input_type: str = field(default="mel")
29
+ speech_normalize: bool = False
30
+ mel_size: int = 128
31
+ has_tgt_units: bool = False
32
+
33
+
34
+ @dataclass
35
+ class TrainingArguments(transformers.TrainingArguments):
36
+ cache_dir: Optional[str] = field(default=None)
37
+ optim: str = field(default="adamw_torch")
38
+ freeze_speech_projector: bool = field(default=False)
39
+ model_max_length: int = field(
40
+ default=512,
41
+ metadata={
42
+ "help":
43
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
44
+ },
45
+ )
46
+ double_quant: bool = field(
47
+ default=True,
48
+ metadata={"help": "Compress the quantization statistics through double quantization."}
49
+ )
50
+ quant_type: str = field(
51
+ default="nf4",
52
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
53
+ )
54
+ bits: int = field(
55
+ default=16,
56
+ metadata={"help": "How many bits to use."}
57
+ )
58
+ lora_enable: bool = False
59
+ lora_r: int = 64
60
+ lora_alpha: int = 16
61
+ lora_dropout: float = 0.05
62
+ lora_weight_path: str = ""
63
+ lora_bias: str = "none"
64
+ speech_projector_lr: Optional[float] = None
65
+ group_by_modality_length: bool = field(default=False)
ola/constants.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ SPEECH_TOKEN_INDEX = -200
9
+ DEFAULT_SPEECH_TOKEN = "<speech>"
10
+ IMAGE_TOKEN_INDEX= -300
11
+ DEFAULT_IMAGE_TOKEN = "<image>"
12
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
13
+ DEFAULT_IM_START_TOKEN = "<im_start>"
14
+ DEFAULT_IM_END_TOKEN = "<im_end>"
ola/conversation.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Any, Union, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ TWO = auto()
12
+ PLAIN = auto()
13
+ CHATML = auto()
14
+ LLAMA_2 = auto()
15
+ LLAMA_3 = auto()
16
+ QWEN2 = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that keeps all conversation history."""
22
+ system: str
23
+ roles: List[str]
24
+ messages: List[List[str]]
25
+ offset: int
26
+ sep_style: SeparatorStyle = SeparatorStyle.PLAIN
27
+ sep: str = "###"
28
+ sep2: str = None
29
+ version: str = "Unknown"
30
+
31
+ tokenizer_id: str = ""
32
+ tokenizer: Any = None
33
+ # Stop criteria (the default one is EOS token)
34
+ stop_str: Union[str, List[str]] = None
35
+ # Stops generation if meeting any token in this list
36
+ stop_token_ids: List[int] = None
37
+
38
+ skip_next: bool = False
39
+
40
+ def get_prompt(self):
41
+ messages = self.messages
42
+
43
+ if self.sep_style == SeparatorStyle.TWO:
44
+ seps = [self.sep, self.sep2]
45
+ ret = self.system + seps[0]
46
+ for i, (role, message) in enumerate(messages):
47
+ if message:
48
+ if type(message) is tuple:
49
+ message = message[0]
50
+ ret += role + ": " + message + seps[i % 2]
51
+ else:
52
+ ret += role + ":"
53
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
54
+ wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
55
+ ret = "<|begin_of_text|>" + wrap_sys(self.system)
56
+ for i, (role, message) in enumerate(messages):
57
+ if message:
58
+ if type(message) is tuple:
59
+ message = message[0]
60
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
61
+ ret += message.strip() + self.sep2
62
+ else:
63
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
64
+ return ret
65
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
66
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
67
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
68
+ ret = ""
69
+
70
+ for i, (role, message) in enumerate(messages):
71
+ if i == 0:
72
+ assert message, "first message should not be none"
73
+ assert role == self.roles[0], "first message should come from user"
74
+ if message:
75
+ if type(message) is tuple:
76
+ message, _, _ = message
77
+ if i == 0:
78
+ message = wrap_sys(self.system) + message
79
+ if i % 2 == 0:
80
+ message = wrap_inst(message)
81
+ ret += self.sep + message
82
+ else:
83
+ ret += " " + message + " " + self.sep2
84
+ else:
85
+ ret += ""
86
+ ret = ret.lstrip(self.sep)
87
+ elif self.sep_style == SeparatorStyle.PLAIN:
88
+ seps = [self.sep, self.sep2]
89
+ ret = self.system
90
+ for i, (role, message) in enumerate(messages):
91
+ if message:
92
+ if type(message) is tuple:
93
+ message, _, _ = message
94
+ ret += message + seps[i % 2]
95
+ else:
96
+ ret += ""
97
+
98
+ elif self.sep_style == SeparatorStyle.CHATML:
99
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
100
+ for role, message in messages:
101
+ if message:
102
+ if type(message) is tuple:
103
+ raise ValueError("Tuple not supported in CHATML")
104
+ message, images = message
105
+ message = "<speech>" * len(images) + message
106
+ ret += role + "\n" + message + self.sep + "\n"
107
+ else:
108
+ ret += role + "\n"
109
+ return ret
110
+ elif self.sep_style == SeparatorStyle.QWEN2:
111
+ start = '<|im_start|>'
112
+ end = '<|im_end|>\n'
113
+ ret = start + 'system\n' + self.system + end
114
+ for i, (role, message) in enumerate(messages):
115
+ if message:
116
+ if type(message) is tuple:
117
+ message, _, _ = message
118
+
119
+ if message.endswith('<|endoftext|>'):
120
+ message = message.replace('<|endoftext|>', '')
121
+ ret += start + role + "\n" + message + end + '<|endoftext|>'
122
+ else:
123
+ assert not '<|endoftext|>' in message, f"Invalid message: {message}"
124
+ ret += start + role + "\n" + message + end
125
+ else:
126
+ ret += start + role + "\n"
127
+ else:
128
+ raise ValueError(f"Invalid style: {self.sep_style}")
129
+
130
+ return ret
131
+
132
+ def append_message(self, role, message):
133
+ self.messages.append([role, message])
134
+
135
+ def to_gradio_chatbot(self):
136
+ ret = []
137
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
138
+ if i % 2 == 0:
139
+ if type(msg) is tuple:
140
+ msg, speech = msg
141
+ ret.append([msg, None])
142
+ else:
143
+ ret.append([msg, None])
144
+ else:
145
+ ret[-1][-1] = msg
146
+ return ret
147
+
148
+ def copy(self):
149
+ return Conversation(
150
+ system=self.system,
151
+ roles=self.roles,
152
+ messages=[[x, y] for x, y in self.messages],
153
+ offset=self.offset,
154
+ sep_style=self.sep_style,
155
+ sep=self.sep,
156
+ sep2=self.sep2,
157
+ version=self.version)
158
+
159
+ def dict(self):
160
+ if len(self.get_images()) > 0:
161
+ return {
162
+ "system": self.system,
163
+ "roles": self.roles,
164
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
165
+ "offset": self.offset,
166
+ "sep": self.sep,
167
+ "sep2": self.sep2,
168
+ }
169
+ return {
170
+ "system": self.system,
171
+ "roles": self.roles,
172
+ "messages": self.messages,
173
+ "offset": self.offset,
174
+ "sep": self.sep,
175
+ "sep2": self.sep2,
176
+ }
177
+
178
+ conv_vicuna_v1 = Conversation(
179
+ system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
180
+ roles=("USER", "ASSISTANT"),
181
+ version="v1",
182
+ messages=[],
183
+ offset=0,
184
+ sep_style=SeparatorStyle.TWO,
185
+ sep=" ",
186
+ sep2="</s>",
187
+ )
188
+
189
+ conv_llama_2 = Conversation(
190
+ system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
191
+ roles=("USER", "ASSISTANT"),
192
+ version="llama_v2",
193
+ messages=[],
194
+ offset=0,
195
+ sep_style=SeparatorStyle.LLAMA_2,
196
+ sep="<s>",
197
+ sep2="</s>",
198
+ )
199
+
200
+ conv_llama_3 = Conversation(
201
+ system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
202
+ roles=("user", "assistant"),
203
+ version="llama_v3",
204
+ messages=[],
205
+ offset=0,
206
+ sep_style=SeparatorStyle.LLAMA_3,
207
+ sep="",
208
+ sep2="<|eot_id|>"
209
+ )
210
+
211
+
212
+ conv_qwen_v1 = Conversation(
213
+ system="You are a helpful assistant.",
214
+ roles=("user", "assistant"),
215
+ version="v1",
216
+ messages=(),
217
+ offset=0,
218
+ sep_style=SeparatorStyle.QWEN2,
219
+ )
220
+
221
+ conv_plain = Conversation(
222
+ system="",
223
+ roles=("", ""),
224
+ messages=(
225
+ ),
226
+ offset=0,
227
+ sep_style=SeparatorStyle.PLAIN,
228
+ sep="</s>",
229
+ )
230
+
231
+ conv_qwen = Conversation(
232
+ system="""<|im_start|>system
233
+ You are a helpful assistant.""",
234
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
235
+ version="qwen",
236
+ messages=[],
237
+ offset=0,
238
+ sep_style=SeparatorStyle.CHATML,
239
+ sep="<|im_end|>",
240
+ )
241
+
242
+ default_conversation = conv_llama_3
243
+ conv_templates = {
244
+ "v1": conv_vicuna_v1,
245
+ "plain": conv_plain,
246
+ "llama_2": conv_llama_2,
247
+ "llama_3": conv_llama_3,
248
+ 'v1_qwen2': conv_qwen_v1,
249
+ "qwen_1_5": conv_qwen,
250
+ }
251
+
252
+
253
+ if __name__ == "__main__":
254
+ print(default_conversation.get_prompt())
ola/datasets/__init__.py ADDED
File without changes
ola/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (135 Bytes). View file
 
ola/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (133 Bytes). View file
 
ola/datasets/__pycache__/preprocess.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
ola/datasets/__pycache__/preprocess.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
ola/datasets/preprocess.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import transformers
4
+ import tokenizers
5
+
6
+ from typing import Dict, Sequence
7
+
8
+ from ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX
9
+ from ola import conversation as conversation_lib
10
+ from ola.model import *
11
+ from ola.arguments import DataArguments
12
+ from ola.constants import SPEECH_TOKEN_INDEX
13
+
14
+ from packaging import version
15
+
16
+ IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
17
+
18
+
19
+ def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
20
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
21
+
22
+ def insert_separator(X, sep):
23
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
24
+
25
+ input_ids = []
26
+ offset = 0
27
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
28
+ offset = 1
29
+ input_ids.append(prompt_chunks[0][0])
30
+
31
+ for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
32
+ input_ids.extend(x[offset:])
33
+
34
+ if return_tensors is not None:
35
+ if return_tensors == 'pt':
36
+ return torch.tensor(input_ids, dtype=torch.long)
37
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
38
+ return input_ids
39
+
40
+
41
+ def preprocess_multimodal(
42
+ sources: Sequence[str],
43
+ data_args: DataArguments
44
+ ) -> Dict:
45
+ is_multimodal = data_args.is_multimodal
46
+ if not is_multimodal:
47
+ return sources
48
+
49
+ for source in sources:
50
+ for sentence in source:
51
+ if DEFAULT_SPEECH_TOKEN in sentence['value']:
52
+ sentence['value'] = sentence['value'].replace(DEFAULT_SPEECH_TOKEN, '').strip()
53
+ sentence['value'] = DEFAULT_SPEECH_TOKEN + '\n' + sentence['value']
54
+ sentence['value'] = sentence['value'].strip()
55
+
56
+ return sources
57
+
58
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
59
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
60
+
61
+ def insert_separator(X, sep):
62
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
63
+
64
+ input_ids = []
65
+ offset = 0
66
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
67
+ offset = 1
68
+ input_ids.append(prompt_chunks[0][0])
69
+
70
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
71
+ input_ids.extend(x[offset:])
72
+
73
+ if return_tensors is not None:
74
+ if return_tensors == 'pt':
75
+ return torch.tensor(input_ids, dtype=torch.long)
76
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
77
+ return input_ids
78
+
79
+ def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
80
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech><image>')]
81
+
82
+ def insert_separator(X, sep):
83
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
84
+
85
+ input_ids = []
86
+ offset = 0
87
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
88
+ offset = 1
89
+ input_ids.append(prompt_chunks[0][0])
90
+
91
+ for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)):
92
+ input_ids.extend(x[offset:])
93
+
94
+ if return_tensors is not None:
95
+ if return_tensors == 'pt':
96
+ return torch.tensor(input_ids, dtype=torch.long)
97
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
98
+ return input_ids
99
+
100
+ def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
101
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>\nUser's question in speech: <speech>\n")]
102
+
103
+ def insert_separator(X, sep):
104
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
105
+
106
+ input_ids = []
107
+ offset = 0
108
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
109
+ offset = 1
110
+ input_ids.append(prompt_chunks[0][0])
111
+
112
+ nl_tokens = tokenizer("\n").input_ids[0]
113
+ special_chunks = [image_token_index, nl_tokens]
114
+ special_chunks.extend(tokenizer("User's question in speech: ").input_ids)
115
+ special_chunks.extend([speech_token_idx, nl_tokens])
116
+
117
+ for x in insert_separator(prompt_chunks, special_chunks):
118
+ input_ids.extend(x[offset:])
119
+
120
+ if return_tensors is not None:
121
+ if return_tensors == 'pt':
122
+ return torch.tensor(input_ids, dtype=torch.long)
123
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
124
+ return input_ids
125
+
126
+ def preprocess_llama_2(
127
+ sources,
128
+ tokenizer: transformers.PreTrainedTokenizer,
129
+ has_speech: bool = False
130
+ ) -> Dict:
131
+ conv = conversation_lib.default_conversation.copy()
132
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
133
+
134
+ # Apply prompt templates
135
+ conversations = []
136
+ for i, source in enumerate(sources):
137
+ if roles[source[0]["from"]] != conv.roles[0]:
138
+ # Skip the first one if it is not from human
139
+ source = source[1:]
140
+
141
+ conv.messages = []
142
+ for j, sentence in enumerate(source):
143
+ role = roles[sentence["from"]]
144
+ assert role == conv.roles[j % 2], f"{i}"
145
+ conv.append_message(role, sentence["value"])
146
+ conversations.append(conv.get_prompt())
147
+
148
+ # Tokenize conversations
149
+
150
+ if has_speech:
151
+ input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
152
+ else:
153
+ input_ids = tokenizer(
154
+ conversations,
155
+ return_tensors="pt",
156
+ padding="longest",
157
+ max_length=tokenizer.model_max_length,
158
+ truncation=True,
159
+ ).input_ids
160
+
161
+ targets = input_ids.clone()
162
+
163
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
164
+
165
+ # Mask targets
166
+ sep = "[/INST] "
167
+ for conversation, target in zip(conversations, targets):
168
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
169
+
170
+ rounds = conversation.split(conv.sep2)
171
+ cur_len = 1
172
+ target[:cur_len] = IGNORE_INDEX
173
+ for i, rou in enumerate(rounds):
174
+ if rou == "":
175
+ break
176
+
177
+ parts = rou.split(sep)
178
+ if len(parts) != 2:
179
+ break
180
+ parts[0] += sep
181
+
182
+ if has_speech:
183
+ round_len = len(tokenizer_speech_token(rou, tokenizer))
184
+ instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2
185
+ else:
186
+ round_len = len(tokenizer(rou).input_ids)
187
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
188
+
189
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
190
+
191
+ cur_len += round_len
192
+ target[cur_len:] = IGNORE_INDEX
193
+
194
+ if cur_len < tokenizer.model_max_length:
195
+ if cur_len != total_len:
196
+ target[:] = IGNORE_INDEX
197
+ print(
198
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
199
+ f" (ignored)"
200
+ )
201
+
202
+ return dict(
203
+ input_ids=input_ids,
204
+ labels=targets,
205
+ )
206
+
207
+
208
+ def preprocess_llama_3(
209
+ sources,
210
+ tokenizer: transformers.PreTrainedTokenizer,
211
+ has_speech: bool = False
212
+ ) -> Dict:
213
+ conv = conversation_lib.default_conversation.copy()
214
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
215
+
216
+ # Apply prompt templates
217
+ conversations = []
218
+ for i, source in enumerate(sources):
219
+ if roles[source[0]["from"]] != conv.roles[0]:
220
+ # Skip the first one if it is not from human
221
+ source = source[1:]
222
+
223
+ assert len(source) == 2, "now only support single-turn conversation"
224
+
225
+ conv.messages = []
226
+ for j, sentence in enumerate(source):
227
+ role = roles[sentence["from"]]
228
+ assert role == conv.roles[j % 2], f"{i}"
229
+ conv.append_message(role, sentence["value"])
230
+ conversations.append(conv.get_prompt())
231
+
232
+ # Tokenize conversations
233
+
234
+ if has_speech:
235
+ input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
236
+ else:
237
+ input_ids = tokenizer(
238
+ conversations,
239
+ return_tensors="pt",
240
+ padding="longest",
241
+ max_length=tokenizer.model_max_length,
242
+ truncation=True,
243
+ ).input_ids
244
+
245
+ targets = input_ids.clone()
246
+
247
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3
248
+
249
+ # Mask targets
250
+ sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>\n\n"
251
+ for conversation, target in zip(conversations, targets):
252
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
253
+
254
+ cur_len = 1
255
+ target[:cur_len] = IGNORE_INDEX
256
+ parts = conversation.split(sep)
257
+ parts[0] += sep
258
+
259
+ if has_speech:
260
+ conversation_len = len(tokenizer_speech_token(conversation, tokenizer))
261
+ instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 1
262
+ else:
263
+ conversation_len = len(tokenizer(conversation).input_ids)
264
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
265
+
266
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
267
+ cur_len += conversation_len
268
+ target[cur_len:] = IGNORE_INDEX
269
+
270
+ # if cur_len < tokenizer.model_max_length:
271
+ # if cur_len != total_len:
272
+ # target[:] = IGNORE_INDEX
273
+ # print(
274
+ # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
275
+ # f" (ignored)"
276
+ # )
277
+
278
+ return dict(
279
+ input_ids=input_ids,
280
+ labels=targets,
281
+ )
282
+
283
+
284
+ def preprocess_v1(
285
+ sources,
286
+ tokenizer: transformers.PreTrainedTokenizer,
287
+ has_speech: bool = False
288
+ ) -> Dict:
289
+ conv = conversation_lib.default_conversation.copy()
290
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
291
+
292
+ # Apply prompt templates
293
+ conversations = []
294
+ for i, source in enumerate(sources):
295
+ if roles[source[0]["from"]] != conv.roles[0]:
296
+ # Skip the first one if it is not from human
297
+ source = source[1:]
298
+
299
+ conv.messages = []
300
+ for j, sentence in enumerate(source):
301
+ role = roles[sentence["from"]]
302
+ assert role == conv.roles[j % 2], f"{i}"
303
+ conv.append_message(role, sentence["value"])
304
+ conversations.append(conv.get_prompt())
305
+
306
+ # Tokenize conversations
307
+
308
+ if has_speech:
309
+ input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
310
+ else:
311
+ input_ids = tokenizer(
312
+ conversations,
313
+ return_tensors="pt",
314
+ padding="longest",
315
+ max_length=tokenizer.model_max_length,
316
+ truncation=True,
317
+ ).input_ids
318
+
319
+ targets = input_ids.clone()
320
+
321
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
322
+
323
+ # Mask targets
324
+ sep = conv.sep + conv.roles[1] + ": "
325
+ for conversation, target in zip(conversations, targets):
326
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
327
+
328
+ rounds = conversation.split(conv.sep2)
329
+ cur_len = 1
330
+ target[:cur_len] = IGNORE_INDEX
331
+ for i, rou in enumerate(rounds):
332
+ if rou == "":
333
+ break
334
+
335
+ parts = rou.split(sep)
336
+ if len(parts) != 2:
337
+ break
338
+ parts[0] += sep
339
+
340
+ if has_speech:
341
+ round_len = len(tokenizer_speech_token(rou, tokenizer))
342
+ instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2
343
+ else:
344
+ round_len = len(tokenizer(rou).input_ids)
345
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
346
+
347
+ # FIXME: tokenizer bug
348
+ if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
349
+ round_len -= 1
350
+ instruction_len -= 1
351
+
352
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
353
+
354
+ cur_len += round_len
355
+ target[cur_len:] = IGNORE_INDEX
356
+
357
+ if cur_len < tokenizer.model_max_length:
358
+ if cur_len != total_len:
359
+ target[:] = IGNORE_INDEX
360
+ print(
361
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
362
+ f" (ignored)"
363
+ )
364
+
365
+ return dict(
366
+ input_ids=input_ids,
367
+ labels=targets,
368
+ )
369
+
370
+
371
+ def preprocess_plain(
372
+ sources: Sequence[str],
373
+ tokenizer: transformers.PreTrainedTokenizer,
374
+ ) -> Dict:
375
+ # add end signal and concatenate together
376
+ conversations = []
377
+ for source in sources:
378
+ assert len(source) == 2
379
+ assert DEFAULT_SPEECH_TOKEN in source[0]['value']
380
+ source[0]['value'] = DEFAULT_SPEECH_TOKEN
381
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
382
+ conversations.append(conversation)
383
+ # tokenize conversations
384
+ input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
385
+ targets = copy.deepcopy(input_ids)
386
+ for target, source in zip(targets, sources):
387
+ tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer))
388
+ target[:tokenized_len] = IGNORE_INDEX
389
+
390
+ return dict(input_ids=input_ids, labels=targets)
391
+
392
+
393
+ def preprocess(
394
+ sources: Sequence[str],
395
+ tokenizer: transformers.PreTrainedTokenizer,
396
+ has_speech: bool = False
397
+ ) -> Dict:
398
+ """
399
+ Given a list of sources, each is a conversation list. This transform:
400
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
401
+ 2. Concatenate conversations together;
402
+ 3. Tokenize the concatenated conversation;
403
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
404
+ """
405
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
406
+ return preprocess_plain(sources, tokenizer)
407
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
408
+ return preprocess_llama_2(sources, tokenizer, has_speech=has_speech)
409
+ if conversation_lib.default_conversation.version.startswith("v1"):
410
+ return preprocess_v1(sources, tokenizer, has_speech=has_speech)
411
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_3:
412
+ return preprocess_llama_3(sources, tokenizer, has_speech=has_speech)
413
+ raise NotImplementedError
ola/mm_utils.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import base64
3
+ import math
4
+ import ast
5
+
6
+ import torch
7
+ from transformers import StoppingCriteria
8
+ import os
9
+ import io
10
+
11
+ if 'VIDEO_RESIZE' in os.environ:
12
+ # highresxpatch
13
+ VIDEO_RESIZE = os.environ['VIDEO_RESIZE']
14
+ video_base, video_ps = VIDEO_RESIZE.split('x')
15
+ video_base = int(video_base)
16
+ video_ps = int(video_ps)
17
+ print(f"VIDEO_RESIZE is set as {VIDEO_RESIZE}, {video_base}, {video_ps}")
18
+ else:
19
+ HIGHRES_BASE = None
20
+
21
+ if 'HIGHRES_BASE' in os.environ:
22
+ # highresxpatch
23
+ HIGHRES_BASE = os.environ['HIGHRES_BASE']
24
+ highres_base, highres_ps = HIGHRES_BASE.split('x')
25
+ highres_base = int(highres_base)
26
+ highres_ps = int(highres_ps)
27
+ print(f"HIGHRES_BASE is set as {HIGHRES_BASE}, {highres_base}, {highres_ps}")
28
+ else:
29
+ HIGHRES_BASE = None
30
+
31
+ if 'MAXRES' in os.environ:
32
+ # highresxpatch
33
+ MAXRES = int(os.environ['MAXRES'])
34
+ print(f"MAXRES is set as {MAXRES}")
35
+ else:
36
+ MAXRES = 1536
37
+
38
+ if 'MINRES' in os.environ:
39
+ # highresxpatch
40
+ MINRES = int(os.environ['MINRES'])
41
+ print(f"MINRES is set as {MINRES}")
42
+ else:
43
+ MINRES = 0
44
+
45
+ if 'VIDEO_MAXRES' in os.environ:
46
+ # highresxpatch
47
+ VIDEO_MAXRES = int(os.environ['VIDEO_MAXRES'])
48
+ print(f"VIDEO_MAXRES is set as {VIDEO_MAXRES}")
49
+ else:
50
+ VIDEO_MAXRES = 1536
51
+
52
+ if 'VIDEO_MINRES' in os.environ:
53
+ # highresxpatch
54
+ VIDEO_MINRES = int(os.environ['VIDEO_MINRES'])
55
+ print(f"VIDEO_MINRES is set as {VIDEO_MINRES}")
56
+ else:
57
+ MINRES = 0
58
+
59
+ if 'PAD2STRIDE' in os.environ:
60
+ # highresxpatch
61
+ PAD2STRIDE = True
62
+ print(f"PAD2STRIDE is set")
63
+ else:
64
+ PAD2STRIDE = False
65
+
66
+ if 'LOWRES_RESIZE' in os.environ:
67
+ LOWRES_RESIZE = os.environ['LOWRES_RESIZE']
68
+ print(f"LOWRES_RESIZE is set as {LOWRES_RESIZE}")
69
+ if 'x' in LOWRES_RESIZE:
70
+ size, ps = LOWRES_RESIZE.split('x')
71
+ size = int(size)
72
+ ps = int(ps)
73
+ LOWRES_RESIZE = (size, ps)
74
+ else:
75
+ LOWRES_RESIZE = int(LOWRES_RESIZE)
76
+ else:
77
+ LOWRES_RESIZE = None
78
+
79
+
80
+ def pad_image(image, target_resolution, value=0):
81
+ """
82
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
83
+
84
+ Args:
85
+ image (PIL.Image.Image): The input image.
86
+ target_resolution (tuple): The target resolution (width, height) of the image.
87
+
88
+ Returns:
89
+ PIL.Image.Image: The resized and padded image.
90
+ """
91
+ original_width, original_height = image.size
92
+ target_width, target_height = target_resolution
93
+ # Create a new image with the target size and paste the resized image onto it
94
+ new_image = Image.new('RGB', (target_width, target_height), (value, value, value))
95
+ paste_x = (target_width - original_width) // 2
96
+ paste_y = (target_height - original_height) // 2
97
+ new_image.paste(image, (paste_x, paste_y))
98
+ return new_image
99
+
100
+ def resize_images(image, patch_size=14, base_size=896):
101
+ h, w = image.size
102
+ if base_size == 0:
103
+ if h * w > MAXRES * MAXRES:
104
+ # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
105
+ scale = MAXRES * MAXRES / (h * w)
106
+ scale = math.sqrt(scale)
107
+ elif h * w < MINRES * MINRES:
108
+ # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
109
+ scale = MINRES * MINRES / (h * w)
110
+ scale = math.sqrt(scale)
111
+ else:
112
+ scale = None
113
+ else:
114
+ scale = base_size * base_size / (h * w)
115
+ scale = math.sqrt(scale)
116
+
117
+
118
+ if scale is not None:
119
+ new_h = int(h * scale / patch_size) * patch_size
120
+ new_w = int(w * scale / patch_size) * patch_size
121
+ new_h = max(new_h, patch_size)
122
+ new_w = max(new_w, patch_size)
123
+ image = image.resize((new_h, new_w))
124
+ elif PAD2STRIDE:
125
+ if h % patch_size == 0:
126
+ new_h = h
127
+ else:
128
+ new_h = (h // patch_size + 1) * patch_size
129
+
130
+ if w % patch_size == 0:
131
+ new_w = w
132
+ else:
133
+ new_w = (w // patch_size + 1) * patch_size
134
+ image = pad_image(image, (new_h, new_w), value=127)
135
+ else:
136
+ scale = 1.0
137
+ new_h = int(h * scale / patch_size) * patch_size
138
+ new_w = int(w * scale / patch_size) * patch_size
139
+ new_h = max(new_h, patch_size)
140
+ new_w = max(new_w, patch_size)
141
+ image = image.resize((new_h, new_w))
142
+
143
+ return image
144
+
145
+ def resize_video(image, patch_size=14, base_size=896):
146
+ h, w = image.size
147
+ if base_size == 0:
148
+ if h * w > VIDEO_MAXRES * VIDEO_MAXRES:
149
+ # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
150
+ scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w)
151
+ scale = math.sqrt(scale)
152
+ elif h * w < VIDEO_MINRES * VIDEO_MINRES:
153
+ # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
154
+ scale = VIDEO_MINRES * VIDEO_MINRES / (h * w)
155
+ scale = math.sqrt(scale)
156
+ else:
157
+ scale = None
158
+ else:
159
+ scale = base_size * base_size / (h * w)
160
+ scale = math.sqrt(scale)
161
+
162
+ if scale is not None:
163
+ new_h = int(h * scale / patch_size) * patch_size
164
+ new_w = int(w * scale / patch_size) * patch_size
165
+ image = image.resize((new_h, new_w))
166
+ elif PAD2STRIDE:
167
+ if h % patch_size == 0:
168
+ new_h = h
169
+ else:
170
+ new_h = (h // patch_size + 1) * patch_size
171
+
172
+ if w % patch_size == 0:
173
+ new_w = w
174
+ else:
175
+ new_w = (w // patch_size + 1) * patch_size
176
+ image = pad_image(image, (new_h, new_w), value=127)
177
+ else:
178
+ scale = 1.0
179
+ new_h = int(h * scale / patch_size) * patch_size
180
+ new_w = int(w * scale / patch_size) * patch_size
181
+ image = image.resize((new_h, new_w))
182
+
183
+ return image
184
+
185
+ def process_anyres_video(image, processor):
186
+ if VIDEO_RESIZE is not None:
187
+ image = resize_video(image, patch_size=video_ps, base_size=video_base)
188
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
189
+ return image.unsqueeze(0)
190
+ else:
191
+ raise ValueError("VIDEO_RESIZE is not set")
192
+
193
+ def process_anyres_highres_image_genli(image, processor):
194
+ h, w = image.size
195
+ if h < 32 and w < 32:
196
+ min_size = min(h, w)
197
+ ratio = 64 / min_size
198
+ image = image.resize((int(h * ratio), int(w * ratio)))
199
+ elif h < 32:
200
+ ratio = 64 / h
201
+ image = image.resize((int(h * ratio), int(w * ratio)))
202
+ elif w < 32:
203
+ ratio = 64 / w
204
+ image = image.resize((int(h * ratio), int(w * ratio)))
205
+ if HIGHRES_BASE is not None:
206
+ image = resize_images(image, patch_size=highres_ps, base_size=highres_base)
207
+
208
+ if LOWRES_RESIZE is not None:
209
+ image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0])
210
+ else:
211
+ image_original_resize = image.resize((384, 384))
212
+
213
+ # image_patches = [image_original_resize] + [image_original_resize]
214
+ # image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
215
+ # for image_patch in image_patches]
216
+ image_patches = processor.preprocess(image_original_resize, return_tensors='pt')['pixel_values'][0]
217
+ image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
218
+ # return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0)
219
+ return image_patches.unsqueeze(0), image_padded.unsqueeze(0)
220
+
221
+ def read_image_patch(patch_info):
222
+ if 'img_path' in patch_info.keys():
223
+ image = Image.open(patch_info['img_path']).convert('RGB')
224
+ else:
225
+ if 'image_encoing' in patch_info.keys():
226
+ patch_info['image_encoding'] = patch_info['image_encoing']
227
+ image_file_name = patch_info['patch']
228
+ start_bytes = int(patch_info['start_num'])
229
+ file_size = int(patch_info['size'])
230
+
231
+ with open(image_file_name, 'rb') as f:
232
+ f.seek(start_bytes)
233
+ if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
234
+ image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB")
235
+ else:
236
+ image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB")
237
+ return image
238
+
239
+
240
+ def get_model_name_from_path(model_path):
241
+ model_path = model_path.strip("/")
242
+ model_paths = model_path.split("/")
243
+ if model_paths[-1].startswith('checkpoint-'):
244
+ return model_paths[-2] + "_" + model_paths[-1]
245
+ else:
246
+ return model_paths[-1]
247
+
248
+
249
+ class KeywordsStoppingCriteria(StoppingCriteria):
250
+ def __init__(self, keywords, tokenizer, input_ids):
251
+ self.keywords = keywords
252
+ self.keyword_ids = []
253
+ for keyword in keywords:
254
+ cur_keyword_ids = tokenizer(keyword).input_ids
255
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
256
+ cur_keyword_ids = cur_keyword_ids[1:]
257
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
258
+ self.tokenizer = tokenizer
259
+ self.start_len = input_ids.shape[1]
260
+
261
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
262
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
263
+ offset = min(output_ids.shape[1] - self.start_len, 3)
264
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
265
+ for keyword_id in self.keyword_ids:
266
+ if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
267
+ return True
268
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
269
+ for keyword in self.keywords:
270
+ if keyword in outputs:
271
+ return True
272
+ return False
ola/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .language_model.ola_qwen import OlaQwenForCausalLM, OlaConfigQwen
ola/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (222 Bytes). View file
 
ola/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (220 Bytes). View file
 
ola/model/__pycache__/builder.cpython-310.pyc ADDED
Binary file (3.27 kB). View file
 
ola/model/__pycache__/builder.cpython-38.pyc ADDED
Binary file (3.34 kB). View file
 
ola/model/__pycache__/ola_arch.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
ola/model/__pycache__/ola_arch.cpython-38.pyc ADDED
Binary file (12 kB). View file
 
ola/model/builder.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import shutil
4
+
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
6
+ import torch
7
+ from ola.model import *
8
+ from ola.model.speech_encoder.builder import build_speech_encoder
9
+
10
+ def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs):
11
+ if load_8bit:
12
+ kwargs['load_in_8bit'] = True
13
+ elif load_4bit:
14
+ kwargs['load_in_4bit'] = True
15
+ kwargs['quantization_config'] = BitsAndBytesConfig(
16
+ load_in_4bit=True,
17
+ bnb_4bit_compute_dtype=torch.float16,
18
+ bnb_4bit_use_double_quant=True,
19
+ bnb_4bit_quant_type='nf4'
20
+ )
21
+ else:
22
+ kwargs['torch_dtype'] = torch.bfloat16
23
+
24
+ if use_flash_attn:
25
+ kwargs['attn_implementation'] = 'flash_attention_2'
26
+
27
+ model_cls = OlaQwenForCausalLM
28
+
29
+ # Load OmniSpeech model
30
+ if is_lora:
31
+ assert model_base is not None, "model_base is required for LoRA models."
32
+ from ola.model.language_model.ola_qwen import OlaConfigQwen
33
+ lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path)
34
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
35
+ print('Loading OmniSpeech from base model...')
36
+ model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
37
+ print('Loading additional OmniSpeech weights...')
38
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
39
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
40
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
41
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
42
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
43
+ model.load_state_dict(non_lora_trainables, strict=False)
44
+
45
+ from peft import PeftModel
46
+ print('Loading LoRA weights...')
47
+ model = PeftModel.from_pretrained(model, model_path)
48
+ print('Merging LoRA weights...')
49
+ model = model.merge_and_unload()
50
+ print('Model is loaded...')
51
+ elif model_base is not None:
52
+ print('Loading OmniSpeech from base model...')
53
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
54
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
55
+ model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
56
+
57
+ speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu')
58
+ speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
59
+ model.load_state_dict(speech_projector_weights, strict=False)
60
+ model = model.to(device=device)
61
+ else:
62
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
63
+ model = model_cls.from_pretrained(
64
+ model_path,
65
+ low_cpu_mem_usage=False,
66
+ **kwargs
67
+ )
68
+ model = model.to(device=device)
69
+
70
+ model.get_model().speech_encoder = build_speech_encoder(model.config)
71
+ model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
72
+
73
+ image_processor = None
74
+ model.resize_token_embeddings(len(tokenizer))
75
+ vision_tower = model.get_vision_tower()
76
+ print("Loading vision tower...")
77
+ if not vision_tower.is_loaded:
78
+ vision_tower.load_model(device_map=device)
79
+ if device != "auto":
80
+ vision_tower.to(device="cuda", dtype=torch.bfloat16)
81
+ else:
82
+ vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
83
+ image_processor = vision_tower.image_processor
84
+ print("Loading vision tower succeeded.")
85
+
86
+ if hasattr(model.config, "max_sequence_length"):
87
+ context_len = model.config.max_sequence_length
88
+ else:
89
+ context_len = 16384
90
+
91
+ return tokenizer, model, image_processor, context_len
ola/model/language_model/__pycache__/ola_qwen.cpython-310.pyc ADDED
Binary file (5.31 kB). View file
 
ola/model/language_model/__pycache__/ola_qwen.cpython-38.pyc ADDED
Binary file (5.26 kB). View file
 
ola/model/language_model/ola_qwen.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ import transformers
7
+ from transformers import AutoConfig, AutoModelForCausalLM
8
+
9
+
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ from transformers.generation.utils import GenerateOutput
12
+
13
+ from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM
14
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
15
+
16
+
17
+ class OlaConfigQwen(Qwen2Config):
18
+ model_type = "ola_qwen"
19
+
20
+
21
+ class OlaQwenModel(OlaMetaModel, Qwen2Model):
22
+ config_class = OlaConfigQwen
23
+
24
+ def __init__(self, config: Qwen2Config):
25
+ super(OlaQwenModel, self).__init__(config)
26
+
27
+
28
+ class OlaQwenForCausalLM(Qwen2ForCausalLM, OlaMetaForCausalLM):
29
+ config_class = OlaConfigQwen
30
+
31
+ def __init__(self, config):
32
+ super(Qwen2ForCausalLM, self).__init__(config)
33
+
34
+ config.rope_scaling = None
35
+ self.model = OlaQwenModel(config)
36
+ self.vocab_size = config.vocab_size
37
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
38
+
39
+ # Initialize weights and apply final processing
40
+ self.post_init()
41
+
42
+ def get_model(self):
43
+ return self.model
44
+
45
+ def forward(
46
+ self,
47
+ input_ids: torch.LongTensor = None,
48
+ attention_mask: Optional[torch.Tensor] = None,
49
+ position_ids: Optional[torch.LongTensor] = None,
50
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
51
+ inputs_embeds: Optional[torch.FloatTensor] = None,
52
+ labels: Optional[torch.LongTensor] = None,
53
+ use_cache: Optional[bool] = None,
54
+ output_attentions: Optional[bool] = None,
55
+ output_hidden_states: Optional[bool] = None,
56
+ speech: Optional[torch.FloatTensor] = None,
57
+ speech_lengths: Optional[torch.LongTensor] = None,
58
+ speech_chunks: Optional[torch.LongTensor] = None,
59
+ speech_wav: Optional[torch.FloatTensor] = None,
60
+ images: Optional[torch.FloatTensor] = None,
61
+ images_highres: Optional[List[torch.FloatTensor]] = None,
62
+ image_sizes: Optional[List[List[int]]] = None,
63
+ modalities: Optional[List[str]] = ["image"],
64
+ return_dict: Optional[bool] = None,
65
+ cache_position: Optional[torch.LongTensor] = None,
66
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
67
+
68
+ if inputs_embeds is None:
69
+ (
70
+ input_ids,
71
+ position_ids,
72
+ attention_mask,
73
+ past_key_values,
74
+ inputs_embeds,
75
+ labels
76
+ ) = self.prepare_inputs_labels_for_speech_vision_text(
77
+ input_ids,
78
+ position_ids,
79
+ attention_mask,
80
+ past_key_values,
81
+ labels,
82
+ speech,
83
+ speech_lengths,
84
+ speech_chunks,
85
+ speech_wav,
86
+ images,
87
+ modalities,
88
+ image_sizes,
89
+ images_highres
90
+ )
91
+
92
+ if labels is None:
93
+ return super().forward(
94
+ input_ids=input_ids,
95
+ attention_mask=attention_mask,
96
+ position_ids=position_ids,
97
+ past_key_values=past_key_values,
98
+ inputs_embeds=inputs_embeds,
99
+ use_cache=use_cache,
100
+ output_attentions=output_attentions,
101
+ output_hidden_states=output_hidden_states,
102
+ return_dict=return_dict
103
+ )
104
+ else:
105
+ return self.forward_llm_efficient(
106
+ input_ids=input_ids,
107
+ attention_mask=attention_mask,
108
+ position_ids=position_ids,
109
+ past_key_values=past_key_values,
110
+ inputs_embeds=inputs_embeds,
111
+ labels=labels,
112
+ use_cache=use_cache,
113
+ output_attentions=output_attentions,
114
+ output_hidden_states=output_hidden_states,
115
+ return_dict=return_dict
116
+ )
117
+
118
+
119
+ def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
120
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
121
+ output_hidden_states = (
122
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
123
+ )
124
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
125
+
126
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
127
+ outputs = self.model(
128
+ input_ids=input_ids,
129
+ attention_mask=attention_mask,
130
+ position_ids=position_ids,
131
+ past_key_values=past_key_values,
132
+ inputs_embeds=inputs_embeds,
133
+ use_cache=use_cache,
134
+ output_attentions=output_attentions,
135
+ output_hidden_states=output_hidden_states,
136
+ return_dict=return_dict,
137
+ )
138
+
139
+ hidden_states = outputs[0]
140
+ hidden_dim = hidden_states.size(-1)
141
+ shift_labels = labels[..., 1:].contiguous().reshape(-1)
142
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
143
+ assert shift_labels.size(0) == shift_hidden_states.size(0)
144
+ mask = shift_labels > -1
145
+ assert mask.float().sum() > 0
146
+ shift_labels = shift_labels[mask]
147
+ shift_hidden_states = shift_hidden_states[mask, :]
148
+ logits = self.lm_head(shift_hidden_states)
149
+ logits = logits.float()
150
+ loss_fct = nn.CrossEntropyLoss()
151
+ loss = loss_fct(logits, shift_labels)
152
+
153
+
154
+ if not return_dict:
155
+ output = (logits,) + outputs[1:]
156
+ return (loss,) + output if loss is not None else output
157
+
158
+
159
+ return CausalLMOutputWithPast(
160
+ loss=loss,
161
+ logits=logits,
162
+ past_key_values=outputs.past_key_values,
163
+ hidden_states=outputs.hidden_states,
164
+ attentions=outputs.attentions,
165
+ )
166
+
167
+ @torch.no_grad()
168
+ def generate(
169
+ self,
170
+ inputs: Optional[torch.Tensor] = None,
171
+ speech: Optional[torch.Tensor] = None,
172
+ speech_lengths: Optional[torch.Tensor] = None,
173
+ speech_chunks: Optional[torch.Tensor] = None,
174
+ speech_wav: Optional[torch.FloatTensor] = None,
175
+ images: Optional[torch.Tensor] = None,
176
+ images_highres: Optional[List[torch.FloatTensor]] = None,
177
+ image_sizes: Optional[torch.Tensor] = None,
178
+ modalities: Optional[List[str]] = ["image"],
179
+ **kwargs,
180
+ ) -> Union[GenerateOutput, torch.LongTensor]:
181
+ position_ids = kwargs.pop("position_ids", None)
182
+ attention_mask = kwargs.pop("attention_mask", None)
183
+ if "inputs_embeds" in kwargs:
184
+ raise NotImplementedError("`inputs_embeds` is not supported")
185
+
186
+ (
187
+ inputs,
188
+ position_ids,
189
+ attention_mask,
190
+ _,
191
+ inputs_embeds,
192
+ _
193
+ ) = self.prepare_inputs_labels_for_speech_vision_text(
194
+ inputs,
195
+ position_ids,
196
+ attention_mask,
197
+ None,
198
+ None,
199
+ speech,
200
+ speech_lengths,
201
+ speech_chunks,
202
+ speech_wav,
203
+ images,
204
+ modalities,
205
+ image_sizes,
206
+ images_highres
207
+ )
208
+
209
+ return super().generate(
210
+ position_ids=position_ids,
211
+ attention_mask=attention_mask,
212
+ inputs_embeds=inputs_embeds,
213
+ **kwargs
214
+ )
215
+
216
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
217
+ inputs_embeds=None, **kwargs):
218
+ speech = kwargs.pop("speech", None)
219
+ speech_lengths = kwargs.pop("speech_lengths", None)
220
+ speech_chunks = kwargs.pop("speech_chunks", None)
221
+ images = kwargs.pop("images", None)
222
+ image_sizes = kwargs.pop("image_sizes", None)
223
+ inputs = super().prepare_inputs_for_generation(
224
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
225
+ )
226
+ if speech is not None:
227
+ inputs['speech'] = speech
228
+ inputs['speech_lengths'] = speech_lengths
229
+ inputs['speech_chunks'] = speech_chunks
230
+ if images is not None:
231
+ inputs["images"] = images
232
+ if image_sizes is not None:
233
+ inputs["image_sizes"] = image_sizes
234
+ return inputs
235
+
236
+ AutoConfig.register("ola_qwen", OlaConfigQwen)
237
+ AutoModelForCausalLM.register(OlaConfigQwen, OlaQwenForCausalLM)
ola/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (579 Bytes). View file
 
ola/model/multimodal_encoder/__pycache__/builder.cpython-38.pyc ADDED
Binary file (577 Bytes). View file
 
ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-310.pyc ADDED
Binary file (28.8 kB). View file
 
ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-38.pyc ADDED
Binary file (28.7 kB). View file
 
ola/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .oryx_vit import SigLIPViTAnysizeWrapper
3
+
4
+ def build_vision_tower(vision_tower_cfg, **kwargs):
5
+ vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None))
6
+ is_absolute_path_exists = os.path.exists(vision_tower)
7
+ print(f"Buiding OryxViTWrapper from {vision_tower}...")
8
+ # path = vision_tower.split(":")[1]
9
+ return SigLIPViTAnysizeWrapper(vision_tower, path=vision_tower, args=vision_tower_cfg, **kwargs)
ola/model/multimodal_encoder/oryx_vit.py ADDED
@@ -0,0 +1,1126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ from typing import (
6
+ Callable,
7
+ Dict,
8
+ Final,
9
+ List,
10
+ Literal,
11
+ Optional,
12
+ Sequence,
13
+ Set,
14
+ Tuple,
15
+ Type,
16
+ Union,
17
+ )
18
+
19
+ from torch.utils.checkpoint import checkpoint
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ try:
24
+ from timm.layers import (
25
+ AttentionPoolLatent,
26
+ DropPath,
27
+ LayerType,
28
+ Mlp,
29
+ PatchDropout,
30
+ PatchEmbed,
31
+ resample_abs_pos_embed,
32
+ )
33
+ from timm.models._manipulate import checkpoint_seq, named_apply
34
+ except:
35
+ print('Wrong timm version')
36
+
37
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
38
+
39
+ from typing import Optional
40
+
41
+ import logging
42
+ import torch
43
+ import torch.nn as nn
44
+ import torch.nn.functional as F
45
+
46
+ import deepspeed
47
+ import os
48
+ if 'LOAD_VISION_EARLY' in os.environ:
49
+ print("LOAD_VISION_EARLY is set")
50
+ LOAD_VISION_EARLY = True
51
+ else:
52
+ LOAD_VISION_EARLY = False
53
+
54
+
55
+ if 'SKIP_LOAD_VIT' in os.environ:
56
+ print("SKIP_LOAD_VIT is set")
57
+ SKIP_LOAD_VIT = True
58
+ else:
59
+ SKIP_LOAD_VIT = False
60
+
61
+ if 'VIT_WITH_GRAD' in os.environ:
62
+ print("VIT_WITH_GRAD is set")
63
+ VIT_WITH_GRAD = True
64
+ else:
65
+ VIT_WITH_GRAD = False
66
+
67
+
68
+ if 'FIX_SIZE' in os.environ:
69
+ print("FIX_SIZE is set")
70
+ FIX_SIZE = True
71
+ else:
72
+ FIX_SIZE = False
73
+
74
+
75
+ if 'ANYRES_SPLIT' in os.environ:
76
+ ANYRES_SPLIT = int(os.environ['ANYRES_SPLIT'])
77
+ print(f"ANYRES_SPLIT is set as {ANYRES_SPLIT}")
78
+ else:
79
+ ANYRES_SPLIT = None
80
+
81
+
82
+ if 'FORCE_NO_DOWNSAMPLE' in os.environ:
83
+ print("FORCE_NO_DOWNSAMPLE is set")
84
+ FORCE_NO_DOWNSAMPLE = True
85
+ else:
86
+ FORCE_NO_DOWNSAMPLE = False
87
+
88
+ if 'EVAL_72B' in os.environ:
89
+ print("EVAL_72B is set")
90
+ EVAL_72B = True
91
+ else:
92
+ EVAL_72B = False
93
+
94
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
95
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
96
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
97
+ def norm_cdf(x):
98
+ # Computes standard normal cumulative distribution function
99
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
100
+
101
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
102
+ warnings.warn(
103
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
104
+ "The distribution of values may be incorrect.",
105
+ stacklevel=2,
106
+ )
107
+
108
+ with torch.no_grad():
109
+ # Values are generated by using a truncated uniform distribution and
110
+ # then using the inverse CDF for the normal distribution.
111
+ # Get upper and lower cdf values
112
+ l = norm_cdf((a - mean) / std) # noqa: E741
113
+ u = norm_cdf((b - mean) / std)
114
+
115
+ # Uniformly fill tensor with values from [l, u], then translate to
116
+ # [2l-1, 2u-1].
117
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
118
+
119
+ # Use inverse cdf transform for normal distribution to get truncated
120
+ # standard normal
121
+ tensor.erfinv_()
122
+
123
+ # Transform to proper mean, std
124
+ tensor.mul_(std * math.sqrt(2.0))
125
+ tensor.add_(mean)
126
+
127
+ # Clamp to ensure it's in the proper range
128
+ tensor.clamp_(min=a, max=b)
129
+ return tensor
130
+
131
+
132
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
133
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
134
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
135
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
136
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
137
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
138
+ with values outside :math:`[a, b]` redrawn until they are within
139
+ the bounds. The method used for generating the random values works
140
+ best when :math:`a \leq \text{mean} \leq b`.
141
+ Args:
142
+ tensor: an n-dimensional `torch.Tensor`
143
+ mean: the mean of the normal distribution
144
+ std: the standard deviation of the normal distribution
145
+ a: the minimum cutoff value
146
+ b: the maximum cutoff value
147
+ Examples:
148
+ >>> w = torch.empty(3, 5)
149
+ >>> nn.init.trunc_normal_(w)
150
+ """
151
+
152
+ with torch.no_grad():
153
+ dtype = tensor.dtype
154
+ tensor_fp32 = tensor.float()
155
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
156
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
157
+ tensor.copy_(tensor_dtype)
158
+
159
+
160
+ def init_weights(self):
161
+ if self.pos_embed is not None:
162
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
163
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
164
+
165
+
166
+ def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
167
+ """ViT weight initialization, original timm impl (for reproducibility)"""
168
+ if isinstance(module, nn.Linear):
169
+ trunc_normal_(module.weight, std=0.02)
170
+ if module.bias is not None:
171
+ nn.init.zeros_(module.bias)
172
+ elif hasattr(module, "init_weights"):
173
+ module.init_weights()
174
+
175
+
176
+ class Attention(nn.Module):
177
+ fused_attn: Final[bool]
178
+
179
+ def __init__(
180
+ self,
181
+ dim: int,
182
+ num_heads: int = 8,
183
+ qkv_bias: bool = False,
184
+ qk_norm: bool = False,
185
+ attn_drop: float = 0.0,
186
+ proj_drop: float = 0.0,
187
+ norm_layer: nn.Module = nn.LayerNorm,
188
+ ) -> None:
189
+ super().__init__()
190
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
191
+ self.num_heads = num_heads
192
+ self.head_dim = dim // num_heads
193
+ self.scale = self.head_dim**-0.5
194
+ # self.fused_attn = use_fused_attn()
195
+ self.fused_attn = True
196
+
197
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
198
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
199
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
200
+ self.attn_drop = nn.Dropout(attn_drop)
201
+ self.proj = nn.Linear(dim, dim)
202
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
203
+
204
+ def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
205
+ B, N, C = x.shape
206
+ qkv = (
207
+ self.qkv(x)
208
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
209
+ .permute(2, 0, 3, 1, 4)
210
+ )
211
+ q, k, v = qkv.unbind(0)
212
+ q, k = self.q_norm(q), self.k_norm(k)
213
+
214
+ if cu_slens is not None:
215
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
216
+ k = k.permute(0, 2, 1, 3)
217
+ v = v.permute(0, 2, 1, 3)
218
+ max_seqlen = torch.max(cu_slens[1:] - cu_slens[:-1]).item()
219
+ x = flash_attn_varlen_func(
220
+ q.squeeze(0),
221
+ k.squeeze(0),
222
+ v.squeeze(0),
223
+ cu_seqlens_q=cu_slens,
224
+ cu_seqlens_k=cu_slens,
225
+ max_seqlen_q=max_seqlen,
226
+ max_seqlen_k=max_seqlen,
227
+ softmax_scale=self.scale,
228
+ causal=False,
229
+ )
230
+
231
+ x = x.reshape(B, N, -1)
232
+ x = self.proj(x)
233
+ x = self.proj_drop(x)
234
+
235
+ else:
236
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
237
+ k = k.permute(0, 2, 1, 3)
238
+ v = v.permute(0, 2, 1, 3)
239
+ x = flash_attn_func(q, k, v, softmax_scale=self.scale) # -> b, n, h, c
240
+
241
+ x = x.reshape(B, N, -1)
242
+ x = self.proj(x)
243
+ x = self.proj_drop(x)
244
+ # if self.fused_attn:
245
+ # x = F.scaled_dot_product_attention(
246
+ # q,
247
+ # k,
248
+ # v,
249
+ # dropout_p=self.attn_drop.p if self.training else 0.0,
250
+ # )
251
+ # else:
252
+ # q = q * self.scale
253
+ # attn = q @ k.transpose(-2, -1)
254
+ # attn = attn.softmax(dim=-1)
255
+ # attn = self.attn_drop(attn)
256
+ # x = attn @ v
257
+
258
+ # x = x.transpose(1, 2).reshape(B, N, C)
259
+ # x = self.proj(x)
260
+ # x = self.proj_drop(x)
261
+ return x
262
+
263
+
264
+ class LayerScale(nn.Module):
265
+ def __init__(
266
+ self,
267
+ dim: int,
268
+ init_values: float = 1e-5,
269
+ inplace: bool = False,
270
+ ) -> None:
271
+ super().__init__()
272
+ self.inplace = inplace
273
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
274
+
275
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
276
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
277
+
278
+
279
+ class Block(nn.Module):
280
+ def __init__(
281
+ self,
282
+ dim: int,
283
+ num_heads: int,
284
+ mlp_ratio: float = 4.0,
285
+ qkv_bias: bool = False,
286
+ qk_norm: bool = False,
287
+ proj_drop: float = 0.0,
288
+ attn_drop: float = 0.0,
289
+ init_values: Optional[float] = None,
290
+ drop_path: float = 0.0,
291
+ act_layer: nn.Module = nn.GELU,
292
+ norm_layer: nn.Module = nn.LayerNorm,
293
+ mlp_layer: nn.Module = Mlp,
294
+ ) -> None:
295
+ super().__init__()
296
+ self.norm1 = norm_layer(dim)
297
+ self.attn = Attention(
298
+ dim,
299
+ num_heads=num_heads,
300
+ qkv_bias=qkv_bias,
301
+ qk_norm=qk_norm,
302
+ attn_drop=attn_drop,
303
+ proj_drop=proj_drop,
304
+ norm_layer=norm_layer,
305
+ )
306
+ self.ls1 = (
307
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
308
+ )
309
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
310
+
311
+ self.norm2 = norm_layer(dim)
312
+ self.mlp = mlp_layer(
313
+ in_features=dim,
314
+ hidden_features=int(dim * mlp_ratio),
315
+ act_layer=act_layer,
316
+ drop=proj_drop,
317
+ )
318
+ self.ls2 = (
319
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
320
+ )
321
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
322
+
323
+ def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
324
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_slens=cu_slens)))
325
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
326
+ return x
327
+
328
+
329
+ class VisionTransformer(nn.Module):
330
+ """Vision Transformer
331
+
332
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
333
+ - https://arxiv.org/abs/2010.11929
334
+ """
335
+
336
+ dynamic_img_size: Final[bool]
337
+
338
+ def __init__(
339
+ self,
340
+ img_size: Union[int, Tuple[int, int]] = 224,
341
+ patch_size: Union[int, Tuple[int, int]] = 16,
342
+ in_chans: int = 3,
343
+ num_classes: int = 1000,
344
+ global_pool: Literal["", "avg", "token", "map"] = "token",
345
+ embed_dim: int = 768,
346
+ depth: int = 12,
347
+ num_heads: int = 12,
348
+ mlp_ratio: float = 4.0,
349
+ qkv_bias: bool = True,
350
+ qk_norm: bool = False,
351
+ init_values: Optional[float] = None,
352
+ class_token: bool = True,
353
+ no_embed_class: bool = False,
354
+ reg_tokens: int = 0,
355
+ pre_norm: bool = False,
356
+ fc_norm: Optional[bool] = None,
357
+ dynamic_img_size: bool = False,
358
+ dynamic_img_pad: bool = False,
359
+ drop_rate: float = 0.0,
360
+ pos_drop_rate: float = 0.0,
361
+ patch_drop_rate: float = 0.0,
362
+ proj_drop_rate: float = 0.0,
363
+ attn_drop_rate: float = 0.0,
364
+ drop_path_rate: float = 0.0,
365
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
366
+ embed_layer: Callable = PatchEmbed,
367
+ norm_layer: Optional[LayerType] = None,
368
+ act_layer: Optional[LayerType] = None,
369
+ strict_img_size: bool = False,
370
+ block_fn: Type[nn.Module] = Block,
371
+ mlp_layer: Type[nn.Module] = Mlp,
372
+ ignore_head: bool = False,
373
+ add_patch2x2: bool = False,
374
+ ) -> None:
375
+ """
376
+ Args:
377
+ img_size: Input image size.
378
+ patch_size: Patch size.
379
+ in_chans: Number of image input channels.
380
+ num_classes: Mumber of classes for classification head.
381
+ global_pool: Type of global pooling for final sequence (default: 'token').
382
+ embed_dim: Transformer embedding dimension.
383
+ depth: Depth of transformer.
384
+ num_heads: Number of attention heads.
385
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
386
+ qkv_bias: Enable bias for qkv projections if True.
387
+ init_values: Layer-scale init values (layer-scale enabled if not None).
388
+ class_token: Use class token.
389
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
390
+ reg_tokens: Number of register tokens.
391
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
392
+ drop_rate: Head dropout rate.
393
+ pos_drop_rate: Position embedding dropout rate.
394
+ attn_drop_rate: Attention dropout rate.
395
+ drop_path_rate: Stochastic depth rate.
396
+ weight_init: Weight initialization scheme.
397
+ embed_layer: Patch embedding layer.
398
+ norm_layer: Normalization layer.
399
+ act_layer: MLP activation layer.
400
+ block_fn: Transformer block layer.
401
+ """
402
+ super().__init__()
403
+ assert global_pool in ("", "avg", "token", "map")
404
+ assert class_token or global_pool != "token"
405
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
406
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
407
+ # act_layer = get_act_layer(act_layer) or nn.GELU
408
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
409
+ act_layer = nn.GELU
410
+
411
+ self.num_classes = num_classes
412
+ self.global_pool = global_pool
413
+ self.num_features = self.embed_dim = (
414
+ embed_dim # num_features for consistency with other models
415
+ )
416
+ self.num_prefix_tokens = 1 if class_token else 0
417
+ self.num_prefix_tokens += reg_tokens
418
+ self.num_reg_tokens = reg_tokens
419
+ self.has_class_token = class_token
420
+ self.no_embed_class = (
421
+ no_embed_class # don't embed prefix positions (includes reg)
422
+ )
423
+ self.dynamic_img_size = dynamic_img_size
424
+ self.grad_checkpointing = False
425
+ self.ignore_head = ignore_head
426
+
427
+ embed_args = {}
428
+ if dynamic_img_size:
429
+ # flatten deferred until after pos embed
430
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
431
+ self.patch_embed = embed_layer(
432
+ img_size=img_size,
433
+ patch_size=patch_size,
434
+ in_chans=in_chans,
435
+ embed_dim=embed_dim,
436
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
437
+ dynamic_img_pad=dynamic_img_pad,
438
+ strict_img_size=strict_img_size,
439
+ **embed_args,
440
+ )
441
+ num_patches = self.patch_embed.num_patches
442
+
443
+ self.cls_token = (
444
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
445
+ )
446
+ self.reg_token = (
447
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
448
+ )
449
+ embed_len = (
450
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
451
+ )
452
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
453
+
454
+
455
+ # deepspeed.zero.register_external_parameter(self, self.pos_embed)
456
+ # deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.weight)
457
+ # deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.bias)
458
+ # print(self.patch_embed.state_dict().keys())
459
+
460
+
461
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
462
+ if patch_drop_rate > 0:
463
+ self.patch_drop = PatchDropout(
464
+ patch_drop_rate,
465
+ num_prefix_tokens=self.num_prefix_tokens,
466
+ )
467
+ else:
468
+ self.patch_drop = nn.Identity()
469
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
470
+
471
+ dpr = [
472
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
473
+ ] # stochastic depth decay rule
474
+ self.blocks = nn.Sequential(
475
+ *[
476
+ block_fn(
477
+ dim=embed_dim,
478
+ num_heads=num_heads,
479
+ mlp_ratio=mlp_ratio,
480
+ qkv_bias=qkv_bias,
481
+ qk_norm=qk_norm,
482
+ init_values=init_values,
483
+ proj_drop=proj_drop_rate,
484
+ attn_drop=attn_drop_rate,
485
+ drop_path=dpr[i],
486
+ norm_layer=norm_layer,
487
+ act_layer=act_layer,
488
+ mlp_layer=mlp_layer,
489
+ )
490
+ for i in range(depth)
491
+ ]
492
+ )
493
+
494
+
495
+ if add_patch2x2:
496
+ if add_patch2x2 == 'v2':
497
+ self.downsample = nn.Sequential(
498
+ nn.Conv2d(embed_dim, embed_dim*2, kernel_size=2, stride=2),
499
+ nn.GELU(),
500
+ nn.Conv2d(embed_dim*2, embed_dim*4, 1)
501
+ )
502
+ else:
503
+ mid_dim = embed_dim * 2
504
+ self.downsample = nn.Sequential(
505
+ nn.Conv2d(embed_dim, mid_dim, kernel_size=2, stride=2),
506
+ nn.GELU(),
507
+ nn.Conv2d(mid_dim, mid_dim, 1)
508
+ )
509
+
510
+ else:
511
+ self.downsample = None
512
+
513
+
514
+ # self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
515
+
516
+ # # Classifier Head
517
+ # if global_pool == "map":
518
+ # AttentionPoolLatent.init_weights = init_weights
519
+ # self.attn_pool = AttentionPoolLatent(
520
+ # self.embed_dim,
521
+ # num_heads=num_heads,
522
+ # mlp_ratio=mlp_ratio,
523
+ # norm_layer=norm_layer,
524
+ # )
525
+ # else:
526
+ # self.attn_pool = None
527
+ # self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
528
+ # self.head_drop = nn.Dropout(drop_rate)
529
+ # self.head = (
530
+ # nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
531
+ # )
532
+
533
+ # if weight_init != "skip":
534
+ # self.init_weights(weight_init)
535
+
536
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
537
+ assert mode in ("jax", "jax_nlhb", "moco", "")
538
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
539
+ trunc_normal_(self.pos_embed, std=0.02)
540
+ if self.cls_token is not None:
541
+ nn.init.normal_(self.cls_token, std=1e-6)
542
+ named_apply(init_weights_vit_timm, self)
543
+
544
+ @torch.jit.ignore
545
+ def no_weight_decay(self) -> Set:
546
+ return {"pos_embed", "cls_token", "dist_token"}
547
+
548
+ @torch.jit.ignore
549
+ def group_matcher(self, coarse: bool = False) -> Dict:
550
+ return dict(
551
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
552
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
553
+ )
554
+
555
+ @torch.jit.ignore
556
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
557
+ self.grad_checkpointing = enable
558
+
559
+ @torch.jit.ignore
560
+ def get_classifier(self) -> nn.Module:
561
+ return self.head
562
+
563
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
564
+ self.num_classes = num_classes
565
+ if global_pool is not None:
566
+ assert global_pool in ("", "avg", "token", "map")
567
+ if global_pool == "map" and self.attn_pool is None:
568
+ assert (
569
+ False
570
+ ), "Cannot currently add attention pooling in reset_classifier()."
571
+ elif global_pool != "map " and self.attn_pool is not None:
572
+ self.attn_pool = None # remove attention pooling
573
+ self.global_pool = global_pool
574
+ self.head = (
575
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
576
+ )
577
+
578
+ def rescale_positional_embedding(self, out_size):
579
+ h, w = out_size
580
+ pos_embed_shape = int((self.pos_embed.shape[1]) ** 0.5)
581
+ if (h, w) == (pos_embed_shape, pos_embed_shape):
582
+ return self.pos_embed
583
+ rescaled_positional_embedding = \
584
+ self.pos_embed.new_zeros(1, h*w, self.pos_embed.shape[2])
585
+ pe_2d = self.pos_embed[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
586
+ pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
587
+ rescaled_positional_embedding[0] = pe_2d.T.contiguous()
588
+ return rescaled_positional_embedding
589
+
590
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
591
+ if self.dynamic_img_size:
592
+ B, H, W, C = x.shape
593
+ pos_embed = resample_abs_pos_embed(
594
+ self.pos_embed,
595
+ (H, W),
596
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
597
+ )
598
+ x = x.view(B, -1, C)
599
+ else:
600
+ pos_embed = self.pos_embed
601
+
602
+ to_cat = []
603
+ if self.cls_token is not None:
604
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
605
+ if self.reg_token is not None:
606
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
607
+
608
+ if self.no_embed_class:
609
+ # deit-3, updated JAX (big vision)
610
+ # position embedding does not overlap with class token, add then concat
611
+ x = x + pos_embed
612
+ if to_cat:
613
+ x = torch.cat(to_cat + [x], dim=1)
614
+ else:
615
+ # original timm, JAX, and deit vit impl
616
+ # pos_embed has entry for class token, concat then add
617
+ if to_cat:
618
+ x = torch.cat(to_cat + [x], dim=1)
619
+ x = x + pos_embed
620
+
621
+ return self.pos_drop(x)
622
+
623
+ def _intermediate_layers(
624
+ self,
625
+ x: torch.Tensor,
626
+ n: Union[int, Sequence] = 1,
627
+ ) -> List[torch.Tensor]:
628
+ outputs, num_blocks = [], len(self.blocks)
629
+ take_indices = set(
630
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
631
+ )
632
+
633
+ # forward pass
634
+ x = self.patch_embed(x)
635
+ x = self._pos_embed(x)
636
+ x = self.patch_drop(x)
637
+ x = self.norm_pre(x)
638
+ for i, blk in enumerate(self.blocks):
639
+ x = blk(x)
640
+ if i in take_indices:
641
+ outputs.append(x)
642
+
643
+ return outputs
644
+
645
+ def get_intermediate_layers(
646
+ self,
647
+ x: torch.Tensor,
648
+ n: Union[int, Sequence] = 1,
649
+ reshape: bool = False,
650
+ return_prefix_tokens: bool = False,
651
+ norm: bool = False,
652
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
653
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
654
+ Inspired by DINO / DINOv2 interface
655
+ """
656
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
657
+ outputs = self._intermediate_layers(x, n)
658
+ if norm:
659
+ outputs = [self.norm(out) for out in outputs]
660
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
661
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
662
+
663
+ if reshape:
664
+ grid_size = self.patch_embed.grid_size
665
+ outputs = [
666
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
667
+ .permute(0, 3, 1, 2)
668
+ .contiguous()
669
+ for out in outputs
670
+ ]
671
+
672
+ if return_prefix_tokens:
673
+ return tuple(zip(outputs, prefix_tokens))
674
+ return tuple(outputs)
675
+
676
+ def forward_features_list(self, x_list):
677
+ x_all = []
678
+ image_sizes = []
679
+ for x in x_list:
680
+ if EVAL_72B:
681
+ x = x.to('cuda:0')
682
+ bs, _, h, w = x.shape
683
+
684
+ # fix patch size=14 in datasets
685
+ pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0]
686
+ pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1]
687
+ x = F.pad(x, (0, pad_w, 0, pad_h))
688
+
689
+ bs, _, h, w = x.shape
690
+
691
+ h = h // self.patch_embed.patch_size[0]
692
+ w = w // self.patch_embed.patch_size[1]
693
+
694
+ x = self.patch_embed(x)
695
+ # x = self._pos_embed(x)
696
+ x = x + self.rescale_positional_embedding(out_size=(h, w))
697
+ x = self.patch_drop(x)
698
+ x = self.norm_pre(x)
699
+ x_all.append(x)
700
+ image_sizes.append((h, w))
701
+
702
+ slen = [xi.size(1) for xi in x_all]
703
+ x = torch.cat(x_all, dim=1)
704
+
705
+ cu_indices = [0, ]
706
+ for i in slen:
707
+ cu_indices.append(cu_indices[-1] + i)
708
+
709
+ cu_slens = torch.tensor(cu_indices, dtype=torch.int32).to(x.device)
710
+ for idx, blk in enumerate(self.blocks):
711
+ if self.grad_checkpointing and not torch.jit.is_scripting():
712
+ x = checkpoint(blk, x, cu_slens, use_reentrant=True)
713
+ else:
714
+ x = blk(x, cu_slens=cu_slens)
715
+ feats = x.split(slen, dim=1) #[(1, slen, c)]
716
+
717
+ if self.downsample is not None:
718
+ new_feats = []
719
+ new_sizes = []
720
+ for f, s in zip(feats, image_sizes):
721
+ h, w = s
722
+ b, n, c = f.size()
723
+ f = f.reshape(b, h, w, c).permute(0, 3, 1, 2)
724
+ f = self.downsample(f)
725
+ b, c, h, w = f.size()
726
+ f = f.permute(0, 2, 3, 1).reshape(b, h*w, c)
727
+ new_feats.append(f)
728
+ new_sizes.append((h, w))
729
+ return new_feats, new_sizes
730
+
731
+
732
+ return feats, image_sizes
733
+
734
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
735
+ if EVAL_72B:
736
+ x = x.to('cuda:0')
737
+ bs, _, h, w = x.shape
738
+ h = h // self.patch_embed.patch_size[0]
739
+ w = w // self.patch_embed.patch_size[1]
740
+
741
+ x = self.patch_embed(x)
742
+ # x = self._pos_embed(x)
743
+ x = x + self.rescale_positional_embedding(out_size=(h, w))
744
+ x = self.patch_drop(x)
745
+ x = self.norm_pre(x)
746
+ if self.grad_checkpointing and not torch.jit.is_scripting():
747
+ x = checkpoint_seq(self.blocks, x)
748
+ else:
749
+ x = self.blocks(x)
750
+
751
+ if self.downsample is not None:
752
+ b, n, c = x.size()
753
+ x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
754
+ x = self.downsample(x)
755
+ b, c, h, w = x.size()
756
+ x = x.permute(0, 2, 3, 1).reshape(b, h*w, c)
757
+ new_feats = x
758
+ new_sizes = (h, w)
759
+ return new_feats, new_sizes
760
+
761
+ return x, (h, w)
762
+
763
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
764
+ x = self.norm(x)
765
+ if self.attn_pool is not None:
766
+ x = self.attn_pool(x)
767
+ elif self.global_pool == "avg":
768
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
769
+ elif self.global_pool:
770
+ x = x[:, 0] # class token
771
+ x = self.fc_norm(x)
772
+ x = self.head_drop(x)
773
+ return x if pre_logits else self.head(x)
774
+
775
+ def forward(self, x, cal_attn_pool=False):
776
+ if type(x) is list:
777
+ x, image_sizes = self.forward_features_list(x)
778
+ return x, image_sizes, None
779
+ else:
780
+ x, image_sizes = self.forward_features(x)
781
+ return x, image_sizes, None
782
+
783
+ @dataclass
784
+ class SigLIPVisionCfg:
785
+ width: int = 1152
786
+ layers: Union[Tuple[int, int, int, int], int] = 27
787
+ heads: int = 16
788
+ patch_size: int = 14
789
+ image_size: Union[Tuple[int, int], int] = 336
790
+ global_pool: str = "map"
791
+ mlp_ratio: float = 3.7362
792
+ class_token: bool = False
793
+ num_classes: int = 0
794
+ use_checkpoint: bool = False
795
+
796
+
797
+ SigLIP_MODEL_CONFIG = {
798
+ "siglip_so400m_patch14_384": {
799
+ "image_size": 384,
800
+ "patch_size": 14,
801
+ "width": 1152,
802
+ "layers": 27,
803
+ "heads": 16,
804
+ "mlp_ratio": 3.7362,
805
+ "global_pool": "map",
806
+ "use_checkpoint": False,
807
+ },
808
+ "siglip_so400m_patch16_384": {
809
+ "image_size": 384,
810
+ "patch_size": 16,
811
+ "width": 1152,
812
+ "layers": 27,
813
+ "heads": 16,
814
+ "mlp_ratio": 3.7362,
815
+ "global_pool": "map",
816
+ "use_checkpoint": False,
817
+ },
818
+ "siglip_so400m_patch14_224": {
819
+ "image_size": 224,
820
+ "patch_size": 14,
821
+ "width": 1152,
822
+ "layers": 27,
823
+ "heads": 16,
824
+ "mlp_ratio": 3.7362,
825
+ "global_pool": "map",
826
+ "use_checkpoint": False,
827
+ },
828
+ "siglip_large_patch16_384": {
829
+ "image_size": 384,
830
+ "patch_size": 16,
831
+ "width": 1024,
832
+ "layers": 24,
833
+ "heads": 16,
834
+ "mlp_ratio": 4,
835
+ "global_pool": "map",
836
+ "use_checkpoint": False,
837
+ },
838
+ }
839
+
840
+
841
+ def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bicubic'):
842
+ # interpolate position embedding
843
+ orig_size = 24
844
+ new_size = 128
845
+ pos_tokens = model.pos_embed
846
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, model.embed_dim).permute(0, 3, 1, 2)
847
+ pos_tokens = torch.nn.functional.interpolate(
848
+ pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False)
849
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
850
+ model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True)
851
+ return model
852
+
853
+ def create_siglip_vit(
854
+ model_name: str = "siglip_so400m_patch14_384",
855
+ image_size: int = 384,
856
+ select_layer: int = -1,
857
+ path: str = "",
858
+ gradient_checkpointing: bool = False,
859
+ **kwargs,
860
+ ):
861
+ assert (
862
+ model_name in SigLIP_MODEL_CONFIG.keys()
863
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
864
+
865
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
866
+
867
+ if select_layer <= 0:
868
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
869
+ else:
870
+ layers = min(vision_cfg.layers, select_layer)
871
+
872
+
873
+
874
+ if 'patch2x2' or 'patch4x4' in path:
875
+ add_patch2x2 = True
876
+ else:
877
+ add_patch2x2 = False
878
+
879
+ if 'patch4x4pool' in path or 'patch2x2from4x4' in path:
880
+ add_patch2x2 = 'v2'
881
+
882
+ if FORCE_NO_DOWNSAMPLE:
883
+ add_patch2x2 = False
884
+
885
+ model = VisionTransformer(
886
+ img_size=2048,
887
+ patch_size=16,
888
+ embed_dim=vision_cfg.width,
889
+ depth=layers,
890
+ num_heads=vision_cfg.heads,
891
+ mlp_ratio=vision_cfg.mlp_ratio,
892
+ class_token=vision_cfg.class_token,
893
+ global_pool=vision_cfg.global_pool,
894
+ dynamic_img_pad=False,
895
+ strict_img_size=False,
896
+ ignore_head=kwargs.get("ignore_head", False),
897
+ weight_init=kwargs.get("weight_init", "skip"),
898
+ num_classes=0,
899
+ add_patch2x2=add_patch2x2
900
+ )
901
+
902
+ if not SKIP_LOAD_VIT:
903
+ if path is not None and os.path.exists(path):
904
+ ckpt = path
905
+ else:
906
+ raise ValueError(f"Model checkpoint not found at {path}")
907
+ state_dict = torch.load(ckpt, map_location="cpu")
908
+ print('loading vision backbone from', path)
909
+
910
+ if 'genli' in path:
911
+ new_sd = {}
912
+ for k in state_dict.keys():
913
+ if k.startswith('base_model.model.model.vision_tower.vision_tower.'):
914
+ new_k = k.replace('base_model.model.model.vision_tower.vision_tower.', '')
915
+ new_sd[new_k] = state_dict[k]
916
+
917
+ if add_patch2x2:
918
+ if k.startswith('base_model.model.model.mm_projector.proj'):
919
+ new_k = k.replace('base_model.model.model.mm_projector.proj', 'downsample')
920
+ new_sd[new_k] = state_dict[k]
921
+
922
+ elif 'distill' in path:
923
+ new_sd = {}
924
+ state_dict = state_dict['model']
925
+ for k in state_dict.keys():
926
+ if k.startswith('vision_tower.'):
927
+ new_k = k.replace('vision_tower.', '')
928
+ new_sd[new_k] = state_dict[k]
929
+ else:
930
+ raise NotImplementedError
931
+ msg = model.load_state_dict(new_sd, strict=False)
932
+ print(msg)
933
+
934
+ else:
935
+ print("#### Skip loading vision backbone")
936
+
937
+ if gradient_checkpointing:
938
+ model.set_grad_checkpointing(True)
939
+ return model
940
+
941
+ from transformers import CLIPImageProcessor
942
+ import torch.distributed as dist
943
+
944
+ class SigLIPViTAnysizeWrapper(nn.Module):
945
+ def __init__(self, vision_tower, path, args, delay_load=False):
946
+ super().__init__()
947
+
948
+ self.is_loaded = False
949
+
950
+ self.vision_tower_name = vision_tower
951
+ self.args = args
952
+ self.path = path
953
+
954
+ self.select_layer = -1
955
+ if self.select_layer < -1: self.select_layer += 1
956
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
957
+
958
+ self.output_dim = 1152
959
+ if not FORCE_NO_DOWNSAMPLE:
960
+ if 'patch2x2' or 'patch4x4' in path:
961
+ self.output_dim = 1152*2
962
+
963
+ if 'patch4x4pool' in path or 'patch2x2from4x4' in path:
964
+ self.output_dim = 1152*4
965
+
966
+ if not delay_load or LOAD_VISION_EARLY:
967
+ self.load_model()
968
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
969
+ # TODO: better detector is needed.
970
+ print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
971
+ self.load_model()
972
+
973
+ def load_model(self, device_map=None):
974
+ if self.is_loaded:
975
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
976
+ return
977
+
978
+ self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
979
+ if self.args.mm_projector_type == "conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp_woconv":
980
+ self.image_processor.crop_size['height'] = 384
981
+ self.image_processor.crop_size['width'] = 384
982
+ self.image_processor.size['shortest_edge'] = 384
983
+ print("Resizeing clip processor to 384...")
984
+ self.image_processor.image_mean = [0.5, 0.5, 0.5]
985
+ self.image_processor.image_std = [0.5, 0.5, 0.5]
986
+ print("Loading vision model...")
987
+ if VIT_WITH_GRAD:
988
+ self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384',
989
+ gradient_checkpointing=True)
990
+ self.vision_tower.train()
991
+ else:
992
+ self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384',
993
+ gradient_checkpointing=False)
994
+ for p in self.vision_tower.parameters():
995
+ p.requires_grad = False
996
+ self.vision_tower.eval()
997
+ self.is_loaded = True
998
+
999
+ def train(self, mode = True):
1000
+ self.training = mode
1001
+
1002
+ if self.is_loaded and not VIT_WITH_GRAD:
1003
+ self.vision_tower.eval()
1004
+
1005
+ def split_images(self, images, split_res=512, base_size=32):
1006
+ split_images = []
1007
+ sub_images_info = []
1008
+ for image in images:
1009
+ now_sub_images = []
1010
+ _, c, h, w = image.shape
1011
+ if h * w <= split_res * split_res:
1012
+ split_images.append(image)
1013
+ sub_images_info.append(
1014
+ (
1015
+ 1, 1, 1, h // base_size, w // base_size, [(0, h // base_size, 0, w // base_size)]
1016
+ )
1017
+ )
1018
+ continue
1019
+ nsplit_h = math.ceil(h / split_res)
1020
+ nsplit_w = math.ceil(w / split_res)
1021
+ sub_h = int(h / nsplit_h / base_size) * base_size
1022
+ sub_w = int(w / nsplit_w / base_size) * base_size
1023
+ crop_infos = []
1024
+ for i in range(nsplit_h):
1025
+ for j in range(nsplit_w):
1026
+ begin_h = i * sub_h
1027
+ begin_w = j * sub_w
1028
+
1029
+ if i == nsplit_h - 1:
1030
+ end_h = h
1031
+ else:
1032
+ end_h = (i + 1) * sub_h
1033
+
1034
+ if j == nsplit_w - 1:
1035
+ end_w = w
1036
+ else:
1037
+ end_w = (j + 1) * sub_w
1038
+
1039
+ assert (end_h - begin_h) % base_size == 0 and (end_w - begin_w) % base_size == 0
1040
+
1041
+ sub_image = image[:, :, begin_h:end_h, begin_w:end_w]
1042
+ now_sub_images.append(sub_image)
1043
+ crop_infos.append(
1044
+ (begin_h // base_size, end_h // base_size, begin_w // base_size, end_w // base_size)
1045
+ )
1046
+
1047
+ split_images += now_sub_images
1048
+ sub_images_info.append(
1049
+ (
1050
+ len(now_sub_images), nsplit_h, nsplit_w, h // base_size, w // base_size, crop_infos
1051
+ )
1052
+ )
1053
+
1054
+ return split_images, sub_images_info
1055
+
1056
+
1057
+ def unsplit_images(self, features, sizes, sub_images_info):
1058
+ new_features = []
1059
+ for feature, size in zip(features, sizes):
1060
+ h, w = size
1061
+ new_features.append(
1062
+ feature.reshape(1, h, w, -1)
1063
+ )
1064
+
1065
+ fused_images = []
1066
+ images_sizes = []
1067
+ sub_count = 0
1068
+ for n_split, nsplit_h, nsplit_w, total_h, total_w, crop_infos in sub_images_info:
1069
+ sub_features = new_features[sub_count:sub_count+n_split]
1070
+ sub_count += n_split
1071
+
1072
+ total_feature = new_features[0].new_zeros(1, total_h, total_w, self.hidden_size)
1073
+ for feature, (begin_h, end_h, begin_w, end_w) in zip(sub_features, crop_infos):
1074
+ total_feature[:, begin_h:end_h, begin_w:end_w] += feature
1075
+
1076
+ fused_images.append(total_feature.reshape(1, total_h * total_w, self.hidden_size))
1077
+ images_sizes.append((total_h, total_w))
1078
+
1079
+ return fused_images, images_sizes
1080
+
1081
+
1082
+
1083
+ def forward_func(self, images, force_fix_size=False, cal_attn_pool=False):
1084
+ if type(images) is list:
1085
+ xs = [x.to(self.dtype) for x in images]
1086
+ image_features, img_size, cls_token = self.vision_tower(xs, cal_attn_pool=cal_attn_pool)
1087
+ image_features = [x.to(images[0].dtype) for x in image_features]
1088
+
1089
+ else:
1090
+ image_forward_outs, img_size, cls_token = self.vision_tower(images.to(self.dtype), cal_attn_pool=cal_attn_pool)
1091
+ image_features = image_forward_outs.to(images.dtype)
1092
+
1093
+ return image_features, img_size, cls_token
1094
+
1095
+ def forward(self, images, cal_attn_pool=False):
1096
+ if VIT_WITH_GRAD:
1097
+ image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool)
1098
+ return image_features, img_size
1099
+ else:
1100
+ with torch.no_grad():
1101
+ image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool)
1102
+ return image_features, img_size
1103
+
1104
+
1105
+ @property
1106
+ def dummy_feature(self):
1107
+ return torch.zeros(1, 1152, device=self.device, dtype=self.dtype)
1108
+
1109
+ @property
1110
+ def dtype(self):
1111
+ return self.vision_tower.pos_embed.dtype
1112
+
1113
+ @property
1114
+ def device(self):
1115
+ return self.vision_tower.pos_embed.device
1116
+
1117
+ @property
1118
+ def hidden_size(self):
1119
+ return self.output_dim
1120
+
1121
+ @property
1122
+ def config(self):
1123
+ return type('LLaVAConfigWrapper', (), {
1124
+ # 'image_size': 224,
1125
+ 'patch_size': 16,
1126
+ })()
ola/model/multimodal_projector/__pycache__/builder.cpython-310.pyc ADDED
Binary file (4.61 kB). View file
 
ola/model/multimodal_projector/__pycache__/builder.cpython-38.pyc ADDED
Binary file (4.62 kB). View file
 
ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc ADDED
Binary file (2.76 kB). View file
 
ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-38.pyc ADDED
Binary file (2.78 kB). View file
 
ola/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+ import math
6
+
7
+ from .pooler_projector import NormalizedDwPooler
8
+ import os
9
+ import math
10
+
11
+
12
+ if 'REGIONAL_POOL' in os.environ:
13
+ REGIONAL_POOL = os.environ['REGIONAL_POOL']
14
+ else:
15
+ REGIONAL_POOL = '2x'
16
+ print(f"REGIONAL_POOL is set as {REGIONAL_POOL}")
17
+
18
+ class IdentityMap(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def forward(self, x, *args, **kwargs):
23
+ return x
24
+
25
+ @property
26
+ def config(self):
27
+ return {"mm_projector_type": 'identity'}
28
+
29
+
30
+ class SimpleResBlock(nn.Module):
31
+ def __init__(self, channels):
32
+ super().__init__()
33
+ self.pre_norm = nn.LayerNorm(channels)
34
+
35
+ self.proj = nn.Sequential(
36
+ nn.Linear(channels, channels),
37
+ nn.GELU(),
38
+ nn.Linear(channels, channels)
39
+ )
40
+ def forward(self, x):
41
+ x = self.pre_norm(x)
42
+ return x + self.proj(x)
43
+
44
+ class OlaMLP(nn.Module):
45
+ def __init__(self, in_channels, out_channels, twoview=False):
46
+ super().__init__()
47
+
48
+ self.proj1 = nn.Linear(in_channels, out_channels)
49
+ self.proj2 = nn.Linear(out_channels, out_channels)
50
+ self.act = nn.GELU()
51
+ self.pooler = NormalizedDwPooler(out_channels)
52
+
53
+ embed_std = 1 / math.sqrt(out_channels)
54
+ self.image_newline = nn.Parameter(
55
+ torch.randn(out_channels) * embed_std
56
+ )
57
+ self.image_begin = nn.Parameter(
58
+ torch.randn(out_channels) * embed_std
59
+ )
60
+ self.image_end = nn.Parameter(
61
+ torch.randn(out_channels) * embed_std
62
+ )
63
+
64
+ if twoview:
65
+ self.image_sep = nn.Parameter(
66
+ torch.randn(out_channels) * embed_std
67
+ )
68
+
69
+ def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'):
70
+
71
+ if modalities in ['image', 'text']:
72
+ h, w = size
73
+ dtype = x.dtype
74
+ x = x.reshape(x.shape[0], h, w, -1)
75
+ x = self.proj1(x)
76
+ x = self.pooler(x, forward_type=REGIONAL_POOL)
77
+ x = self.act(x)
78
+ x = self.proj2(x)
79
+
80
+
81
+ b, h, w, c = x.shape
82
+ x = torch.cat([
83
+ x,
84
+ self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
85
+ ], dim=2)
86
+ x = x.reshape(b, -1, c)
87
+
88
+ if x2 is not None:
89
+ h2, w2 = size2
90
+ x2 = x2.reshape(x2.shape[0], h2, w2, -1)
91
+ x2 = self.proj1(x2)
92
+ x2 = self.pooler(x2, forward_type=REGIONAL_POOL)
93
+ x2 = self.act(x2)
94
+ x2 = self.proj2(x2)
95
+
96
+ b2, h2, w2, c2 = x2.shape
97
+ x2 = torch.cat([
98
+ x2,
99
+ self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
100
+ ], dim=2)
101
+ x2 = x2.reshape(b, -1, c)
102
+ sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
103
+ x = torch.cat([x, sep, x2], dim=1)
104
+
105
+ begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
106
+ end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
107
+ x = torch.cat([begin, x, end], dim=1)
108
+ return x
109
+ elif modalities in ['video']:
110
+ # x2 is the true feature, ignore x
111
+ h, w = size
112
+ dtype = x.dtype
113
+ x = x.reshape(x.shape[0], h, w, -1)
114
+ x1 = self.proj1(x)
115
+ x1 = self.pooler(x1, forward_type=REGIONAL_POOL)
116
+ x1 = self.proj2(x1).mean() * 0.0
117
+
118
+ h2, w2 = size2
119
+ x2 = x2.reshape(x2.shape[0], h2, w2, -1)
120
+ x2 = self.proj1(x2)
121
+ x2 = self.pooler(x2, forward_type=REGIONAL_POOL)
122
+ x2 = self.act(x2)
123
+ x2 = self.proj2(x2)
124
+
125
+ b2, h2, w2, c = x2.shape
126
+ x2 = torch.cat([
127
+ x2,
128
+ self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype)
129
+ ], dim=2)
130
+
131
+ x2 = x2.reshape(b2, -1, c)
132
+
133
+ sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype)
134
+ x2 = torch.cat([x2, sep], dim=1)
135
+
136
+ x2 = x2.flatten(0, 1)
137
+
138
+ begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype)
139
+ end = self.image_end.reshape(1, -1).expand(1, c).to(dtype)
140
+ x2 = torch.cat([begin, x2, end], dim=0)
141
+ x2 = x2.unsqueeze(0)
142
+ return x2
143
+ else:
144
+ raise ValueError(f'Unknown modalities: {modalities}')
145
+
146
+ def build_vision_projector(config, delay_load=False, **kwargs):
147
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
148
+
149
+ if projector_type == 'linear':
150
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
151
+
152
+ elif projector_type == 'ola_mlp':
153
+ return OlaMLP(config.mm_hidden_size, config.hidden_size, twoview=True)
154
+
155
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
156
+ if mlp_gelu_match:
157
+ mlp_depth = int(mlp_gelu_match.group(1))
158
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
159
+ for _ in range(1, mlp_depth):
160
+ modules.append(nn.GELU())
161
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
162
+ return nn.Sequential(*modules)
163
+
164
+ mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type)
165
+ if mlp_gelu_resnet_match:
166
+ mlp_depth = int(mlp_gelu_resnet_match.group(1))
167
+ res_depth = int(mlp_gelu_resnet_match.group(2))
168
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
169
+ for _ in range(1, mlp_depth):
170
+ modules.append(nn.GELU())
171
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
172
+ for _ in range(res_depth):
173
+ modules.append(SimpleResBlock(config.hidden_size))
174
+ return nn.Sequential(*modules)
175
+
176
+ if projector_type == 'identity':
177
+ return IdentityMap()
178
+
179
+ raise ValueError(f'Unknown projector type: {projector_type}')
ola/model/multimodal_projector/pooler_projector.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from transformers.models.clip.modeling_clip import CLIPVisionModel
7
+ import os
8
+
9
+ if 'NORMALIZE_POOL' in os.environ:
10
+ NORMALIZE_POOL = bool(int(os.environ['NORMALIZE_POOL']))
11
+ print(f'NORMALIZE_POOL: {NORMALIZE_POOL}')
12
+ else:
13
+ NORMALIZE_POOL = True
14
+
15
+
16
+ class PoolerProjector(nn.Module):
17
+ def __init__(self, config, vision_cfg):
18
+ super().__init__()
19
+ self._config = config
20
+ self.hw = vision_cfg.image_size // vision_cfg.patch_size
21
+
22
+ self.conv_pool = nn.Conv2d(
23
+ config.mm_hidden_size, config.hidden_size,
24
+ kernel_size=2, stride=2
25
+ )
26
+
27
+ self.proj = nn.Sequential(
28
+ nn.GELU(),
29
+ nn.Linear(config.hidden_size, config.hidden_size),
30
+ )
31
+
32
+ def forward(self, x, *args, **kwargs):
33
+ height = width = self.hw
34
+ assert height * width == x.shape[1]
35
+ x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
36
+ x = self.conv_pool(x)
37
+ x = x.flatten(2).transpose(1, 2)
38
+ x = self.proj(x)
39
+ return x
40
+
41
+ @property
42
+ def config(self):
43
+ return {"mm_projector_type": 'pooler'}
44
+
45
+
46
+ class NormalizedDwPooler(nn.Module):
47
+ def __init__(self, dim):
48
+ super().__init__()
49
+ self.dim = dim
50
+ self.predictor = nn.Sequential(
51
+ nn.Linear(dim*2, dim),
52
+ nn.GELU(),
53
+ nn.Linear(dim, dim),
54
+ )
55
+
56
+ def forward(self, x, forward_type='2x'):
57
+ B, H, W, C = x.shape
58
+
59
+ if forward_type == '2x':
60
+ new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
61
+ pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
62
+ fused_x = torch.cat([new_x, pooled_x], dim=-1)
63
+ elif forward_type == '1x':
64
+ new_x = x.reshape(B, H, W, 1, C)
65
+ fused_x = torch.cat([new_x, new_x], dim=-1)
66
+ elif forward_type == '4x':
67
+ new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
68
+ pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
69
+ fused_x = torch.cat([new_x, pooled_x], dim=-1)
70
+
71
+ score = self.predictor(fused_x)
72
+ normalized_score = F.softmax(score, dim=-2)
73
+ new_x = (new_x * normalized_score).sum(dim=-2)
74
+ return new_x
ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc ADDED
Binary file (1.18 kB). View file
 
ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc ADDED
Binary file (1.18 kB). View file
 
ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc ADDED
Binary file (2.81 kB). View file
 
ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc ADDED
Binary file (2.83 kB). View file
 
ola/model/multimodal_resampler/builder.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .perceiver import DynamicCompressor
4
+
5
+ class IdentityMap(torch.nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, *args, **kwargs):
10
+ return x
11
+
12
+ @property
13
+ def config(self):
14
+ return {"mm_resampler_type": None}
15
+
16
+ def build_vision_resampler(model_args, delay_load=False, **kwargs):
17
+ # import pdb;pdb.set_trace()
18
+ resampler_type = getattr(model_args, 'mm_resampler_type', None)
19
+ if resampler_type == 'dynamic_compressor':
20
+ return DynamicCompressor(model_args, **kwargs)
21
+ elif resampler_type is None:
22
+ return IdentityMap()
23
+ else:
24
+ raise ValueError(f'Unknown resampler type: {resampler_type}')