gordonhubackup commited on
Commit
e62d81d
·
1 Parent(s): 2ba1ee8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE.txt +14 -0
  2. MANIFEST.in +2 -0
  3. README.md +1 -1
  4. app.py +151 -0
  5. bliva/__init__.py +29 -0
  6. bliva/common/config.py +469 -0
  7. bliva/common/dist_utils.py +137 -0
  8. bliva/common/gradcam.py +24 -0
  9. bliva/common/logger.py +193 -0
  10. bliva/common/optims.py +117 -0
  11. bliva/common/registry.py +268 -0
  12. bliva/common/utils.py +424 -0
  13. bliva/configs/default.yaml +7 -0
  14. bliva/configs/models/bliva_flant5xxl.yaml +39 -0
  15. bliva/configs/models/bliva_vicuna7b.yaml +39 -0
  16. bliva/conversation/__init__.py +0 -0
  17. bliva/conversation/conversation.py +180 -0
  18. bliva/models/Qformer.py +1216 -0
  19. bliva/models/__init__.py +208 -0
  20. bliva/models/base_model.py +251 -0
  21. bliva/models/blip2.py +319 -0
  22. bliva/models/bliva_flant5xxl.py +803 -0
  23. bliva/models/bliva_vicuna7b.py +783 -0
  24. bliva/models/clip_vit.py +272 -0
  25. bliva/models/eva_vit.py +442 -0
  26. bliva/models/modeling_llama.py +888 -0
  27. bliva/models/modeling_t5.py +2063 -0
  28. bliva/models/vit.py +527 -0
  29. bliva/processors/__init__.py +38 -0
  30. bliva/processors/base_processor.py +26 -0
  31. bliva/processors/blip_processors.py +239 -0
  32. bliva/processors/clip_processors.py +92 -0
  33. bliva/processors/randaugment.py +398 -0
  34. bliva_vicuna7b.pth +3 -0
  35. evaluate.py +93 -0
  36. hf_vicuna_7b/config.json +23 -0
  37. hf_vicuna_7b/generation_config.json +7 -0
  38. hf_vicuna_7b/pytorch_model-00001-of-00002.bin +3 -0
  39. hf_vicuna_7b/pytorch_model-00002-of-00002.bin +3 -0
  40. hf_vicuna_7b/pytorch_model.bin.index.json +330 -0
  41. hf_vicuna_7b/special_tokens_map.json +23 -0
  42. hf_vicuna_7b/tokenizer.model +3 -0
  43. hf_vicuna_7b/tokenizer_config.json +33 -0
  44. images/example.jpg +0 -0
  45. images/img1.jpg +0 -0
  46. images/img2.jpg +0 -0
  47. images/img3.jpg +0 -0
  48. images/img4.jpg +0 -0
  49. images/img5.jpg +0 -0
  50. images/img6.jpg +0 -0
LICENSE.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022 Salesforce, Inc.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
MANIFEST.in ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ include requirements.txt
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: BLIVA
3
- emoji: 😻
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
 
1
  ---
2
  title: BLIVA
3
+ emoji: 🚀
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ import gradio as gr
9
+
10
+ from bliva.common.config import Config
11
+ from bliva.common.dist_utils import get_rank
12
+ from bliva.common.registry import registry
13
+ from bliva.conversation.conversation import Chat, CONV_VISION, CONV_DIRECT
14
+
15
+ # imports modules for registration
16
+
17
+ from bliva.models import *
18
+ from bliva.processors import *
19
+ from bliva.models import load_model_and_preprocess
20
+ from evaluate import disable_torch_init
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(description="Demo")
24
+ parser.add_argument("--model_name",default='bliva_vicuna', type=str, help='model name')
25
+ parser.add_argument("--gpu_id", type=int, default=0, help="specify the gpu to load the model.")
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+ # ========================================
30
+ # Model Initialization
31
+ # ========================================
32
+
33
+ print('Initializing Chat')
34
+ args = parse_args()
35
+
36
+ if torch.cuda.is_available():
37
+ device='cuda:{}'.format(args.gpu_id)
38
+ else:
39
+ device=torch.device('cpu')
40
+
41
+ disable_torch_init()
42
+ if args.model_name == "blip2_vicuna_instruct":
43
+ model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="vicuna7b", is_eval=True, device=device)
44
+ elif args.model_name == "bliva_vicuna":
45
+ model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="vicuna7b", is_eval=True, device=device)
46
+ elif args.model_name == "bliva_flant5":
47
+ model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="flant5xxl", is_eval=True, device=device)
48
+ else:
49
+ print("Model not found")
50
+
51
+ vis_processor = vis_processors["eval"]
52
+
53
+
54
+ # vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
55
+ # vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
56
+ chat = Chat(model, vis_processor, device=device)
57
+ print('Initialization Finished')
58
+
59
+ # ========================================
60
+ # Gradio Setting
61
+ # ========================================
62
+
63
+ def gradio_reset(chat_state, img_list):
64
+ if chat_state is not None:
65
+ chat_state.messages = []
66
+ if img_list is not None:
67
+ img_list = []
68
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
69
+
70
+ def upload_img(gr_img, text_input, chat_state):
71
+ if gr_img is None:
72
+ return None, None, gr.update(interactive=True), chat_state, None
73
+ chat_state = CONV_DIRECT.copy() #CONV_VISION.copy()
74
+ img_list = []
75
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
76
+ return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
77
+
78
+ def gradio_ask(user_message, chatbot, chat_state):
79
+ if len(user_message) == 0:
80
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
81
+ chat.ask(user_message, chat_state)
82
+ chatbot = chatbot + [[user_message, None]]
83
+ return '', chatbot, chat_state
84
+
85
+
86
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
87
+ llm_message = chat.answer(conv=chat_state,
88
+ img_list=img_list,
89
+ num_beams=num_beams,
90
+ temperature=temperature,
91
+ max_new_tokens=300,
92
+ max_length=2000)[0]
93
+ chatbot[-1][1] = llm_message[0]
94
+ return chatbot, chat_state, img_list
95
+
96
+ title = """<h1 align="center">Demo of BLIVA</h1>"""
97
+ description = """<h3>This is the demo of BLIVA. Upload your images and start chatting!</h3>"""
98
+ article = """<p><a href='https://gordonhu608.github.io/bliva/'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/mlpc-ucsd/BLIVA'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
99
+ """
100
+
101
+ #TODO show examples below
102
+
103
+ with gr.Blocks() as demo:
104
+ gr.Markdown(title)
105
+ gr.Markdown(description)
106
+ gr.Markdown(article)
107
+
108
+ with gr.Row():
109
+ with gr.Column(scale=0.5):
110
+ image = gr.Image(type="pil")
111
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
112
+ clear = gr.Button("Restart 🔄")
113
+
114
+ num_beams = gr.Slider(
115
+ minimum=1,
116
+ maximum=10,
117
+ value=5,
118
+ step=1,
119
+ interactive=True,
120
+ label="beam search numbers)",
121
+ )
122
+
123
+ temperature = gr.Slider(
124
+ minimum=0.1,
125
+ maximum=2.0,
126
+ value=1.0,
127
+ step=0.1,
128
+ interactive=True,
129
+ label="Temperature",
130
+ )
131
+
132
+ with gr.Column():
133
+ chat_state = gr.State()
134
+ img_list = gr.State()
135
+ chatbot = gr.Chatbot(label='BLIVA')
136
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
137
+
138
+ gr.Examples(examples=[
139
+ [f"images/example.jpg", "Describe this image in detail."],
140
+ [f"images/img3.jpg", "What is this image about?"],
141
+ [f"images/img4.jpg", "What is the title of this movie?"],
142
+ ], inputs=[image, text_input])
143
+
144
+ upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
145
+
146
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
147
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
148
+ )
149
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
150
+
151
+ demo.launch(enable_queue=True)
bliva/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from bliva.common.registry import registry
14
+
15
+ from bliva.models import *
16
+ from bliva.processors import *
17
+
18
+
19
+ root_dir = os.path.dirname(os.path.abspath(__file__))
20
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
21
+
22
+ registry.register_path("library_root", root_dir)
23
+ repo_root = os.path.join(root_dir, "..")
24
+ registry.register_path("repo_root", repo_root)
25
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
26
+ registry.register_path("cache_root", cache_root)
27
+
28
+ registry.register("MAX_INT", sys.maxsize)
29
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
bliva/common/config.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from bliva.common.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ model_config_path = model_cls.default_config_path(model_type=model_type)
72
+
73
+ model_config = OmegaConf.create()
74
+ # hiararchy override, customized config > default config
75
+ model_config = OmegaConf.merge(
76
+ model_config,
77
+ OmegaConf.load(model_config_path),
78
+ {"model": config["model"]},
79
+ )
80
+
81
+ return model_config
82
+
83
+ @staticmethod
84
+ def build_runner_config(config):
85
+ return {"run": config.run}
86
+
87
+ @staticmethod
88
+ def build_dataset_config(config):
89
+ datasets = config.get("datasets", None)
90
+ if datasets is None:
91
+ raise KeyError(
92
+ "Expecting 'datasets' as the root key for dataset configuration."
93
+ )
94
+
95
+ dataset_config = OmegaConf.create()
96
+
97
+ for dataset_name in datasets:
98
+ print(f"Building dataset config for {dataset_name}")
99
+ builder_cls = registry.get_builder_class(dataset_name)
100
+
101
+ dataset_config_type = datasets[dataset_name].get("type", "default")
102
+ dataset_config_path = builder_cls.default_config_path(
103
+ type=dataset_config_type
104
+ )
105
+
106
+ # hiararchy override, customized config > default config
107
+ dataset_config = OmegaConf.merge(
108
+ dataset_config,
109
+ OmegaConf.load(dataset_config_path),
110
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
111
+ )
112
+
113
+ return dataset_config
114
+
115
+ def _convert_to_dot_list(self, opts):
116
+ if opts is None:
117
+ opts = []
118
+
119
+ if len(opts) == 0:
120
+ return opts
121
+
122
+ has_equal = opts[0].find("=") != -1
123
+
124
+ if has_equal:
125
+ return opts
126
+
127
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
128
+
129
+ def get_config(self):
130
+ return self.config
131
+
132
+ @property
133
+ def run_cfg(self):
134
+ return self.config.run
135
+
136
+ @property
137
+ def datasets_cfg(self):
138
+ return self.config.datasets
139
+
140
+ @property
141
+ def model_cfg(self):
142
+ return self.config.model
143
+
144
+ def pretty_print(self):
145
+ logging.info("\n===== Running Parameters =====")
146
+ logging.info(self._convert_node_to_json(self.config.run))
147
+
148
+ logging.info("\n====== Dataset Attributes ======")
149
+ datasets = self.config.datasets
150
+
151
+ for dataset in datasets:
152
+ if dataset in self.config.datasets:
153
+ logging.info(f"\n======== {dataset} =======")
154
+ dataset_config = self.config.datasets[dataset]
155
+ logging.info(self._convert_node_to_json(dataset_config))
156
+ else:
157
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
158
+
159
+ logging.info(f"\n====== Model Attributes ======")
160
+ logging.info(self._convert_node_to_json(self.config.model))
161
+
162
+ def _convert_node_to_json(self, node):
163
+ container = OmegaConf.to_container(node, resolve=True)
164
+ return json.dumps(container, indent=4, sort_keys=True)
165
+
166
+ def to_dict(self):
167
+ return OmegaConf.to_container(self.config)
168
+
169
+
170
+ def node_to_dict(node):
171
+ return OmegaConf.to_container(node)
172
+
173
+
174
+ class ConfigValidator:
175
+ """
176
+ This is a preliminary implementation to centralize and validate the configuration.
177
+ May be altered in the future.
178
+
179
+ A helper class to validate configurations from yaml file.
180
+
181
+ This serves the following purposes:
182
+ 1. Ensure all the options in the yaml are defined, raise error if not.
183
+ 2. when type mismatches are found, the validator will raise an error.
184
+ 3. a central place to store and display helpful messages for supported configurations.
185
+
186
+ """
187
+
188
+ class _Argument:
189
+ def __init__(self, name, choices=None, type=None, help=None):
190
+ self.name = name
191
+ self.val = None
192
+ self.choices = choices
193
+ self.type = type
194
+ self.help = help
195
+
196
+ def __str__(self):
197
+ s = f"{self.name}={self.val}"
198
+ if self.type is not None:
199
+ s += f", ({self.type})"
200
+ if self.choices is not None:
201
+ s += f", choices: {self.choices}"
202
+ if self.help is not None:
203
+ s += f", ({self.help})"
204
+ return s
205
+
206
+ def __init__(self, description):
207
+ self.description = description
208
+
209
+ self.arguments = dict()
210
+
211
+ self.parsed_args = None
212
+
213
+ def __getitem__(self, key):
214
+ assert self.parsed_args is not None, "No arguments parsed yet."
215
+
216
+ return self.parsed_args[key]
217
+
218
+ def __str__(self) -> str:
219
+ return self.format_help()
220
+
221
+ def add_argument(self, *args, **kwargs):
222
+ """
223
+ Assume the first argument is the name of the argument.
224
+ """
225
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
226
+
227
+ def validate(self, config=None):
228
+ """
229
+ Convert yaml config (dict-like) to list, required by argparse.
230
+ """
231
+ for k, v in config.items():
232
+ assert (
233
+ k in self.arguments
234
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
235
+
236
+ if self.arguments[k].type is not None:
237
+ try:
238
+ self.arguments[k].val = self.arguments[k].type(v)
239
+ except ValueError:
240
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
241
+
242
+ if self.arguments[k].choices is not None:
243
+ assert (
244
+ v in self.arguments[k].choices
245
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
246
+
247
+ return config
248
+
249
+ def format_arguments(self):
250
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
251
+
252
+ def format_help(self):
253
+ # description + key-value pair string for each argument
254
+ help_msg = str(self.description)
255
+ return help_msg + ", available arguments: " + self.format_arguments()
256
+
257
+ def print_help(self):
258
+ # display help message
259
+ print(self.format_help())
260
+
261
+
262
+ def create_runner_config_validator():
263
+ validator = ConfigValidator(description="Runner configurations")
264
+
265
+ validator.add_argument(
266
+ "runner",
267
+ type=str,
268
+ choices=["runner_base", "runner_iter"],
269
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
270
+ runner runs based on iters. Default: runner_base""",
271
+ )
272
+ # add argumetns for training dataset ratios
273
+ validator.add_argument(
274
+ "train_dataset_ratios",
275
+ type=Dict[str, float],
276
+ help="""Ratios of training dataset. This is used in iteration-based runner.
277
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
278
+ Default: None""",
279
+ )
280
+ validator.add_argument(
281
+ "max_iters",
282
+ type=float,
283
+ help="Maximum number of iterations to run.",
284
+ )
285
+ validator.add_argument(
286
+ "max_epoch",
287
+ type=int,
288
+ help="Maximum number of epochs to run.",
289
+ )
290
+ # add arguments for iters_per_inner_epoch
291
+ validator.add_argument(
292
+ "iters_per_inner_epoch",
293
+ type=float,
294
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
295
+ )
296
+ lr_scheds_choices = registry.list_lr_schedulers()
297
+ validator.add_argument(
298
+ "lr_sched",
299
+ type=str,
300
+ choices=lr_scheds_choices,
301
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
302
+ )
303
+ task_choices = registry.list_tasks()
304
+ validator.add_argument(
305
+ "task",
306
+ type=str,
307
+ choices=task_choices,
308
+ help="Task to use, from {}".format(task_choices),
309
+ )
310
+ # add arguments for init_lr
311
+ validator.add_argument(
312
+ "init_lr",
313
+ type=float,
314
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
315
+ )
316
+ # add arguments for min_lr
317
+ validator.add_argument(
318
+ "min_lr",
319
+ type=float,
320
+ help="Minimum learning rate (after decay).",
321
+ )
322
+ # add arguments for warmup_lr
323
+ validator.add_argument(
324
+ "warmup_lr",
325
+ type=float,
326
+ help="Starting learning rate for warmup.",
327
+ )
328
+ # add arguments for learning rate decay rate
329
+ validator.add_argument(
330
+ "lr_decay_rate",
331
+ type=float,
332
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
333
+ )
334
+ # add arguments for weight decay
335
+ validator.add_argument(
336
+ "weight_decay",
337
+ type=float,
338
+ help="Weight decay rate.",
339
+ )
340
+ # add arguments for training batch size
341
+ validator.add_argument(
342
+ "batch_size_train",
343
+ type=int,
344
+ help="Training batch size.",
345
+ )
346
+ # add arguments for evaluation batch size
347
+ validator.add_argument(
348
+ "batch_size_eval",
349
+ type=int,
350
+ help="Evaluation batch size, including validation and testing.",
351
+ )
352
+ # add arguments for number of workers for data loading
353
+ validator.add_argument(
354
+ "num_workers",
355
+ help="Number of workers for data loading.",
356
+ )
357
+ # add arguments for warm up steps
358
+ validator.add_argument(
359
+ "warmup_steps",
360
+ type=int,
361
+ help="Number of warmup steps. Required if a warmup schedule is used.",
362
+ )
363
+ # add arguments for random seed
364
+ validator.add_argument(
365
+ "seed",
366
+ type=int,
367
+ help="Random seed.",
368
+ )
369
+ # add arguments for output directory
370
+ validator.add_argument(
371
+ "output_dir",
372
+ type=str,
373
+ help="Output directory to save checkpoints and logs.",
374
+ )
375
+ # add arguments for whether only use evaluation
376
+ validator.add_argument(
377
+ "evaluate",
378
+ help="Whether to only evaluate the model. If true, training will not be performed.",
379
+ )
380
+ # add arguments for splits used for training, e.g. ["train", "val"]
381
+ validator.add_argument(
382
+ "train_splits",
383
+ type=list,
384
+ help="Splits to use for training.",
385
+ )
386
+ # add arguments for splits used for validation, e.g. ["val"]
387
+ validator.add_argument(
388
+ "valid_splits",
389
+ type=list,
390
+ help="Splits to use for validation. If not provided, will skip the validation.",
391
+ )
392
+ # add arguments for splits used for testing, e.g. ["test"]
393
+ validator.add_argument(
394
+ "test_splits",
395
+ type=list,
396
+ help="Splits to use for testing. If not provided, will skip the testing.",
397
+ )
398
+ # add arguments for accumulating gradient for iterations
399
+ validator.add_argument(
400
+ "accum_grad_iters",
401
+ type=int,
402
+ help="Number of iterations to accumulate gradient for.",
403
+ )
404
+
405
+ # ====== distributed training ======
406
+ validator.add_argument(
407
+ "device",
408
+ type=str,
409
+ choices=["cpu", "cuda"],
410
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
411
+ )
412
+ validator.add_argument(
413
+ "world_size",
414
+ type=int,
415
+ help="Number of processes participating in the job.",
416
+ )
417
+ validator.add_argument("dist_url", type=str)
418
+ validator.add_argument("distributed", type=bool)
419
+ # add arguments to opt using distributed sampler during evaluation or not
420
+ validator.add_argument(
421
+ "use_dist_eval_sampler",
422
+ type=bool,
423
+ help="Whether to use distributed sampler during evaluation or not.",
424
+ )
425
+
426
+ # ====== task specific ======
427
+ # generation task specific arguments
428
+ # add arguments for maximal length of text output
429
+ validator.add_argument(
430
+ "max_len",
431
+ type=int,
432
+ help="Maximal length of text output.",
433
+ )
434
+ # add arguments for minimal length of text output
435
+ validator.add_argument(
436
+ "min_len",
437
+ type=int,
438
+ help="Minimal length of text output.",
439
+ )
440
+ # add arguments number of beams
441
+ validator.add_argument(
442
+ "num_beams",
443
+ type=int,
444
+ help="Number of beams used for beam search.",
445
+ )
446
+
447
+ # vqa task specific arguments
448
+ # add arguments for number of answer candidates
449
+ validator.add_argument(
450
+ "num_ans_candidates",
451
+ type=int,
452
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
453
+ )
454
+ # add arguments for inference method
455
+ validator.add_argument(
456
+ "inference_method",
457
+ type=str,
458
+ choices=["genearte", "rank"],
459
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
460
+ )
461
+
462
+ # ====== model specific ======
463
+ validator.add_argument(
464
+ "k_test",
465
+ type=int,
466
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
467
+ )
468
+
469
+ return validator
bliva/common/dist_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
bliva/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
bliva/common/logger.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from bliva.common import dist_utils
17
+
18
+ class SmoothedValue(object):
19
+ """Track a series of values and provide access to smoothed values over a
20
+ window or the global series average.
21
+ """
22
+
23
+ def __init__(self, window_size=20, fmt=None):
24
+ if fmt is None:
25
+ fmt = "{median:.4f} ({global_avg:.4f})"
26
+ self.deque = deque(maxlen=window_size)
27
+ self.total = 0.0
28
+ self.count = 0
29
+ self.fmt = fmt
30
+
31
+ def update(self, value, n=1):
32
+ self.deque.append(value)
33
+ self.count += n
34
+ self.total += value * n
35
+
36
+ def synchronize_between_processes(self):
37
+ """
38
+ Warning: does not synchronize the deque!
39
+ """
40
+ if not dist_utils.is_dist_avail_and_initialized():
41
+ return
42
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
43
+ dist.barrier()
44
+ dist.all_reduce(t)
45
+ t = t.tolist()
46
+ self.count = int(t[0])
47
+ self.total = t[1]
48
+
49
+ @property
50
+ def median(self):
51
+ d = torch.tensor(list(self.deque))
52
+ return d.median().item()
53
+
54
+ @property
55
+ def avg(self):
56
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
57
+ return d.mean().item()
58
+
59
+ @property
60
+ def global_avg(self):
61
+ return self.total / self.count
62
+
63
+ @property
64
+ def max(self):
65
+ return max(self.deque)
66
+
67
+ @property
68
+ def value(self):
69
+ return self.deque[-1]
70
+
71
+ def __str__(self):
72
+ return self.fmt.format(
73
+ median=self.median,
74
+ avg=self.avg,
75
+ global_avg=self.global_avg,
76
+ max=self.max,
77
+ value=self.value,
78
+ )
79
+
80
+
81
+ class MetricLogger(object):
82
+ def __init__(self, delimiter="\t"):
83
+ self.meters = defaultdict(SmoothedValue)
84
+ self.delimiter = delimiter
85
+
86
+ def update(self, **kwargs):
87
+ for k, v in kwargs.items():
88
+ if isinstance(v, torch.Tensor):
89
+ v = v.item()
90
+ assert isinstance(v, (float, int))
91
+ self.meters[k].update(v)
92
+
93
+ def __getattr__(self, attr):
94
+ if attr in self.meters:
95
+ return self.meters[attr]
96
+ if attr in self.__dict__:
97
+ return self.__dict__[attr]
98
+ raise AttributeError(
99
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
100
+ )
101
+
102
+ def __str__(self):
103
+ loss_str = []
104
+ for name, meter in self.meters.items():
105
+ loss_str.append("{}: {}".format(name, str(meter)))
106
+ return self.delimiter.join(loss_str)
107
+
108
+ def global_avg(self):
109
+ loss_str = []
110
+ for name, meter in self.meters.items():
111
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
112
+ return self.delimiter.join(loss_str)
113
+
114
+ def synchronize_between_processes(self):
115
+ for meter in self.meters.values():
116
+ meter.synchronize_between_processes()
117
+
118
+ def add_meter(self, name, meter):
119
+ self.meters[name] = meter
120
+
121
+ def log_every(self, iterable, print_freq, header=None):
122
+ i = 0
123
+ if not header:
124
+ header = ""
125
+ start_time = time.time()
126
+ end = time.time()
127
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
128
+ data_time = SmoothedValue(fmt="{avg:.4f}")
129
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
130
+ log_msg = [
131
+ header,
132
+ "[{0" + space_fmt + "}/{1}]",
133
+ "eta: {eta}",
134
+ "{meters}",
135
+ "time: {time}",
136
+ "data: {data}",
137
+ ]
138
+ if torch.cuda.is_available():
139
+ log_msg.append("max mem: {memory:.0f}")
140
+ log_msg = self.delimiter.join(log_msg)
141
+ MB = 1024.0 * 1024.0
142
+ for obj in iterable:
143
+ data_time.update(time.time() - end)
144
+ yield obj
145
+ iter_time.update(time.time() - end)
146
+ if i % print_freq == 0 or i == len(iterable) - 1:
147
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
+ if torch.cuda.is_available():
150
+ print(
151
+ log_msg.format(
152
+ i,
153
+ len(iterable),
154
+ eta=eta_string,
155
+ meters=str(self),
156
+ time=str(iter_time),
157
+ data=str(data_time),
158
+ memory=torch.cuda.max_memory_allocated() / MB,
159
+ )
160
+ )
161
+ else:
162
+ print(
163
+ log_msg.format(
164
+ i,
165
+ len(iterable),
166
+ eta=eta_string,
167
+ meters=str(self),
168
+ time=str(iter_time),
169
+ data=str(data_time),
170
+ )
171
+ )
172
+ i += 1
173
+ end = time.time()
174
+ total_time = time.time() - start_time
175
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
176
+ print(
177
+ "{} Total time: {} ({:.4f} s / it)".format(
178
+ header, total_time_str, total_time / len(iterable)
179
+ )
180
+ )
181
+
182
+
183
+ class AttrDict(dict):
184
+ def __init__(self, *args, **kwargs):
185
+ super(AttrDict, self).__init__(*args, **kwargs)
186
+ self.__dict__ = self
187
+
188
+ def setup_logger():
189
+ logging.basicConfig(
190
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
191
+ format="%(asctime)s [%(levelname)s] %(message)s",
192
+ handlers=[logging.StreamHandler()],
193
+ )
bliva/common/optims.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from bliva.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ min_lr,
63
+ init_lr,
64
+ warmup_steps=0,
65
+ warmup_start_lr=-1,
66
+ **kwargs
67
+ ):
68
+ self.optimizer = optimizer
69
+
70
+ self.max_epoch = max_epoch
71
+ self.min_lr = min_lr
72
+
73
+ self.init_lr = init_lr
74
+ self.warmup_steps = warmup_steps
75
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
76
+
77
+ def step(self, cur_epoch, cur_step):
78
+ # assuming the warmup iters less than one epoch
79
+ if cur_epoch == 0:
80
+ warmup_lr_schedule(
81
+ step=cur_step,
82
+ optimizer=self.optimizer,
83
+ max_step=self.warmup_steps,
84
+ init_lr=self.warmup_start_lr,
85
+ max_lr=self.init_lr,
86
+ )
87
+ else:
88
+ cosine_lr_schedule(
89
+ epoch=cur_epoch,
90
+ optimizer=self.optimizer,
91
+ max_epoch=self.max_epoch,
92
+ init_lr=self.init_lr,
93
+ min_lr=self.min_lr,
94
+ )
95
+
96
+
97
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
98
+ """Decay the learning rate"""
99
+ lr = (init_lr - min_lr) * 0.5 * (
100
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
101
+ ) + min_lr
102
+ for param_group in optimizer.param_groups:
103
+ param_group["lr"] = lr
104
+
105
+
106
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
107
+ """Warmup the learning rate"""
108
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
109
+ for param_group in optimizer.param_groups:
110
+ param_group["lr"] = lr
111
+
112
+
113
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
114
+ """Decay the learning rate"""
115
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
116
+ for param_group in optimizer.param_groups:
117
+ param_group["lr"] = lr
bliva/common/registry.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_model(cls, name):
23
+ r"""Register a task to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the task will be registered.
27
+
28
+ Usage:
29
+
30
+ from bliva.common.registry import registry
31
+ """
32
+
33
+ def wrap(model_cls):
34
+ from bliva.models import BaseModel
35
+
36
+ assert issubclass(
37
+ model_cls, BaseModel
38
+ ), "All models must inherit BaseModel class"
39
+ if name in cls.mapping["model_name_mapping"]:
40
+ raise KeyError(
41
+ "Name '{}' already registered for {}.".format(
42
+ name, cls.mapping["model_name_mapping"][name]
43
+ )
44
+ )
45
+ cls.mapping["model_name_mapping"][name] = model_cls
46
+ return model_cls
47
+
48
+ return wrap
49
+
50
+ @classmethod
51
+ def register_processor(cls, name):
52
+ r"""Register a processor to registry with key 'name'
53
+
54
+ Args:
55
+ name: Key with which the task will be registered.
56
+
57
+ Usage:
58
+
59
+ from bliva.common.registry import registry
60
+ """
61
+
62
+ def wrap(processor_cls):
63
+ from bliva.processors import BaseProcessor
64
+
65
+ assert issubclass(
66
+ processor_cls, BaseProcessor
67
+ ), "All processors must inherit BaseProcessor class"
68
+ if name in cls.mapping["processor_name_mapping"]:
69
+ raise KeyError(
70
+ "Name '{}' already registered for {}.".format(
71
+ name, cls.mapping["processor_name_mapping"][name]
72
+ )
73
+ )
74
+ cls.mapping["processor_name_mapping"][name] = processor_cls
75
+ return processor_cls
76
+
77
+ return wrap
78
+
79
+ @classmethod
80
+ def register_lr_scheduler(cls, name):
81
+ r"""Register a model to registry with key 'name'
82
+
83
+ Args:
84
+ name: Key with which the task will be registered.
85
+
86
+ Usage:
87
+
88
+ from bliva.common.registry import registry
89
+ """
90
+
91
+ def wrap(lr_sched_cls):
92
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
93
+ raise KeyError(
94
+ "Name '{}' already registered for {}.".format(
95
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
96
+ )
97
+ )
98
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
99
+ return lr_sched_cls
100
+
101
+ return wrap
102
+
103
+ @classmethod
104
+ def register_runner(cls, name):
105
+ r"""Register a model to registry with key 'name'
106
+
107
+ Args:
108
+ name: Key with which the task will be registered.
109
+
110
+ Usage:
111
+
112
+ from bliva.common.registry import registry
113
+ """
114
+
115
+ def wrap(runner_cls):
116
+ if name in cls.mapping["runner_name_mapping"]:
117
+ raise KeyError(
118
+ "Name '{}' already registered for {}.".format(
119
+ name, cls.mapping["runner_name_mapping"][name]
120
+ )
121
+ )
122
+ cls.mapping["runner_name_mapping"][name] = runner_cls
123
+ return runner_cls
124
+
125
+ return wrap
126
+
127
+ @classmethod
128
+ def register_path(cls, name, path):
129
+ r"""Register a path to registry with key 'name'
130
+
131
+ Args:
132
+ name: Key with which the path will be registered.
133
+
134
+ Usage:
135
+
136
+ from bliva.common.registry import registry
137
+ """
138
+ assert isinstance(path, str), "All path must be str."
139
+ if name in cls.mapping["paths"]:
140
+ raise KeyError("Name '{}' already registered.".format(name))
141
+ cls.mapping["paths"][name] = path
142
+
143
+ @classmethod
144
+ def register(cls, name, obj):
145
+ r"""Register an item to registry with key 'name'
146
+
147
+ Args:
148
+ name: Key with which the item will be registered.
149
+
150
+ Usage::
151
+
152
+ from bliva.common.registry import registry
153
+
154
+ registry.register("config", {})
155
+ """
156
+ path = name.split(".")
157
+ current = cls.mapping["state"]
158
+
159
+ for part in path[:-1]:
160
+ if part not in current:
161
+ current[part] = {}
162
+ current = current[part]
163
+
164
+ current[path[-1]] = obj
165
+
166
+ # @classmethod
167
+ # def get_trainer_class(cls, name):
168
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
169
+
170
+ @classmethod
171
+ def get_builder_class(cls, name):
172
+ return cls.mapping["builder_name_mapping"].get(name, None)
173
+
174
+ @classmethod
175
+ def get_model_class(cls, name):
176
+ return cls.mapping["model_name_mapping"].get(name, None)
177
+
178
+ @classmethod
179
+ def get_task_class(cls, name):
180
+ return cls.mapping["task_name_mapping"].get(name, None)
181
+
182
+ @classmethod
183
+ def get_processor_class(cls, name):
184
+ return cls.mapping["processor_name_mapping"].get(name, None)
185
+
186
+ @classmethod
187
+ def get_lr_scheduler_class(cls, name):
188
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
189
+
190
+ @classmethod
191
+ def get_runner_class(cls, name):
192
+ return cls.mapping["runner_name_mapping"].get(name, None)
193
+
194
+ @classmethod
195
+ def list_runners(cls):
196
+ return sorted(cls.mapping["runner_name_mapping"].keys())
197
+
198
+ @classmethod
199
+ def list_models(cls):
200
+ return sorted(cls.mapping["model_name_mapping"].keys())
201
+
202
+ @classmethod
203
+ def list_tasks(cls):
204
+ return sorted(cls.mapping["task_name_mapping"].keys())
205
+
206
+ @classmethod
207
+ def list_processors(cls):
208
+ return sorted(cls.mapping["processor_name_mapping"].keys())
209
+
210
+ @classmethod
211
+ def list_lr_schedulers(cls):
212
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
213
+
214
+ @classmethod
215
+ def list_datasets(cls):
216
+ return sorted(cls.mapping["builder_name_mapping"].keys())
217
+
218
+ @classmethod
219
+ def get_path(cls, name):
220
+ return cls.mapping["paths"].get(name, None)
221
+
222
+ @classmethod
223
+ def get(cls, name, default=None, no_warning=False):
224
+ r"""Get an item from registry with key 'name'
225
+
226
+ Args:
227
+ name (string): Key whose value needs to be retrieved.
228
+ default: If passed and key is not in registry, default value will
229
+ be returned with a warning. Default: None
230
+ no_warning (bool): If passed as True, warning when key doesn't exist
231
+ will not be generated. Useful for MMF's
232
+ internal operations. Default: False
233
+ """
234
+ original_name = name
235
+ name = name.split(".")
236
+ value = cls.mapping["state"]
237
+ for subname in name:
238
+ value = value.get(subname, default)
239
+ if value is default:
240
+ break
241
+
242
+ if (
243
+ "writer" in cls.mapping["state"]
244
+ and value == default
245
+ and no_warning is False
246
+ ):
247
+ cls.mapping["state"]["writer"].warning(
248
+ "Key {} is not present in registry, returning default value "
249
+ "of {}".format(original_name, default)
250
+ )
251
+ return value
252
+
253
+ @classmethod
254
+ def unregister(cls, name):
255
+ r"""Remove an item from registry with key 'name'
256
+
257
+ Args:
258
+ name: Key which needs to be removed.
259
+ Usage::
260
+
261
+ from mmf.common.registry import registry
262
+
263
+ config = registry.unregister("config")
264
+ """
265
+ return cls.mapping["state"].pop(name, None)
266
+
267
+
268
+ registry = Registry()
bliva/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from bliva.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
bliva/configs/default.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ env:
3
+ # For default users
4
+ # cache_root: "cache"
5
+ # For internal use with persistent storage
6
+ cache_root: "~/.cache/bliva"
7
+
bliva/configs/models/bliva_flant5xxl.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ arch: flant5xxl
4
+ load_finetuned: True
5
+ load_pretrained: False
6
+
7
+ pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_flanxxl_trimmed.pth"
8
+ finetuned: ''
9
+
10
+ # vit encoder
11
+ image_size: 224
12
+ drop_path_rate: 0
13
+ use_grad_checkpoint: False
14
+ vit_precision: "fp16"
15
+ freeze_vit: True
16
+
17
+ # Q-Former
18
+ num_query_token: 32
19
+
20
+ # T5
21
+ t5_model: "google/flan-t5-xxl"
22
+
23
+ # generation configs
24
+ prompt: ""
25
+
26
+
27
+ preprocess:
28
+ vis_processor:
29
+ train:
30
+ name: "blip_image_train"
31
+ image_size: 224
32
+ eval:
33
+ name: "blip_image_eval"
34
+ image_size: 224
35
+ text_processor:
36
+ train:
37
+ name: "blip_caption"
38
+ eval:
39
+ name: "blip_caption"
bliva/configs/models/bliva_vicuna7b.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ arch: vicuna7b
4
+ load_finetuned: True
5
+ load_pretrained: False
6
+
7
+ pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth"
8
+ finetuned: 'bliva_vicuna7b.pth'
9
+
10
+ # vit encoder
11
+ image_size: 224 #336
12
+ drop_path_rate: 0
13
+ use_grad_checkpoint: False
14
+ vit_precision: "fp16"
15
+ freeze_vit: True
16
+
17
+ # Q-Former
18
+ num_query_token: 32
19
+
20
+ # path to Vicuna checkpoint
21
+ llm_model: "hf_vicuna_7b"
22
+
23
+ # generation configs
24
+ prompt: ""
25
+
26
+
27
+ preprocess:
28
+ vis_processor:
29
+ train:
30
+ name: "blip2_image_train"
31
+ image_size: 224
32
+ eval:
33
+ name: "blip_image_eval"
34
+ image_size: 224
35
+ text_processor:
36
+ train:
37
+ name: "blip_caption"
38
+ eval:
39
+ name: "blip_caption"
bliva/conversation/__init__.py ADDED
File without changes
bliva/conversation/conversation.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+
9
+ import dataclasses
10
+ from enum import auto, Enum
11
+ from typing import List, Tuple, Any
12
+
13
+
14
+
15
+ class SeparatorStyle(Enum):
16
+ """Different separator style."""
17
+ SINGLE = auto()
18
+ TWO = auto()
19
+ THREE = auto()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Conversation:
24
+ """A class that keeps all conversation history."""
25
+ system: str
26
+ roles: List[str]
27
+ messages: List[List[str]]
28
+ offset: int
29
+ # system_img: List[Image.Image] = []
30
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31
+ sep: str = "###"
32
+ sep2: str = None
33
+
34
+ skip_next: bool = False
35
+ conv_id: Any = None
36
+
37
+ def get_prompt(self):
38
+ if self.sep_style == SeparatorStyle.SINGLE:
39
+ ret = self.system + self.sep
40
+ for role, message in self.messages:
41
+ if message:
42
+ ret += role + ": " + message + self.sep
43
+ else:
44
+ ret += role + ":"
45
+ return ret
46
+ elif self.sep_style == SeparatorStyle.TWO:
47
+ seps = [self.sep, self.sep2]
48
+ ret = self.system + seps[0]
49
+ for i, (role, message) in enumerate(self.messages):
50
+ if message:
51
+ ret += role + ": " + message + seps[i % 2]
52
+ else:
53
+ ret += role + ":"
54
+ return ret
55
+ elif self.sep_style == SeparatorStyle.THREE:
56
+ ret = self.system
57
+ for i, (role, message) in enumerate(self.messages):
58
+ if message:
59
+ if type(message) == list:
60
+ message = message[0]
61
+ ret += role + ": " + message
62
+ else:
63
+ ret += role + ":"
64
+ return ret
65
+ else:
66
+ raise ValueError(f"Invalid style: {self.sep_style}")
67
+
68
+ def append_message(self, role, message):
69
+ self.messages.append([role, message])
70
+
71
+ def to_gradio_chatbot(self):
72
+ print('to_gradio_chatbot')
73
+ print(self.messages)
74
+ ret = []
75
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
76
+ if i % 2 == 0:
77
+ ret.append([msg, None])
78
+ else:
79
+ ret[-1][-1] = msg
80
+ return ret
81
+
82
+ def copy(self):
83
+ return Conversation(
84
+ system=self.system,
85
+ # system_img=self.system_img,
86
+ roles=self.roles,
87
+ messages=[[x, y] for x, y in self.messages],
88
+ offset=self.offset,
89
+ sep_style=self.sep_style,
90
+ sep=self.sep,
91
+ sep2=self.sep2,
92
+ conv_id=self.conv_id)
93
+
94
+ def dict(self):
95
+ return {
96
+ "system": self.system,
97
+ # "system_img": self.system_img,
98
+ "roles": self.roles,
99
+ "messages": self.messages,
100
+ "offset": self.offset,
101
+ "sep": self.sep,
102
+ "sep2": self.sep2,
103
+ "conv_id": self.conv_id,
104
+ }
105
+
106
+
107
+ class StoppingCriteriaSub(StoppingCriteria):
108
+
109
+ def __init__(self, stops=[], encounters=1):
110
+ super().__init__()
111
+ self.stops = stops
112
+
113
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
114
+ for stop in self.stops:
115
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
116
+ return True
117
+
118
+ return False
119
+
120
+
121
+ CONV_VISION = Conversation(
122
+ system="A chat between human who asks question and you give helpful, detailed, and insightful answers to his question.",
123
+ roles=(" Question", " Answer"),
124
+ messages=[],
125
+ offset=2,
126
+ sep_style=SeparatorStyle.THREE,
127
+ sep="###",
128
+ )
129
+
130
+ CONV_DIRECT= Conversation(
131
+ system="",
132
+ roles=("", ""),
133
+ messages=[],
134
+ offset=2,
135
+ sep_style=SeparatorStyle.THREE,
136
+ sep="###",
137
+ )
138
+
139
+ class Chat:
140
+ def __init__(self, model, vis_processor, device='cuda:0'):
141
+ self.device = device
142
+ self.model = model
143
+ self.vis_processor = vis_processor
144
+
145
+ def ask(self, text, conv):
146
+ #conv.messages = [] #hack not keeping history.
147
+ conv.append_message(conv.roles[0], text)
148
+
149
+ def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
150
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
151
+ conv.append_message(conv.roles[1], None)
152
+
153
+ question = conv.get_prompt()
154
+ image = img_list[0] #torch.stack(img_list).to(self.device)
155
+ output_text = self.model.generate({"image": image, "prompt": question}, num_beams=num_beams, temperature=temperature)
156
+
157
+ conv.messages[-1][1] = output_text
158
+ return output_text, ''
159
+
160
+ def upload_img(self, image, conv, img_list):
161
+ if isinstance(image, str): # is a image path
162
+ raw_image = Image.open(image).convert('RGB')
163
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
164
+ elif isinstance(image, Image.Image):
165
+ raw_image = image
166
+ raw_image = raw_image.convert('RGB')
167
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
168
+ elif isinstance(image, torch.Tensor):
169
+ if len(image.shape) == 3:
170
+ image = image.unsqueeze(0)
171
+ image = image.to(self.device)
172
+
173
+ #image_emb, _ = self.model.encode_img(image)
174
+ img_list.append(image)
175
+ #conv.append_message(conv.roles[0], "")
176
+ msg = "Received."
177
+ # self.conv.append_message(self.conv.roles[1], msg)
178
+ return msg
179
+
180
+
bliva/models/Qformer.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class BertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word and position embeddings."""
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(
57
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
58
+ )
59
+ self.position_embeddings = nn.Embedding(
60
+ config.max_position_embeddings, config.hidden_size
61
+ )
62
+
63
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
64
+ # any TensorFlow checkpoint file
65
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
67
+
68
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
69
+ self.register_buffer(
70
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
71
+ )
72
+ self.position_embedding_type = getattr(
73
+ config, "position_embedding_type", "absolute"
74
+ )
75
+
76
+ self.config = config
77
+
78
+ def forward(
79
+ self,
80
+ input_ids=None,
81
+ position_ids=None,
82
+ query_embeds=None,
83
+ past_key_values_length=0,
84
+ ):
85
+ if input_ids is not None:
86
+ seq_length = input_ids.size()[1]
87
+ else:
88
+ seq_length = 0
89
+
90
+ if position_ids is None:
91
+ position_ids = self.position_ids[
92
+ :, past_key_values_length : seq_length + past_key_values_length
93
+ ].clone()
94
+
95
+ if input_ids is not None:
96
+ embeddings = self.word_embeddings(input_ids)
97
+ if self.position_embedding_type == "absolute":
98
+ position_embeddings = self.position_embeddings(position_ids)
99
+ embeddings = embeddings + position_embeddings
100
+
101
+ if query_embeds is not None:
102
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
103
+ else:
104
+ embeddings = query_embeds
105
+
106
+ embeddings = self.LayerNorm(embeddings)
107
+ embeddings = self.dropout(embeddings)
108
+ return embeddings
109
+
110
+
111
+ class BertSelfAttention(nn.Module):
112
+ def __init__(self, config, is_cross_attention):
113
+ super().__init__()
114
+ self.config = config
115
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
116
+ config, "embedding_size"
117
+ ):
118
+ raise ValueError(
119
+ "The hidden size (%d) is not a multiple of the number of attention "
120
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
121
+ )
122
+
123
+ self.num_attention_heads = config.num_attention_heads
124
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
125
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
126
+
127
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
128
+ if is_cross_attention:
129
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
130
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
131
+ else:
132
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
133
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
134
+
135
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
136
+ self.position_embedding_type = getattr(
137
+ config, "position_embedding_type", "absolute"
138
+ )
139
+ if (
140
+ self.position_embedding_type == "relative_key"
141
+ or self.position_embedding_type == "relative_key_query"
142
+ ):
143
+ self.max_position_embeddings = config.max_position_embeddings
144
+ self.distance_embedding = nn.Embedding(
145
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
146
+ )
147
+ self.save_attention = False
148
+
149
+ def save_attn_gradients(self, attn_gradients):
150
+ self.attn_gradients = attn_gradients
151
+
152
+ def get_attn_gradients(self):
153
+ return self.attn_gradients
154
+
155
+ def save_attention_map(self, attention_map):
156
+ self.attention_map = attention_map
157
+
158
+ def get_attention_map(self):
159
+ return self.attention_map
160
+
161
+ def transpose_for_scores(self, x):
162
+ new_x_shape = x.size()[:-1] + (
163
+ self.num_attention_heads,
164
+ self.attention_head_size,
165
+ )
166
+ x = x.view(*new_x_shape)
167
+ return x.permute(0, 2, 1, 3)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states,
172
+ attention_mask=None,
173
+ head_mask=None,
174
+ encoder_hidden_states=None,
175
+ encoder_attention_mask=None,
176
+ past_key_value=None,
177
+ output_attentions=False,
178
+ ):
179
+
180
+ # If this is instantiated as a cross-attention module, the keys
181
+ # and values come from an encoder; the attention mask needs to be
182
+ # such that the encoder's padding tokens are not attended to.
183
+ is_cross_attention = encoder_hidden_states is not None
184
+
185
+ if is_cross_attention:
186
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
187
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
188
+ attention_mask = encoder_attention_mask
189
+ elif past_key_value is not None:
190
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
191
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
192
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
193
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
194
+ else:
195
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
196
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
197
+
198
+ mixed_query_layer = self.query(hidden_states)
199
+
200
+ query_layer = self.transpose_for_scores(mixed_query_layer)
201
+
202
+ past_key_value = (key_layer, value_layer)
203
+
204
+ # Take the dot product between "query" and "key" to get the raw attention scores.
205
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
206
+
207
+ if (
208
+ self.position_embedding_type == "relative_key"
209
+ or self.position_embedding_type == "relative_key_query"
210
+ ):
211
+ seq_length = hidden_states.size()[1]
212
+ position_ids_l = torch.arange(
213
+ seq_length, dtype=torch.long, device=hidden_states.device
214
+ ).view(-1, 1)
215
+ position_ids_r = torch.arange(
216
+ seq_length, dtype=torch.long, device=hidden_states.device
217
+ ).view(1, -1)
218
+ distance = position_ids_l - position_ids_r
219
+ positional_embedding = self.distance_embedding(
220
+ distance + self.max_position_embeddings - 1
221
+ )
222
+ positional_embedding = positional_embedding.to(
223
+ dtype=query_layer.dtype
224
+ ) # fp16 compatibility
225
+
226
+ if self.position_embedding_type == "relative_key":
227
+ relative_position_scores = torch.einsum(
228
+ "bhld,lrd->bhlr", query_layer, positional_embedding
229
+ )
230
+ attention_scores = attention_scores + relative_position_scores
231
+ elif self.position_embedding_type == "relative_key_query":
232
+ relative_position_scores_query = torch.einsum(
233
+ "bhld,lrd->bhlr", query_layer, positional_embedding
234
+ )
235
+ relative_position_scores_key = torch.einsum(
236
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
237
+ )
238
+ attention_scores = (
239
+ attention_scores
240
+ + relative_position_scores_query
241
+ + relative_position_scores_key
242
+ )
243
+
244
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
245
+ if attention_mask is not None:
246
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
247
+ attention_scores = attention_scores + attention_mask
248
+
249
+ # Normalize the attention scores to probabilities.
250
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
251
+
252
+ if is_cross_attention and self.save_attention:
253
+ self.save_attention_map(attention_probs)
254
+ attention_probs.register_hook(self.save_attn_gradients)
255
+
256
+ # This is actually dropping out entire tokens to attend to, which might
257
+ # seem a bit unusual, but is taken from the original Transformer paper.
258
+ attention_probs_dropped = self.dropout(attention_probs)
259
+
260
+ # Mask heads if we want to
261
+ if head_mask is not None:
262
+ attention_probs_dropped = attention_probs_dropped * head_mask
263
+
264
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
265
+
266
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
267
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
268
+ context_layer = context_layer.view(*new_context_layer_shape)
269
+
270
+ outputs = (
271
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
272
+ )
273
+
274
+ outputs = outputs + (past_key_value,)
275
+ return outputs
276
+
277
+
278
+ class BertSelfOutput(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
282
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
284
+
285
+ def forward(self, hidden_states, input_tensor):
286
+ hidden_states = self.dense(hidden_states)
287
+ hidden_states = self.dropout(hidden_states)
288
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
289
+ return hidden_states
290
+
291
+
292
+ class BertAttention(nn.Module):
293
+ def __init__(self, config, is_cross_attention=False):
294
+ super().__init__()
295
+ self.self = BertSelfAttention(config, is_cross_attention)
296
+ self.output = BertSelfOutput(config)
297
+ self.pruned_heads = set()
298
+
299
+ def prune_heads(self, heads):
300
+ if len(heads) == 0:
301
+ return
302
+ heads, index = find_pruneable_heads_and_indices(
303
+ heads,
304
+ self.self.num_attention_heads,
305
+ self.self.attention_head_size,
306
+ self.pruned_heads,
307
+ )
308
+
309
+ # Prune linear layers
310
+ self.self.query = prune_linear_layer(self.self.query, index)
311
+ self.self.key = prune_linear_layer(self.self.key, index)
312
+ self.self.value = prune_linear_layer(self.self.value, index)
313
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
314
+
315
+ # Update hyper params and store pruned heads
316
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
317
+ self.self.all_head_size = (
318
+ self.self.attention_head_size * self.self.num_attention_heads
319
+ )
320
+ self.pruned_heads = self.pruned_heads.union(heads)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states,
325
+ attention_mask=None,
326
+ head_mask=None,
327
+ encoder_hidden_states=None,
328
+ encoder_attention_mask=None,
329
+ past_key_value=None,
330
+ output_attentions=False,
331
+ ):
332
+ self_outputs = self.self(
333
+ hidden_states,
334
+ attention_mask,
335
+ head_mask,
336
+ encoder_hidden_states,
337
+ encoder_attention_mask,
338
+ past_key_value,
339
+ output_attentions,
340
+ )
341
+ attention_output = self.output(self_outputs[0], hidden_states)
342
+
343
+ outputs = (attention_output,) + self_outputs[
344
+ 1:
345
+ ] # add attentions if we output them
346
+ return outputs
347
+
348
+
349
+ class BertIntermediate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
353
+ if isinstance(config.hidden_act, str):
354
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
355
+ else:
356
+ self.intermediate_act_fn = config.hidden_act
357
+
358
+ def forward(self, hidden_states):
359
+ hidden_states = self.dense(hidden_states)
360
+ hidden_states = self.intermediate_act_fn(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ class BertOutput(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
368
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
369
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
370
+
371
+ def forward(self, hidden_states, input_tensor):
372
+ hidden_states = self.dense(hidden_states)
373
+ hidden_states = self.dropout(hidden_states)
374
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
375
+ return hidden_states
376
+
377
+
378
+ class BertLayer(nn.Module):
379
+ def __init__(self, config, layer_num):
380
+ super().__init__()
381
+ self.config = config
382
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
383
+ self.seq_len_dim = 1
384
+ self.attention = BertAttention(config)
385
+ self.layer_num = layer_num
386
+ if (
387
+ self.config.add_cross_attention
388
+ and layer_num % self.config.cross_attention_freq == 0
389
+ ):
390
+ self.crossattention = BertAttention(
391
+ config, is_cross_attention=self.config.add_cross_attention
392
+ )
393
+ self.has_cross_attention = True
394
+ else:
395
+ self.has_cross_attention = False
396
+ self.intermediate = BertIntermediate(config)
397
+ self.output = BertOutput(config)
398
+
399
+ self.intermediate_query = BertIntermediate(config)
400
+ self.output_query = BertOutput(config)
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states,
405
+ attention_mask=None,
406
+ head_mask=None,
407
+ encoder_hidden_states=None,
408
+ encoder_attention_mask=None,
409
+ past_key_value=None,
410
+ output_attentions=False,
411
+ query_length=0,
412
+ ):
413
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
414
+ self_attn_past_key_value = (
415
+ past_key_value[:2] if past_key_value is not None else None
416
+ )
417
+ self_attention_outputs = self.attention(
418
+ hidden_states,
419
+ attention_mask,
420
+ head_mask,
421
+ output_attentions=output_attentions,
422
+ past_key_value=self_attn_past_key_value,
423
+ )
424
+ attention_output = self_attention_outputs[0]
425
+ outputs = self_attention_outputs[1:-1]
426
+
427
+ present_key_value = self_attention_outputs[-1]
428
+
429
+ if query_length > 0:
430
+ query_attention_output = attention_output[:, :query_length, :]
431
+
432
+ if self.has_cross_attention:
433
+ assert (
434
+ encoder_hidden_states is not None
435
+ ), "encoder_hidden_states must be given for cross-attention layers"
436
+ cross_attention_outputs = self.crossattention(
437
+ query_attention_output,
438
+ attention_mask,
439
+ head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ output_attentions=output_attentions,
443
+ )
444
+ query_attention_output = cross_attention_outputs[0]
445
+ outputs = (
446
+ outputs + cross_attention_outputs[1:-1]
447
+ ) # add cross attentions if we output attention weights
448
+
449
+ layer_output = apply_chunking_to_forward(
450
+ self.feed_forward_chunk_query,
451
+ self.chunk_size_feed_forward,
452
+ self.seq_len_dim,
453
+ query_attention_output,
454
+ )
455
+ if attention_output.shape[1] > query_length:
456
+ layer_output_text = apply_chunking_to_forward(
457
+ self.feed_forward_chunk,
458
+ self.chunk_size_feed_forward,
459
+ self.seq_len_dim,
460
+ attention_output[:, query_length:, :],
461
+ )
462
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
463
+ else:
464
+ layer_output = apply_chunking_to_forward(
465
+ self.feed_forward_chunk,
466
+ self.chunk_size_feed_forward,
467
+ self.seq_len_dim,
468
+ attention_output,
469
+ )
470
+ outputs = (layer_output,) + outputs
471
+
472
+ outputs = outputs + (present_key_value,)
473
+
474
+ return outputs
475
+
476
+ def feed_forward_chunk(self, attention_output):
477
+ intermediate_output = self.intermediate(attention_output)
478
+ layer_output = self.output(intermediate_output, attention_output)
479
+ return layer_output
480
+
481
+ def feed_forward_chunk_query(self, attention_output):
482
+ intermediate_output = self.intermediate_query(attention_output)
483
+ layer_output = self.output_query(intermediate_output, attention_output)
484
+ return layer_output
485
+
486
+
487
+ class BertEncoder(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.config = config
491
+ self.layer = nn.ModuleList(
492
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
493
+ )
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states,
498
+ attention_mask=None,
499
+ head_mask=None,
500
+ encoder_hidden_states=None,
501
+ encoder_attention_mask=None,
502
+ past_key_values=None,
503
+ use_cache=None,
504
+ output_attentions=False,
505
+ output_hidden_states=False,
506
+ return_dict=True,
507
+ query_length=0,
508
+ ):
509
+ all_hidden_states = () if output_hidden_states else None
510
+ all_self_attentions = () if output_attentions else None
511
+ all_cross_attentions = (
512
+ () if output_attentions and self.config.add_cross_attention else None
513
+ )
514
+
515
+ next_decoder_cache = () if use_cache else None
516
+
517
+ for i in range(self.config.num_hidden_layers):
518
+ layer_module = self.layer[i]
519
+ if output_hidden_states:
520
+ all_hidden_states = all_hidden_states + (hidden_states,)
521
+
522
+ layer_head_mask = head_mask[i] if head_mask is not None else None
523
+ past_key_value = past_key_values[i] if past_key_values is not None else None
524
+
525
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
526
+
527
+ if use_cache:
528
+ logger.warn(
529
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
+ )
531
+ use_cache = False
532
+
533
+ def create_custom_forward(module):
534
+ def custom_forward(*inputs):
535
+ return module(
536
+ *inputs, past_key_value, output_attentions, query_length
537
+ )
538
+
539
+ return custom_forward
540
+
541
+ layer_outputs = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(layer_module),
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ )
549
+ else:
550
+ layer_outputs = layer_module(
551
+ hidden_states,
552
+ attention_mask,
553
+ layer_head_mask,
554
+ encoder_hidden_states,
555
+ encoder_attention_mask,
556
+ past_key_value,
557
+ output_attentions,
558
+ query_length,
559
+ )
560
+
561
+ hidden_states = layer_outputs[0]
562
+ if use_cache:
563
+ next_decoder_cache += (layer_outputs[-1],)
564
+ if output_attentions:
565
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
566
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
567
+
568
+ if output_hidden_states:
569
+ all_hidden_states = all_hidden_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(
573
+ v
574
+ for v in [
575
+ hidden_states,
576
+ next_decoder_cache,
577
+ all_hidden_states,
578
+ all_self_attentions,
579
+ all_cross_attentions,
580
+ ]
581
+ if v is not None
582
+ )
583
+ return BaseModelOutputWithPastAndCrossAttentions(
584
+ last_hidden_state=hidden_states,
585
+ past_key_values=next_decoder_cache,
586
+ hidden_states=all_hidden_states,
587
+ attentions=all_self_attentions,
588
+ cross_attentions=all_cross_attentions,
589
+ )
590
+
591
+
592
+ class BertPooler(nn.Module):
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
596
+ self.activation = nn.Tanh()
597
+
598
+ def forward(self, hidden_states):
599
+ # We "pool" the model by simply taking the hidden state corresponding
600
+ # to the first token.
601
+ first_token_tensor = hidden_states[:, 0]
602
+ pooled_output = self.dense(first_token_tensor)
603
+ pooled_output = self.activation(pooled_output)
604
+ return pooled_output
605
+
606
+
607
+ class BertPredictionHeadTransform(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
611
+ if isinstance(config.hidden_act, str):
612
+ self.transform_act_fn = ACT2FN[config.hidden_act]
613
+ else:
614
+ self.transform_act_fn = config.hidden_act
615
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
616
+
617
+ def forward(self, hidden_states):
618
+ hidden_states = self.dense(hidden_states)
619
+ hidden_states = self.transform_act_fn(hidden_states)
620
+ hidden_states = self.LayerNorm(hidden_states)
621
+ return hidden_states
622
+
623
+
624
+ class BertLMPredictionHead(nn.Module):
625
+ def __init__(self, config):
626
+ super().__init__()
627
+ self.transform = BertPredictionHeadTransform(config)
628
+
629
+ # The output weights are the same as the input embeddings, but there is
630
+ # an output-only bias for each token.
631
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632
+
633
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
634
+
635
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
636
+ self.decoder.bias = self.bias
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.transform(hidden_states)
640
+ hidden_states = self.decoder(hidden_states)
641
+ return hidden_states
642
+
643
+
644
+ class BertOnlyMLMHead(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.predictions = BertLMPredictionHead(config)
648
+
649
+ def forward(self, sequence_output):
650
+ prediction_scores = self.predictions(sequence_output)
651
+ return prediction_scores
652
+
653
+
654
+ class BertPreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+
660
+ config_class = BertConfig
661
+ base_model_prefix = "bert"
662
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
663
+
664
+ def _init_weights(self, module):
665
+ """Initialize the weights"""
666
+ if isinstance(module, (nn.Linear, nn.Embedding)):
667
+ # Slightly different from the TF version which uses truncated_normal for initialization
668
+ # cf https://github.com/pytorch/pytorch/pull/5617
669
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
670
+ elif isinstance(module, nn.LayerNorm):
671
+ module.bias.data.zero_()
672
+ module.weight.data.fill_(1.0)
673
+ if isinstance(module, nn.Linear) and module.bias is not None:
674
+ module.bias.data.zero_()
675
+
676
+
677
+ class BertModel(BertPreTrainedModel):
678
+ """
679
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
680
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
681
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
682
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
683
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
684
+ input to the forward pass.
685
+ """
686
+
687
+ def __init__(self, config, add_pooling_layer=False):
688
+ super().__init__(config)
689
+ self.config = config
690
+
691
+ self.embeddings = BertEmbeddings(config)
692
+
693
+ self.encoder = BertEncoder(config)
694
+
695
+ self.pooler = BertPooler(config) if add_pooling_layer else None
696
+
697
+ self.init_weights()
698
+
699
+ def get_input_embeddings(self):
700
+ return self.embeddings.word_embeddings
701
+
702
+ def set_input_embeddings(self, value):
703
+ self.embeddings.word_embeddings = value
704
+
705
+ def _prune_heads(self, heads_to_prune):
706
+ """
707
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
708
+ class PreTrainedModel
709
+ """
710
+ for layer, heads in heads_to_prune.items():
711
+ self.encoder.layer[layer].attention.prune_heads(heads)
712
+
713
+ def get_extended_attention_mask(
714
+ self,
715
+ attention_mask: Tensor,
716
+ input_shape: Tuple[int],
717
+ device: device,
718
+ is_decoder: bool,
719
+ has_query: bool = False,
720
+ ) -> Tensor:
721
+ """
722
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
723
+
724
+ Arguments:
725
+ attention_mask (:obj:`torch.Tensor`):
726
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
727
+ input_shape (:obj:`Tuple[int]`):
728
+ The shape of the input to the model.
729
+ device: (:obj:`torch.device`):
730
+ The device of the input to the model.
731
+
732
+ Returns:
733
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
734
+ """
735
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
736
+ # ourselves in which case we just need to make it broadcastable to all heads.
737
+ if attention_mask.dim() == 3:
738
+ extended_attention_mask = attention_mask[:, None, :, :]
739
+ elif attention_mask.dim() == 2:
740
+ # Provided a padding mask of dimensions [batch_size, seq_length]
741
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
742
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
743
+ if is_decoder:
744
+ batch_size, seq_length = input_shape
745
+
746
+ seq_ids = torch.arange(seq_length, device=device)
747
+ causal_mask = (
748
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
749
+ <= seq_ids[None, :, None]
750
+ )
751
+
752
+ # add a prefix ones mask to the causal mask
753
+ # causal and attention masks must have same type with pytorch version < 1.3
754
+ causal_mask = causal_mask.to(attention_mask.dtype)
755
+
756
+ if causal_mask.shape[1] < attention_mask.shape[1]:
757
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
758
+ if has_query: # UniLM style attention mask
759
+ causal_mask = torch.cat(
760
+ [
761
+ torch.zeros(
762
+ (batch_size, prefix_seq_len, seq_length),
763
+ device=device,
764
+ dtype=causal_mask.dtype,
765
+ ),
766
+ causal_mask,
767
+ ],
768
+ axis=1,
769
+ )
770
+ causal_mask = torch.cat(
771
+ [
772
+ torch.ones(
773
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
774
+ device=device,
775
+ dtype=causal_mask.dtype,
776
+ ),
777
+ causal_mask,
778
+ ],
779
+ axis=-1,
780
+ )
781
+ extended_attention_mask = (
782
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
783
+ )
784
+ else:
785
+ extended_attention_mask = attention_mask[:, None, None, :]
786
+ else:
787
+ raise ValueError(
788
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
789
+ input_shape, attention_mask.shape
790
+ )
791
+ )
792
+
793
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
794
+ # masked positions, this operation will create a tensor which is 0.0 for
795
+ # positions we want to attend and -10000.0 for masked positions.
796
+ # Since we are adding it to the raw scores before the softmax, this is
797
+ # effectively the same as removing these entirely.
798
+ extended_attention_mask = extended_attention_mask.to(
799
+ dtype=self.dtype
800
+ ) # fp16 compatibility
801
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
802
+ return extended_attention_mask
803
+
804
+ def forward(
805
+ self,
806
+ input_ids=None,
807
+ attention_mask=None,
808
+ position_ids=None,
809
+ head_mask=None,
810
+ query_embeds=None,
811
+ encoder_hidden_states=None,
812
+ encoder_attention_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=None,
816
+ output_hidden_states=None,
817
+ return_dict=None,
818
+ is_decoder=False,
819
+ ):
820
+ r"""
821
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
822
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
823
+ the model is configured as a decoder.
824
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
825
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
826
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
827
+ - 1 for tokens that are **not masked**,
828
+ - 0 for tokens that are **masked**.
829
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
830
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
831
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
832
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
833
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
834
+ use_cache (:obj:`bool`, `optional`):
835
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
836
+ decoding (see :obj:`past_key_values`).
837
+ """
838
+ output_attentions = (
839
+ output_attentions
840
+ if output_attentions is not None
841
+ else self.config.output_attentions
842
+ )
843
+ output_hidden_states = (
844
+ output_hidden_states
845
+ if output_hidden_states is not None
846
+ else self.config.output_hidden_states
847
+ )
848
+ return_dict = (
849
+ return_dict if return_dict is not None else self.config.use_return_dict
850
+ )
851
+
852
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
853
+
854
+ if input_ids is None:
855
+ assert (
856
+ query_embeds is not None
857
+ ), "You have to specify query_embeds when input_ids is None"
858
+
859
+ # past_key_values_length
860
+ past_key_values_length = (
861
+ past_key_values[0][0].shape[2] - self.config.query_length
862
+ if past_key_values is not None
863
+ else 0
864
+ )
865
+
866
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
867
+
868
+ embedding_output = self.embeddings(
869
+ input_ids=input_ids,
870
+ position_ids=position_ids,
871
+ query_embeds=query_embeds,
872
+ past_key_values_length=past_key_values_length,
873
+ )
874
+
875
+ input_shape = embedding_output.size()[:-1]
876
+ batch_size, seq_length = input_shape
877
+ device = embedding_output.device
878
+
879
+ if attention_mask is None:
880
+ attention_mask = torch.ones(
881
+ ((batch_size, seq_length + past_key_values_length)), device=device
882
+ )
883
+
884
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
885
+ # ourselves in which case we just need to make it broadcastable to all heads.
886
+ if is_decoder:
887
+ extended_attention_mask = self.get_extended_attention_mask(
888
+ attention_mask,
889
+ input_ids.shape,
890
+ device,
891
+ is_decoder,
892
+ has_query=(query_embeds is not None),
893
+ )
894
+ else:
895
+ extended_attention_mask = self.get_extended_attention_mask(
896
+ attention_mask, input_shape, device, is_decoder
897
+ )
898
+
899
+ # If a 2D or 3D attention mask is provided for the cross-attention
900
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
901
+ if encoder_hidden_states is not None:
902
+ if type(encoder_hidden_states) == list:
903
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
904
+ 0
905
+ ].size()
906
+ else:
907
+ (
908
+ encoder_batch_size,
909
+ encoder_sequence_length,
910
+ _,
911
+ ) = encoder_hidden_states.size()
912
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
913
+
914
+ if type(encoder_attention_mask) == list:
915
+ encoder_extended_attention_mask = [
916
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
917
+ ]
918
+ elif encoder_attention_mask is None:
919
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
920
+ encoder_extended_attention_mask = self.invert_attention_mask(
921
+ encoder_attention_mask
922
+ )
923
+ else:
924
+ encoder_extended_attention_mask = self.invert_attention_mask(
925
+ encoder_attention_mask
926
+ )
927
+ else:
928
+ encoder_extended_attention_mask = None
929
+
930
+ # Prepare head mask if needed
931
+ # 1.0 in head_mask indicate we keep the head
932
+ # attention_probs has shape bsz x n_heads x N x N
933
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
934
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
935
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
936
+
937
+ encoder_outputs = self.encoder(
938
+ embedding_output,
939
+ attention_mask=extended_attention_mask,
940
+ head_mask=head_mask,
941
+ encoder_hidden_states=encoder_hidden_states,
942
+ encoder_attention_mask=encoder_extended_attention_mask,
943
+ past_key_values=past_key_values,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ query_length=query_length,
949
+ )
950
+ sequence_output = encoder_outputs[0]
951
+ pooled_output = (
952
+ self.pooler(sequence_output) if self.pooler is not None else None
953
+ )
954
+
955
+ if not return_dict:
956
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
957
+
958
+ return BaseModelOutputWithPoolingAndCrossAttentions(
959
+ last_hidden_state=sequence_output,
960
+ pooler_output=pooled_output,
961
+ past_key_values=encoder_outputs.past_key_values,
962
+ hidden_states=encoder_outputs.hidden_states,
963
+ attentions=encoder_outputs.attentions,
964
+ cross_attentions=encoder_outputs.cross_attentions,
965
+ )
966
+
967
+
968
+ class BertLMHeadModel(BertPreTrainedModel):
969
+
970
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
971
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
972
+
973
+ def __init__(self, config):
974
+ super().__init__(config)
975
+
976
+ self.bert = BertModel(config, add_pooling_layer=False)
977
+ self.cls = BertOnlyMLMHead(config)
978
+
979
+ self.init_weights()
980
+
981
+ def get_output_embeddings(self):
982
+ return self.cls.predictions.decoder
983
+
984
+ def set_output_embeddings(self, new_embeddings):
985
+ self.cls.predictions.decoder = new_embeddings
986
+
987
+ def forward(
988
+ self,
989
+ input_ids=None,
990
+ attention_mask=None,
991
+ position_ids=None,
992
+ head_mask=None,
993
+ query_embeds=None,
994
+ encoder_hidden_states=None,
995
+ encoder_attention_mask=None,
996
+ labels=None,
997
+ past_key_values=None,
998
+ use_cache=True,
999
+ output_attentions=None,
1000
+ output_hidden_states=None,
1001
+ return_dict=None,
1002
+ return_logits=False,
1003
+ is_decoder=True,
1004
+ reduction="mean",
1005
+ ):
1006
+ r"""
1007
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1008
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1009
+ the model is configured as a decoder.
1010
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1011
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1012
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1013
+ - 1 for tokens that are **not masked**,
1014
+ - 0 for tokens that are **masked**.
1015
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1016
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1017
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1018
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1019
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1020
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1021
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1022
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1023
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1024
+ use_cache (:obj:`bool`, `optional`):
1025
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1026
+ decoding (see :obj:`past_key_values`).
1027
+ Returns:
1028
+ Example::
1029
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1030
+ >>> import torch
1031
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1032
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1033
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1034
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1035
+ >>> outputs = model(**inputs)
1036
+ >>> prediction_logits = outputs.logits
1037
+ """
1038
+ return_dict = (
1039
+ return_dict if return_dict is not None else self.config.use_return_dict
1040
+ )
1041
+ if labels is not None:
1042
+ use_cache = False
1043
+ if past_key_values is not None:
1044
+ query_embeds = None
1045
+
1046
+ outputs = self.bert(
1047
+ input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ head_mask=head_mask,
1051
+ query_embeds=query_embeds,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ encoder_attention_mask=encoder_attention_mask,
1054
+ past_key_values=past_key_values,
1055
+ use_cache=use_cache,
1056
+ output_attentions=output_attentions,
1057
+ output_hidden_states=output_hidden_states,
1058
+ return_dict=return_dict,
1059
+ is_decoder=is_decoder,
1060
+ )
1061
+
1062
+ sequence_output = outputs[0]
1063
+ if query_embeds is not None:
1064
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1065
+
1066
+ prediction_scores = self.cls(sequence_output)
1067
+
1068
+ if return_logits:
1069
+ return prediction_scores[:, :-1, :].contiguous()
1070
+
1071
+ lm_loss = None
1072
+ if labels is not None:
1073
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1074
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1075
+ labels = labels[:, 1:].contiguous()
1076
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1077
+ lm_loss = loss_fct(
1078
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1079
+ labels.view(-1),
1080
+ )
1081
+ if reduction == "none":
1082
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1083
+
1084
+ if not return_dict:
1085
+ output = (prediction_scores,) + outputs[2:]
1086
+ return ((lm_loss,) + output) if lm_loss is not None else output
1087
+
1088
+ return CausalLMOutputWithCrossAttentions(
1089
+ loss=lm_loss,
1090
+ logits=prediction_scores,
1091
+ past_key_values=outputs.past_key_values,
1092
+ hidden_states=outputs.hidden_states,
1093
+ attentions=outputs.attentions,
1094
+ cross_attentions=outputs.cross_attentions,
1095
+ )
1096
+
1097
+ def prepare_inputs_for_generation(
1098
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1099
+ ):
1100
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1101
+ if attention_mask is None:
1102
+ attention_mask = input_ids.new_ones(input_ids.shape)
1103
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1104
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1105
+
1106
+ # cut decoder_input_ids if past is used
1107
+ if past is not None:
1108
+ input_ids = input_ids[:, -1:]
1109
+
1110
+ return {
1111
+ "input_ids": input_ids,
1112
+ "query_embeds": query_embeds,
1113
+ "attention_mask": attention_mask,
1114
+ "past_key_values": past,
1115
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1116
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1117
+ "is_decoder": True,
1118
+ }
1119
+
1120
+ def _reorder_cache(self, past, beam_idx):
1121
+ reordered_past = ()
1122
+ for layer_past in past:
1123
+ reordered_past += (
1124
+ tuple(
1125
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1126
+ ),
1127
+ )
1128
+ return reordered_past
1129
+
1130
+
1131
+ class BertForMaskedLM(BertPreTrainedModel):
1132
+
1133
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1134
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1135
+
1136
+ def __init__(self, config):
1137
+ super().__init__(config)
1138
+
1139
+ self.bert = BertModel(config, add_pooling_layer=False)
1140
+ self.cls = BertOnlyMLMHead(config)
1141
+
1142
+ self.init_weights()
1143
+
1144
+ def get_output_embeddings(self):
1145
+ return self.cls.predictions.decoder
1146
+
1147
+ def set_output_embeddings(self, new_embeddings):
1148
+ self.cls.predictions.decoder = new_embeddings
1149
+
1150
+ def forward(
1151
+ self,
1152
+ input_ids=None,
1153
+ attention_mask=None,
1154
+ position_ids=None,
1155
+ head_mask=None,
1156
+ query_embeds=None,
1157
+ encoder_hidden_states=None,
1158
+ encoder_attention_mask=None,
1159
+ labels=None,
1160
+ output_attentions=None,
1161
+ output_hidden_states=None,
1162
+ return_dict=None,
1163
+ return_logits=False,
1164
+ is_decoder=False,
1165
+ ):
1166
+ r"""
1167
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1168
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1169
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1170
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1171
+ """
1172
+
1173
+ return_dict = (
1174
+ return_dict if return_dict is not None else self.config.use_return_dict
1175
+ )
1176
+
1177
+ outputs = self.bert(
1178
+ input_ids,
1179
+ attention_mask=attention_mask,
1180
+ position_ids=position_ids,
1181
+ head_mask=head_mask,
1182
+ query_embeds=query_embeds,
1183
+ encoder_hidden_states=encoder_hidden_states,
1184
+ encoder_attention_mask=encoder_attention_mask,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ is_decoder=is_decoder,
1189
+ )
1190
+
1191
+ if query_embeds is not None:
1192
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1193
+ prediction_scores = self.cls(sequence_output)
1194
+
1195
+ if return_logits:
1196
+ return prediction_scores
1197
+
1198
+ masked_lm_loss = None
1199
+ if labels is not None:
1200
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1201
+ masked_lm_loss = loss_fct(
1202
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1203
+ )
1204
+
1205
+ if not return_dict:
1206
+ output = (prediction_scores,) + outputs[2:]
1207
+ return (
1208
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1209
+ )
1210
+
1211
+ return MaskedLMOutput(
1212
+ loss=masked_lm_loss,
1213
+ logits=prediction_scores,
1214
+ hidden_states=outputs.hidden_states,
1215
+ attentions=outputs.attentions,
1216
+ )
bliva/models/__init__.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+ from bliva.common.registry import registry
12
+
13
+ from bliva.models.base_model import BaseModel
14
+
15
+ from bliva.models.blip2 import Blip2Base
16
+
17
+ from bliva.models.bliva_flant5xxl import BLIVAFlanT5
18
+ from bliva.models.bliva_vicuna7b import BLIVAVicuna
19
+
20
+ from bliva.models.vit import VisionTransformerEncoder
21
+
22
+
23
+ from bliva.processors.base_processor import BaseProcessor
24
+
25
+
26
+ __all__ = [
27
+ "load_model",
28
+ "BaseModel",
29
+ "Blip2Base",
30
+ "BLIVAFlanT5",
31
+ "BLIVAVicuna",
32
+ ]
33
+
34
+
35
+ def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
36
+ """
37
+ Load supported models.
38
+
39
+ To list all available models and types in registry:
40
+ >>> from bliva.models import model_zoo
41
+ >>> print(model_zoo)
42
+
43
+ Args:
44
+ name (str): name of the model.
45
+ model_type (str): type of the model.
46
+ is_eval (bool): whether the model is in eval mode. Default: False.
47
+ device (str): device to use. Default: "cpu".
48
+ checkpoint (str): path or to checkpoint. Default: None.
49
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
50
+
51
+ Returns:
52
+ model (torch.nn.Module): model.
53
+ """
54
+
55
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
56
+
57
+ if checkpoint is not None:
58
+ model.load_checkpoint(checkpoint)
59
+
60
+ if is_eval:
61
+ model.eval()
62
+
63
+ if device == "cpu":
64
+ model = model.float()
65
+
66
+ return model.to(device)
67
+
68
+
69
+ def load_preprocess(config):
70
+ """
71
+ Load preprocessor configs and construct preprocessors.
72
+
73
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
74
+
75
+ Args:
76
+ config (dict): preprocessor configs.
77
+
78
+ Returns:
79
+ vis_processors (dict): preprocessors for visual inputs.
80
+ txt_processors (dict): preprocessors for text inputs.
81
+
82
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
83
+ """
84
+
85
+ def _build_proc_from_cfg(cfg):
86
+ return (
87
+ registry.get_processor_class(cfg.name).from_config(cfg)
88
+ if cfg is not None
89
+ else BaseProcessor()
90
+ )
91
+
92
+ vis_processors = dict()
93
+ txt_processors = dict()
94
+
95
+ vis_proc_cfg = config.get("vis_processor")
96
+ txt_proc_cfg = config.get("text_processor")
97
+
98
+ if vis_proc_cfg is not None:
99
+ vis_train_cfg = vis_proc_cfg.get("train")
100
+ vis_eval_cfg = vis_proc_cfg.get("eval")
101
+ else:
102
+ vis_train_cfg = None
103
+ vis_eval_cfg = None
104
+
105
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
106
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
107
+
108
+ if txt_proc_cfg is not None:
109
+ txt_train_cfg = txt_proc_cfg.get("train")
110
+ txt_eval_cfg = txt_proc_cfg.get("eval")
111
+ else:
112
+ txt_train_cfg = None
113
+ txt_eval_cfg = None
114
+
115
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
116
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
117
+
118
+ return vis_processors, txt_processors
119
+
120
+
121
+ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
122
+ """
123
+ Load model and its related preprocessors.
124
+
125
+ List all available models and types in registry:
126
+ >>> from bliva.models import model_zoo
127
+ >>> print(model_zoo)
128
+
129
+ Args:
130
+ name (str): name of the model.
131
+ model_type (str): type of the model.
132
+ is_eval (bool): whether the model is in eval mode. Default: False.
133
+ device (str): device to use. Default: "cpu".
134
+
135
+ Returns:
136
+ model (torch.nn.Module): model.
137
+ vis_processors (dict): preprocessors for visual inputs.
138
+ txt_processors (dict): preprocessors for text inputs.
139
+ """
140
+ model_cls = registry.get_model_class(name)
141
+
142
+ # load model
143
+ model = model_cls.from_pretrained(model_type=model_type)
144
+
145
+ if is_eval:
146
+ model.eval()
147
+
148
+ # load preprocess
149
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
150
+ if cfg is not None:
151
+ preprocess_cfg = cfg.preprocess
152
+
153
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
154
+ else:
155
+ vis_processors, txt_processors = None, None
156
+ logging.info(
157
+ f"""No default preprocess for model {name} ({model_type}).
158
+ This can happen if the model is not finetuned on downstream datasets,
159
+ or it is not intended for direct use without finetuning.
160
+ """
161
+ )
162
+
163
+ if device == "cpu" or device == torch.device("cpu"):
164
+ model = model.float()
165
+
166
+ return model.to(device), vis_processors, txt_processors
167
+
168
+
169
+ class ModelZoo:
170
+ """
171
+ A utility class to create string representation of available model architectures and types.
172
+
173
+ >>> from bliva.models import model_zoo
174
+ >>> # list all available models
175
+ >>> print(model_zoo)
176
+ >>> # show total number of models
177
+ >>> print(len(model_zoo))
178
+ """
179
+
180
+ def __init__(self) -> None:
181
+ self.model_zoo = {
182
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
183
+ for k, v in registry.mapping["model_name_mapping"].items()
184
+ }
185
+
186
+ def __str__(self) -> str:
187
+ return (
188
+ "=" * 50
189
+ + "\n"
190
+ + f"{'Architectures':<30} {'Types'}\n"
191
+ + "=" * 50
192
+ + "\n"
193
+ + "\n".join(
194
+ [
195
+ f"{name:<30} {', '.join(types)}"
196
+ for name, types in self.model_zoo.items()
197
+ ]
198
+ )
199
+ )
200
+
201
+ def __iter__(self):
202
+ return iter(self.model_zoo.items())
203
+
204
+ def __len__(self):
205
+ return sum([len(v) for v in self.model_zoo.values()])
206
+
207
+
208
+ model_zoo = ModelZoo()
bliva/models/base_model.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from bliva.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
15
+ from bliva.common.utils import get_abs_path, is_url
16
+ from omegaconf import OmegaConf
17
+
18
+
19
+ class BaseModel(nn.Module):
20
+ """Base class for models."""
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @property
26
+ def device(self):
27
+ return list(self.parameters())[0].device
28
+
29
+ def load_checkpoint(self, url_or_filename):
30
+ """
31
+ Load from a finetuned checkpoint.
32
+
33
+ This should expect no mismatch in the model keys and the checkpoint keys.
34
+ """
35
+
36
+ if is_url(url_or_filename):
37
+ cached_file = download_cached_file(
38
+ url_or_filename, check_hash=False, progress=True
39
+ )
40
+ checkpoint = torch.load(cached_file, map_location="cuda") #hack cpu
41
+ elif os.path.isfile(url_or_filename):
42
+ checkpoint = torch.load(url_or_filename, map_location="cuda") #hack cpu
43
+ else:
44
+ raise RuntimeError("checkpoint url or path is invalid")
45
+
46
+ if "model" in checkpoint.keys():
47
+ state_dict = checkpoint["model"]
48
+ else:
49
+ state_dict = checkpoint
50
+
51
+ msg = self.load_state_dict(state_dict, strict=False)
52
+
53
+ logging.info("Missing keys {}".format(msg.missing_keys))
54
+ logging.info("load checkpoint from %s" % url_or_filename)
55
+
56
+ return msg
57
+
58
+ @classmethod
59
+ def from_pretrained(cls, model_type):
60
+ """
61
+ Build a pretrained model from default configuration file, specified by model_type.
62
+
63
+ Args:
64
+ - model_type (str): model type, specifying architecture and checkpoints.
65
+
66
+ Returns:
67
+ - model (nn.Module): pretrained or finetuned model, depending on the configuration.
68
+ """
69
+ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
70
+ model = cls.from_config(model_cfg)
71
+
72
+ return model
73
+
74
+ @classmethod
75
+ def default_config_path(cls, model_type):
76
+ assert (
77
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
78
+ ), "Unknown model type {}".format(model_type)
79
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
80
+
81
+ def load_checkpoint_from_config(self, cfg, **kwargs):
82
+ """
83
+ Load checkpoint as specified in the config file.
84
+
85
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
86
+ When loading the pretrained model, each task-specific architecture may define their
87
+ own load_from_pretrained() method.
88
+ """
89
+ load_finetuned = cfg.get("load_finetuned", True)
90
+ #load_finetuned = False
91
+ if load_finetuned:
92
+ finetune_path = cfg.get("finetuned", None)
93
+ assert (
94
+ finetune_path is not None
95
+ ), "Found load_finetuned is True, but finetune_path is None."
96
+ self.load_checkpoint(url_or_filename=finetune_path)
97
+ else:
98
+ load_pretrained = cfg.get("load_pretrained", True)
99
+ if load_pretrained:
100
+ # load pre-trained weights
101
+ pretrain_path = cfg.get("pretrained", None)
102
+ assert "Found load_finetuned is False, but pretrain_path is None."
103
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
104
+
105
+
106
+ def before_evaluation(self, **kwargs):
107
+ pass
108
+
109
+ def show_n_params(self, return_str=True):
110
+ tot = 0
111
+ for p in self.parameters():
112
+ w = 1
113
+ for x in p.shape:
114
+ w *= x
115
+ tot += w
116
+ if return_str:
117
+ if tot >= 1e6:
118
+ return "{:.1f}M".format(tot / 1e6)
119
+ else:
120
+ return "{:.1f}K".format(tot / 1e3)
121
+ else:
122
+ return tot
123
+
124
+
125
+ class BaseEncoder(nn.Module):
126
+ """
127
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
128
+ """
129
+
130
+ def __init__(self):
131
+ super().__init__()
132
+
133
+ def forward_features(self, samples, **kwargs):
134
+ raise NotImplementedError
135
+
136
+ @property
137
+ def device(self):
138
+ return list(self.parameters())[0].device
139
+
140
+
141
+ class SharedQueueMixin:
142
+ @torch.no_grad()
143
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
144
+ # gather keys before updating queue
145
+ image_feats = concat_all_gather(image_feat)
146
+ text_feats = concat_all_gather(text_feat)
147
+
148
+ batch_size = image_feats.shape[0]
149
+
150
+ ptr = int(self.queue_ptr)
151
+ assert self.queue_size % batch_size == 0 # for simplicity
152
+
153
+ # replace the keys at ptr (dequeue and enqueue)
154
+ self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
155
+ self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
156
+
157
+ if idxs is not None:
158
+ idxs = concat_all_gather(idxs)
159
+ self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
160
+
161
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
162
+ self.queue_ptr[0] = ptr
163
+
164
+
165
+ class MomentumDistilationMixin:
166
+ @torch.no_grad()
167
+ def copy_params(self):
168
+ for model_pair in self.model_pairs:
169
+ for param, param_m in zip(
170
+ model_pair[0].parameters(), model_pair[1].parameters()
171
+ ):
172
+ param_m.data.copy_(param.data) # initialize
173
+ param_m.requires_grad = False # not update by gradient
174
+
175
+ @torch.no_grad()
176
+ def _momentum_update(self):
177
+ for model_pair in self.model_pairs:
178
+ for param, param_m in zip(
179
+ model_pair[0].parameters(), model_pair[1].parameters()
180
+ ):
181
+ param_m.data = param_m.data * self.momentum + param.data * (
182
+ 1.0 - self.momentum
183
+ )
184
+
185
+
186
+ class GatherLayer(torch.autograd.Function):
187
+ """
188
+ Gather tensors from all workers with support for backward propagation:
189
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
190
+ """
191
+
192
+ @staticmethod
193
+ def forward(ctx, x):
194
+ output = [
195
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
196
+ ]
197
+ torch.distributed.all_gather(output, x)
198
+ return tuple(output)
199
+
200
+ @staticmethod
201
+ def backward(ctx, *grads):
202
+ all_gradients = torch.stack(grads)
203
+ torch.distributed.all_reduce(all_gradients)
204
+ return all_gradients[torch.distributed.get_rank()]
205
+
206
+
207
+ def all_gather_with_grad(tensors):
208
+ """
209
+ Performs all_gather operation on the provided tensors.
210
+ Graph remains connected for backward grad computation.
211
+ """
212
+ # Queue the gathered tensors
213
+ world_size = torch.distributed.get_world_size()
214
+ # There is no need for reduction in the single-proc case
215
+ if world_size == 1:
216
+ return tensors
217
+
218
+ # tensor_all = GatherLayer.apply(tensors)
219
+ tensor_all = GatherLayer.apply(tensors)
220
+
221
+ return torch.cat(tensor_all, dim=0)
222
+
223
+
224
+ @torch.no_grad()
225
+ def concat_all_gather(tensor):
226
+ """
227
+ Performs all_gather operation on the provided tensors.
228
+ *** Warning ***: torch.distributed.all_gather has no gradient.
229
+ """
230
+ # if use distributed training
231
+ if not is_dist_avail_and_initialized():
232
+ return tensor
233
+
234
+ tensors_gather = [
235
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
236
+ ]
237
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
238
+
239
+ output = torch.cat(tensors_gather, dim=0)
240
+ return output
241
+
242
+
243
+ def tile(x, dim, n_tile):
244
+ init_dim = x.size(dim)
245
+ repeat_idx = [1] * x.dim()
246
+ repeat_idx[dim] = n_tile
247
+ x = x.repeat(*(repeat_idx))
248
+ order_index = torch.LongTensor(
249
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
250
+ )
251
+ return torch.index_select(x, dim, order_index.to(x.device))
bliva/models/blip2.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import logging
9
+ import os
10
+ import time
11
+ import datetime
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.distributed as dist
16
+ import torch.nn.functional as F
17
+
18
+ import bliva.common.dist_utils as dist_utils
19
+ from bliva.common.dist_utils import download_cached_file
20
+ from bliva.common.utils import is_url
21
+ from bliva.common.logger import MetricLogger
22
+ from bliva.models.base_model import BaseModel
23
+ from bliva.models.Qformer import BertConfig, BertLMHeadModel
24
+ from bliva.models.eva_vit import create_eva_vit_g
25
+ from bliva.models.clip_vit import create_clip_vit_L
26
+ from transformers import BertTokenizer
27
+
28
+
29
+ class Blip2Base(BaseModel):
30
+ @classmethod
31
+ def init_tokenizer(cls, truncation_side="right"):
32
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side)
33
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
34
+ return tokenizer
35
+
36
+ def maybe_autocast(self, dtype=torch.float16):
37
+ # if on cpu, don't use autocast
38
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
39
+ enable_autocast = self.device != torch.device("cpu")
40
+
41
+ if enable_autocast:
42
+ return torch.cuda.amp.autocast(dtype=dtype)
43
+ else:
44
+ return contextlib.nullcontext()
45
+
46
+ @classmethod
47
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
48
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
49
+ encoder_config.encoder_width = vision_width
50
+ # insert cross-attention layer every other block
51
+ encoder_config.add_cross_attention = True
52
+ encoder_config.cross_attention_freq = cross_attention_freq
53
+ encoder_config.query_length = num_query_token
54
+ Qformer = BertLMHeadModel.from_pretrained(
55
+ "bert-base-uncased", config=encoder_config
56
+ )
57
+ query_tokens = nn.Parameter(
58
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
59
+ )
60
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) #0.02
61
+ return Qformer, query_tokens
62
+
63
+ def init_vision_encoder(
64
+ self, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
65
+ ):
66
+ assert model_name in [
67
+ "eva_clip_g",
68
+ "eva2_clip_L",
69
+ "clip_L",
70
+ 'cpe_eva_clip_g'
71
+ ], "vit model must be eva_clip_g, eva2_clip_L or clip_L or cpe_eva_clip_g"
72
+ if model_name == "eva_clip_g":
73
+ visual_encoder = create_eva_vit_g(
74
+ img_size, drop_path_rate, use_grad_checkpoint, precision
75
+ )
76
+ # elif model_name == "eva2_clip_L":
77
+ # visual_encoder = create_eva2_vit_L(
78
+ # img_size, drop_path_rate, use_grad_checkpoint, precision
79
+ # )
80
+ elif model_name == "clip_L":
81
+ visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
82
+
83
+ ln_vision = LayerNorm(visual_encoder.num_features)
84
+ self.vit_name = model_name
85
+ return visual_encoder, ln_vision
86
+
87
+ def load_from_pretrained(self, url_or_filename):
88
+ if is_url(url_or_filename):
89
+ cached_file = download_cached_file(
90
+ url_or_filename, check_hash=False, progress=True
91
+ )
92
+ checkpoint = torch.load(cached_file, map_location="cpu")
93
+ elif os.path.isfile(url_or_filename):
94
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
95
+ else:
96
+ raise RuntimeError("checkpoint url or path is invalid")
97
+
98
+ state_dict = checkpoint["model"]
99
+
100
+ msg = self.load_state_dict(state_dict, strict=False)
101
+
102
+ # logging.info("Missing keys {}".format(msg.missing_keys))
103
+ logging.info("load checkpoint from %s" % url_or_filename)
104
+
105
+ return msg
106
+
107
+ def get_optimizer_params(self, weight_decay, lr_scale=1):
108
+ if self.vit_name == "eva_clip_g":
109
+ vit_num_layers = self.visual_encoder.get_num_layer()
110
+ lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
111
+
112
+ parameter_group_names = {}
113
+ parameter_group_vars = {}
114
+
115
+ for name, param in self.named_parameters():
116
+ if not param.requires_grad:
117
+ continue # frozen weights
118
+ if len(param.shape) == 1 or name.endswith(".bias"):
119
+ group_name = "no_decay"
120
+ this_weight_decay = 0.
121
+ else:
122
+ group_name = "decay"
123
+ this_weight_decay = weight_decay
124
+ if 'visual_encoder' in name:
125
+ layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.',''))
126
+ group_name = "vit_layer_%d_%s" % (layer_id, group_name)
127
+ else:
128
+ layer_id = None
129
+
130
+ if group_name not in parameter_group_names:
131
+ if layer_id is not None:
132
+ scale = lr_scales[layer_id]
133
+ else:
134
+ scale = 1
135
+ parameter_group_names[group_name] = {
136
+ "weight_decay": this_weight_decay,
137
+ "params": [],
138
+ "lr_scale": scale
139
+ }
140
+ parameter_group_vars[group_name] = {
141
+ "weight_decay": this_weight_decay,
142
+ "params": [],
143
+ "lr_scale": scale
144
+ }
145
+ parameter_group_vars[group_name]["params"].append(param)
146
+ parameter_group_names[group_name]["params"].append(name)
147
+ # import json
148
+ # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
149
+ optim_params = list(parameter_group_vars.values())
150
+ return optim_params
151
+ else:
152
+ return super().get_optimizer_params(weight_decay,lr_scale)
153
+
154
+ def _lemmatize(self, answers):
155
+ def apply(answer):
156
+ doc = self.lemmatizer(answer)
157
+
158
+ words = []
159
+ for token in doc:
160
+ if token.pos_ in ["NOUN", "VERB"]:
161
+ words.append(token.lemma_)
162
+ else:
163
+ words.append(token.text)
164
+ answer = " ".join(words)
165
+
166
+ return answer
167
+
168
+ return [apply(answer) for answer in answers]
169
+
170
+ @property
171
+ def lemmatizer(self):
172
+ if self._lemmatizer is None:
173
+ try:
174
+ import spacy
175
+
176
+ self._lemmatizer = spacy.load("en_core_web_sm")
177
+ except ImportError:
178
+ logging.error(
179
+ """
180
+ Please install spacy and en_core_web_sm model to apply lemmatization.
181
+ python -m spacy download en_core_web_sm
182
+ OR
183
+ import spacy.cli
184
+ spacy.cli.download("en_core_web_sm")
185
+ """
186
+ )
187
+ exit(1)
188
+
189
+ return self._lemmatizer
190
+
191
+ def disabled_train(self, mode=True):
192
+ """Overwrite model.train with this function to make sure train/eval mode
193
+ does not change anymore."""
194
+ return self
195
+
196
+
197
+ class LayerNorm(nn.LayerNorm):
198
+ """Subclass torch's LayerNorm to handle fp16."""
199
+
200
+ def forward(self, x: torch.Tensor):
201
+ orig_type = x.dtype
202
+ ret = super().forward(x.type(torch.float32))
203
+ return ret.type(orig_type)
204
+
205
+
206
+ def compute_sim_matrix(model, data_loader, **kwargs):
207
+ k_test = kwargs.pop("k_test")
208
+
209
+ metric_logger = MetricLogger(delimiter=" ")
210
+ header = "Evaluation:"
211
+
212
+ logging.info("Computing features for evaluation...")
213
+ start_time = time.time()
214
+
215
+ texts = data_loader.dataset.text
216
+ num_text = len(texts)
217
+ text_bs = 256
218
+ text_ids = []
219
+ text_embeds = []
220
+ text_atts = []
221
+ for i in range(0, num_text, text_bs):
222
+ text = texts[i : min(num_text, i + text_bs)]
223
+ text_input = model.tokenizer(
224
+ text,
225
+ padding="max_length",
226
+ truncation=True,
227
+ max_length=35,
228
+ return_tensors="pt",
229
+ ).to(model.device)
230
+ text_feat = model.forward_text(text_input)
231
+ text_embed = F.normalize(model.text_proj(text_feat))
232
+ text_embeds.append(text_embed)
233
+ text_ids.append(text_input.input_ids)
234
+ text_atts.append(text_input.attention_mask)
235
+
236
+ text_embeds = torch.cat(text_embeds, dim=0)
237
+ text_ids = torch.cat(text_ids, dim=0)
238
+ text_atts = torch.cat(text_atts, dim=0)
239
+
240
+ vit_feats = []
241
+ image_embeds = []
242
+ for samples in data_loader:
243
+ image = samples["image"]
244
+
245
+ image = image.to(model.device)
246
+ image_feat, vit_feat = model.forward_image(image)
247
+ image_embed = model.vision_proj(image_feat)
248
+ image_embed = F.normalize(image_embed, dim=-1)
249
+
250
+ vit_feats.append(vit_feat.cpu())
251
+ image_embeds.append(image_embed)
252
+
253
+ vit_feats = torch.cat(vit_feats, dim=0)
254
+ image_embeds = torch.cat(image_embeds, dim=0)
255
+
256
+ sims_matrix = []
257
+ for image_embed in image_embeds:
258
+ sim_q2t = image_embed @ text_embeds.t()
259
+ sim_i2t, _ = sim_q2t.max(0)
260
+ sims_matrix.append(sim_i2t)
261
+ sims_matrix = torch.stack(sims_matrix, dim=0)
262
+
263
+ score_matrix_i2t = torch.full(
264
+ (len(data_loader.dataset.image), len(texts)), -100.0
265
+ ).to(model.device)
266
+
267
+ num_tasks = dist_utils.get_world_size()
268
+ rank = dist_utils.get_rank()
269
+ step = sims_matrix.size(0) // num_tasks + 1
270
+ start = rank * step
271
+ end = min(sims_matrix.size(0), start + step)
272
+
273
+ for i, sims in enumerate(
274
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
275
+ ):
276
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
277
+ image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
278
+ score = model.compute_itm(
279
+ image_inputs=image_inputs,
280
+ text_ids=text_ids[topk_idx],
281
+ text_atts=text_atts[topk_idx],
282
+ ).float()
283
+ score_matrix_i2t[start + i, topk_idx] = score + topk_sim
284
+
285
+ sims_matrix = sims_matrix.t()
286
+ score_matrix_t2i = torch.full(
287
+ (len(texts), len(data_loader.dataset.image)), -100.0
288
+ ).to(model.device)
289
+
290
+ step = sims_matrix.size(0) // num_tasks + 1
291
+ start = rank * step
292
+ end = min(sims_matrix.size(0), start + step)
293
+
294
+ for i, sims in enumerate(
295
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
296
+ ):
297
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
298
+ image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
299
+ score = model.compute_itm(
300
+ image_inputs=image_inputs,
301
+ text_ids=text_ids[start + i].repeat(k_test, 1),
302
+ text_atts=text_atts[start + i].repeat(k_test, 1),
303
+ ).float()
304
+ score_matrix_t2i[start + i, topk_idx] = score + topk_sim
305
+
306
+ if dist_utils.is_dist_avail_and_initialized():
307
+ dist.barrier()
308
+ torch.distributed.all_reduce(
309
+ score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
310
+ )
311
+ torch.distributed.all_reduce(
312
+ score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
313
+ )
314
+
315
+ total_time = time.time() - start_time
316
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
317
+ logging.info("Evaluation time {}".format(total_time_str))
318
+
319
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
bliva/models/bliva_flant5xxl.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import string
3
+ import random
4
+ import copy
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.cuda.amp import autocast as autocast
9
+ from transformers import T5TokenizerFast
10
+
11
+ from bliva.common.registry import registry
12
+ from bliva.models.blip2 import Blip2Base, disabled_train
13
+ from bliva.models.modeling_t5 import T5Config, T5ForConditionalGeneration
14
+ from transformers.modeling_outputs import BaseModelOutput
15
+
16
+
17
+ @registry.register_model("bliva_flant5")
18
+ class BLIVAFlanT5(Blip2Base):
19
+
20
+ PRETRAINED_MODEL_CONFIG_DICT = {
21
+ "flant5xxl": "configs/models/bliva_flant5xxl.yaml",
22
+ }
23
+
24
+ def __init__(
25
+ self,
26
+ vit_model="eva_clip_g",
27
+ img_size=224,
28
+ drop_path_rate=0,
29
+ use_grad_checkpoint=False,
30
+ vit_precision="fp16",
31
+ freeze_vit=True,
32
+ num_query_token=32,
33
+ t5_model="google/flan-t5-xl",
34
+ prompt="",
35
+ max_txt_len=128,
36
+ max_output_txt_len=256,
37
+ apply_lemmatizer=False,
38
+ num_few_shot_examples=0,
39
+ few_shot_prob=0,
40
+ qformer_text_input=True,
41
+ ):
42
+ """
43
+ apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
44
+ """
45
+ super().__init__()
46
+
47
+ self.tokenizer = self.init_tokenizer(truncation_side="left")
48
+
49
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
50
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
51
+ )
52
+ if freeze_vit:
53
+ for name, param in self.visual_encoder.named_parameters():
54
+ param.requires_grad = False
55
+ self.visual_encoder = self.visual_encoder.eval()
56
+ self.visual_encoder.train = disabled_train
57
+ logging.info("freeze vision encoder")
58
+
59
+ self.Qformer, self.query_tokens = self.init_Qformer(
60
+ num_query_token, self.visual_encoder.num_features
61
+ )
62
+
63
+ if not qformer_text_input:
64
+ self.Qformer.bert.embeddings.word_embeddings = None
65
+ self.Qformer.bert.embeddings.position_embeddings = None
66
+ for layer in self.Qformer.bert.encoder.layer:
67
+ layer.output = None
68
+ layer.intermediate = None
69
+ else:
70
+ self.Qformer.resize_token_embeddings(len(self.tokenizer))
71
+ self.Qformer.cls = None
72
+
73
+ self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='left')
74
+ self.t5_output_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='right')
75
+
76
+ t5_config = T5Config.from_pretrained(t5_model)
77
+ t5_config.dense_act_fn = "gelu"
78
+ self.t5_model = T5ForConditionalGeneration.from_pretrained(
79
+ t5_model, config=t5_config
80
+ )
81
+
82
+ for name, param in self.t5_model.named_parameters():
83
+ param.requires_grad = False
84
+ param.data = param.data.bfloat16()
85
+
86
+ self.t5_proj = nn.Linear(
87
+ self.Qformer.config.hidden_size, self.t5_model.config.hidden_size
88
+ )
89
+
90
+ self.max_txt_len = max_txt_len
91
+ self.max_output_txt_len = max_output_txt_len
92
+ self.prompt = prompt
93
+
94
+ self._apply_lemmatizer = apply_lemmatizer
95
+ self._lemmatizer = None
96
+
97
+ self.num_few_shot_examples = num_few_shot_examples
98
+ self.few_shot_prob = few_shot_prob
99
+
100
+ self.qformer_text_input = qformer_text_input
101
+ self.vision_project = nn.Linear(self.visual_encoder.num_features, self.t5_model.config.hidden_size)
102
+
103
+ def forward(self, samples):
104
+
105
+ image = samples["image"]
106
+ image_features= self.visual_encoder.get_intermediate_layers(image)[-2] # [batch_size, 257, 1408]
107
+ image_features = image_features[:, 1:]
108
+ add_feature_llm = self.vision_project(image_features)
109
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
110
+
111
+ with self.maybe_autocast():
112
+ image_embeds = self.ln_vision(self.visual_encoder(image))
113
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
114
+
115
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
116
+ if self.qformer_text_input:
117
+ text_Qformer = self.tokenizer(
118
+ samples["text_input"],
119
+ padding='longest',
120
+ truncation=True,
121
+ max_length=self.max_txt_len,
122
+ return_tensors="pt",
123
+ ).to(image.device)
124
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
125
+ Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
126
+
127
+ query_output = self.Qformer.bert(
128
+ text_Qformer.input_ids,
129
+ attention_mask=Qformer_atts,
130
+ query_embeds=query_tokens,
131
+ encoder_hidden_states=image_embeds,
132
+ encoder_attention_mask=image_atts,
133
+ return_dict=True,
134
+ )
135
+ else:
136
+ query_output = self.Qformer.bert(
137
+ query_embeds=query_tokens,
138
+ encoder_hidden_states=image_embeds,
139
+ encoder_attention_mask=image_atts,
140
+ return_dict=True,
141
+ )
142
+
143
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
144
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
145
+
146
+ fs_embeds, fs_atts = None, None
147
+ if self.few_shot_prob > 0 and "few_shot_samples" in samples.keys():
148
+ fs_embeds, fs_atts = self.prepare_few_shot_embeds(samples['few_shot_samples'])
149
+
150
+ with self.maybe_autocast(dtype=torch.bfloat16):
151
+ input_tokens = self.t5_tokenizer(
152
+ samples["text_input"],
153
+ padding="longest",
154
+ truncation=True,
155
+ max_length=self.max_txt_len,
156
+ return_tensors="pt",
157
+ ).to(image.device)
158
+ output_tokens = self.t5_output_tokenizer(
159
+ samples["text_output"],
160
+ padding="longest",
161
+ truncation=True,
162
+ max_length=self.max_output_txt_len,
163
+ return_tensors="pt",
164
+ ).to(image.device)
165
+
166
+ encoder_atts = torch.cat([atts_t5, atts_add_feature_llm, input_tokens.attention_mask], dim=1)
167
+
168
+ targets = output_tokens.input_ids.masked_fill(
169
+ output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100
170
+ )
171
+
172
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
173
+ inputs_embeds = torch.cat([inputs_t5, add_feature_llm, inputs_embeds], dim=1)
174
+
175
+ if fs_embeds is not None:
176
+ inputs_embeds = torch.cat([fs_embeds, inputs_embeds], dim=1)
177
+ encoder_atts = torch.cat([fs_atts, encoder_atts], dim=1)
178
+
179
+ outputs = self.t5_model(
180
+ inputs_embeds=inputs_embeds,
181
+ attention_mask=encoder_atts,
182
+ decoder_attention_mask=output_tokens.attention_mask,
183
+ return_dict=True,
184
+ labels=targets,
185
+ )
186
+ loss = outputs.loss
187
+
188
+ return {"loss": loss}
189
+
190
+ def prepare_few_shot_embeds(self, samples):
191
+ this_n_fs = random.choices(
192
+ list(range(self.num_few_shot_examples + 1)),
193
+ weights=[1 - self.few_shot_prob] + [self.few_shot_prob / self.num_few_shot_examples] * self.num_few_shot_examples
194
+ )[0]
195
+
196
+ if this_n_fs == 0:
197
+ return None, None
198
+
199
+ images = []
200
+ text_input = []
201
+ for sample in samples:
202
+ for n in range(this_n_fs):
203
+ images.append(sample['image'][n])
204
+ text_input.append(sample['text_input'][n])
205
+ images = torch.stack(images, dim=0)
206
+
207
+ image = images
208
+
209
+ with self.maybe_autocast():
210
+ image_embeds = self.ln_vision(self.visual_encoder(image))
211
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
212
+ image.device
213
+ )
214
+
215
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
216
+ if self.qformer_text_input:
217
+ text_Qformer = self.tokenizer(
218
+ text_input,
219
+ padding='longest',
220
+ truncation=True,
221
+ max_length=self.max_txt_len,
222
+ return_tensors="pt",
223
+ ).to(image.device)
224
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
225
+ Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
226
+ query_output = self.Qformer.bert(
227
+ text_Qformer.input_ids,
228
+ attention_mask = Qformer_atts,
229
+ query_embeds=query_tokens,
230
+ encoder_hidden_states=image_embeds,
231
+ encoder_attention_mask=image_atts,
232
+ return_dict=True,
233
+ )
234
+ else:
235
+ query_output = self.Qformer.bert(
236
+ query_embeds=query_tokens,
237
+ encoder_hidden_states=image_embeds,
238
+ encoder_attention_mask=image_atts,
239
+ return_dict=True,
240
+ )
241
+
242
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
243
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
244
+
245
+ with self.maybe_autocast(dtype=torch.bfloat16):
246
+ input_tokens = self.t5_tokenizer(
247
+ text_input,
248
+ padding="longest",
249
+ truncation=True,
250
+ max_length=self.max_txt_len,
251
+ return_tensors="pt",
252
+ ).to(image.device)
253
+
254
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
255
+
256
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
257
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
258
+
259
+ if this_n_fs > 1:
260
+ encoder_atts = encoder_atts.reshape(encoder_atts.size(0) // this_n_fs, encoder_atts.size(1) * this_n_fs)
261
+ inputs_embeds = inputs_embeds.reshape(inputs_embeds.size(0) // this_n_fs, inputs_embeds.size(1) * this_n_fs, inputs_embeds.size(2))
262
+
263
+ return inputs_embeds, encoder_atts
264
+
265
+ @torch.no_grad()
266
+ def generate(
267
+ self,
268
+ samples,
269
+ use_nucleus_sampling=False,
270
+ num_beams=5,
271
+ max_length=256,
272
+ min_length=1,
273
+ top_p=0.9,
274
+ repetition_penalty=1.5,
275
+ length_penalty=1.0,
276
+ num_captions=1,
277
+ temperature=1,
278
+ ):
279
+ if "prompt" in samples.keys():
280
+ prompt = samples["prompt"]
281
+ else:
282
+ prompt = self.prompt
283
+
284
+ image = samples["image"]
285
+
286
+ bs = image.size(0)
287
+
288
+ if isinstance(prompt, str):
289
+ prompt = [prompt] * bs
290
+ else:
291
+ assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
292
+
293
+ # For TextCaps
294
+ if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
295
+ prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
296
+ if 'context' in samples.keys() and samples['context'] != '':
297
+ prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
298
+ print('using context')
299
+ query_tokens = self.query_tokens.expand(bs, -1, -1)
300
+ if self.qformer_text_input:
301
+ # remove ocr tokens in q_former (for eval textvqa)
302
+ # qformer_prompt = prompt
303
+ # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
304
+
305
+ text_Qformer = self.tokenizer(
306
+ prompt,
307
+ padding='longest',
308
+ truncation=True,
309
+ max_length=self.max_txt_len,
310
+ return_tensors="pt",
311
+ ).to(image.device)
312
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
313
+ Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
314
+
315
+ # For video data
316
+ if image.dim() == 5:
317
+ inputs_t5, atts_t5 = [], []
318
+ add_inputs_llm, add_atts_llm = [], []
319
+ for j in range(image.size(2)):
320
+ this_frame = image[:,:,j,:,:]
321
+ with self.maybe_autocast():
322
+ frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
323
+ frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
324
+ frame_features =self.visual_encoder.get_intermediate_layers(this_frame)[-2]
325
+
326
+ frame_features = frame_features[:, 1:]
327
+
328
+ add_feature_llm = self.vision_project(frame_features)
329
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
330
+
331
+ if self.qformer_text_input:
332
+ frame_query_output = self.Qformer.bert(
333
+ text_Qformer.input_ids,
334
+ attention_mask = Qformer_atts,
335
+ query_embeds=query_tokens,
336
+ encoder_hidden_states=frame_embeds,
337
+ encoder_attention_mask=frame_atts,
338
+ return_dict=True,
339
+ )
340
+ else:
341
+ frame_query_output = self.Qformer.bert(
342
+ query_embeds=query_tokens,
343
+ encoder_hidden_states=frame_embeds,
344
+ encoder_attention_mask=frame_atts,
345
+ return_dict=True,
346
+ )
347
+
348
+ frame_inputs_t5 = self.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
349
+ frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
350
+ inputs_t5.append(frame_inputs_t5)
351
+ atts_t5.append(frame_atts_t5)
352
+ add_inputs_llm.append(add_feature_llm)
353
+ add_atts_llm.append(atts_add_feature_llm)
354
+ inputs_t5 = torch.cat(inputs_t5, dim=1)
355
+ atts_t5 = torch.cat(atts_t5, dim=1)
356
+ add_feature_llm = torch.cat(add_inputs_llm, dim=1)
357
+ atts_add_feature_llm = torch.cat(add_atts_llm, dim=1)
358
+ else:
359
+ with self.maybe_autocast():
360
+ image_embeds = self.ln_vision(self.visual_encoder(image))
361
+ image_features= self.visual_encoder.get_intermediate_layers(image)[-2]
362
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
363
+
364
+ image_features = image_features[:, 1:]
365
+ add_feature_llm = self.vision_project(image_features)
366
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
367
+ if self.qformer_text_input:
368
+ query_output = self.Qformer.bert(
369
+ text_Qformer.input_ids,
370
+ attention_mask=Qformer_atts,
371
+ query_embeds=query_tokens,
372
+ encoder_hidden_states=image_embeds,
373
+ encoder_attention_mask=image_atts,
374
+ return_dict=True,
375
+ )
376
+ else:
377
+ query_output = self.Qformer.bert(
378
+ query_embeds=query_tokens,
379
+ encoder_hidden_states=image_embeds,
380
+ encoder_attention_mask=image_atts,
381
+ return_dict=True,
382
+ )
383
+
384
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
385
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
386
+
387
+ input_tokens = self.t5_tokenizer(
388
+ prompt,
389
+ padding="longest",
390
+ return_tensors="pt"
391
+ ).to(image.device)
392
+
393
+ encoder_atts = torch.cat([atts_t5, atts_add_feature_llm,input_tokens.attention_mask], dim=1)
394
+
395
+ with self.maybe_autocast(dtype=torch.bfloat16):
396
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
397
+ inputs_embeds = torch.cat([inputs_t5, add_feature_llm, inputs_embeds], dim=1)
398
+
399
+ outputs = self.t5_model.generate(
400
+ inputs_embeds=inputs_embeds,
401
+ attention_mask=encoder_atts,
402
+ do_sample=use_nucleus_sampling,
403
+ top_p=top_p,
404
+ temperature=temperature,
405
+ num_beams=num_beams,
406
+ max_new_tokens=max_length,
407
+ min_length=min_length,
408
+ repetition_penalty=repetition_penalty,
409
+ length_penalty=length_penalty,
410
+ num_return_sequences=num_captions,
411
+ )
412
+ output_text = self.t5_tokenizer.batch_decode(
413
+ outputs, skip_special_tokens=True
414
+ )
415
+
416
+ return output_text
417
+
418
+ def predict_answers(
419
+ self,
420
+ samples,
421
+ num_beams=5,
422
+ inference_method="generate",
423
+ max_len=10,
424
+ min_len=1,
425
+ num_ans_candidates=128,
426
+ answer_list=None,
427
+ prompt="",
428
+ length_penalty=-1,
429
+ **kwargs
430
+ ):
431
+ if isinstance(samples["text_input"], str):
432
+ samples["text_input"] = [samples["text_input"]]
433
+
434
+ if prompt:
435
+ if prompt.count("{}") == 2:
436
+ if 'ocr_tokens' in samples:
437
+ text_input = [
438
+ prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i])
439
+ for i in range(len(samples["text_input"]))]
440
+ elif 'choices' in samples:
441
+ text_input = []
442
+ for i in range(len(samples["text_input"])):
443
+ this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])]
444
+ this_choices = " ".join(this_choices)
445
+ text_input.append(prompt.format(samples["text_input"][i], this_choices))
446
+ else:
447
+ text_input = [prompt.format(question) for question in samples["text_input"]]
448
+ else:
449
+ text_input = samples["text_input"]
450
+
451
+ samples["prompt"] = text_input
452
+
453
+ output_text = self.generate(
454
+ samples,
455
+ num_beams=num_beams,
456
+ max_length=max_len,
457
+ min_length=min_len,
458
+ length_penalty=length_penalty
459
+ )
460
+
461
+ if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]):
462
+ output_text = self._lemmatize(output_text)
463
+
464
+ return output_text
465
+
466
+ def predict_class(
467
+ self,
468
+ samples,
469
+ candidates,
470
+ n_segments=1,
471
+ ):
472
+ # If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one
473
+ if type(candidates[0]) == list:
474
+ results = []
475
+
476
+ for i in range(samples["image"].size(0)):
477
+ this_sample = {
478
+ "image": samples["image"][i].unsqueeze(0),
479
+ "prompt": samples["prompt"][i],
480
+ }
481
+
482
+ if "text_input" in samples.keys():
483
+ this_sample["text_input"] = [samples["text_input"][i]]
484
+
485
+ if 'context' in samples.keys():
486
+ this_sample['context'] = [samples["context"][i]]
487
+
488
+ if 'history' in samples.keys():
489
+ this_sample['history'] = [samples["history"][i]]
490
+
491
+ if 'caption' in samples.keys():
492
+ this_sample['caption'] = [samples["caption"][i]]
493
+
494
+ this_result = self._predict_class(this_sample, candidates[i], n_segments)
495
+ results.append(this_result)
496
+
497
+ try:
498
+ results = torch.cat(results, dim=0)
499
+ except:
500
+ results = [res.tolist()[0] for res in results]
501
+
502
+ return results
503
+
504
+ return self._predict_class(samples, candidates, n_segments)
505
+
506
+ def _predict_class(
507
+ self,
508
+ samples,
509
+ candidates,
510
+ n_segments=1,
511
+ ):
512
+ """
513
+ Args:
514
+ samples (dict): A dictionary containing the following keys:
515
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
516
+ - prompt: the instruction
517
+ candidates:
518
+ (list): A list of candidate class names;
519
+ n_segments:
520
+ (int): Split the candidates into n_segments and predict one by one. This is useful when the number of candidates is too large.
521
+ Returns:
522
+ output_class: predicted class index
523
+ """
524
+
525
+ image = samples["image"]
526
+ prompt = samples["prompt"]
527
+
528
+ bs = image.size(0)
529
+
530
+ if isinstance(prompt, str):
531
+ prompt = [prompt] * bs
532
+ else:
533
+ assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
534
+
535
+ if "text_input" in samples.keys():
536
+ if type(samples["text_input"][0]) == list:
537
+ prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))]
538
+ else:
539
+ prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))]
540
+
541
+ # scienceqa
542
+ if 'context' in samples.keys() and samples['context'] != '':
543
+ prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
544
+
545
+ # visual dialog
546
+ if 'history' in samples.keys() and samples['history'][0] != '':
547
+ prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))]
548
+
549
+ if 'caption' in samples.keys() and samples['caption'][0] != '':
550
+ prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))]
551
+
552
+ query_tokens = self.query_tokens.expand(bs, -1, -1)
553
+ if self.qformer_text_input:
554
+ text_Qformer = self.tokenizer(
555
+ prompt,
556
+ padding='longest',
557
+ truncation=True,
558
+ max_length=self.max_txt_len,
559
+ return_tensors="pt"
560
+ ).to(image.device)
561
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
562
+ Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask], dim=1)
563
+
564
+ if image.dim() == 5:
565
+ inputs_t5, atts_t5 = [], []
566
+ add_inputs_llm, add_atts_llm = [], []
567
+ for j in range(image.size(2)):
568
+ this_frame = image[:,:,j,:,:]
569
+ with self.maybe_autocast():
570
+ frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
571
+ frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
572
+ frame_features =self.visual_encoder.get_intermediate_layers(this_frame)[-2]
573
+
574
+ frame_features = frame_features[:, 1:]
575
+
576
+ add_feature_llm = self.vision_project(frame_features)
577
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
578
+ if self.qformer_text_input:
579
+ frame_query_output = self.Qformer.bert(
580
+ text_Qformer.input_ids,
581
+ attention_mask=Qformer_atts,
582
+ query_embeds=query_tokens,
583
+ encoder_hidden_states=frame_embeds,
584
+ encoder_attention_mask=frame_atts,
585
+ return_dict=True,
586
+ )
587
+ else:
588
+ frame_query_output = self.Qformer.bert(
589
+ query_embeds=query_tokens,
590
+ encoder_hidden_states=frame_embeds,
591
+ encoder_attention_mask=frame_atts,
592
+ return_dict=True,
593
+ )
594
+
595
+ frame_inputs_t5 = self.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
596
+ frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
597
+ inputs_t5.append(frame_inputs_t5)
598
+ atts_t5.append(frame_atts_t5)
599
+ add_inputs_llm.append(add_feature_llm)
600
+ add_atts_llm.append(atts_add_feature_llm)
601
+ inputs_t5 = torch.cat(inputs_t5, dim=1)
602
+ atts_t5 = torch.cat(atts_t5, dim=1)
603
+ add_feature_llm = torch.cat(add_inputs_llm, dim=1)
604
+ atts_add_feature_llm = torch.cat(add_atts_llm, dim=1)
605
+ else:
606
+ with self.maybe_autocast():
607
+ image_embeds = self.ln_vision(self.visual_encoder(image))
608
+ image_features= self.visual_encoder.get_intermediate_layers(image)[-2]
609
+
610
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
611
+
612
+ image_features = image_features[:, 1:]
613
+ add_feature_llm = self.vision_project(image_features)
614
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
615
+
616
+ if self.qformer_text_input:
617
+ query_output = self.Qformer.bert(
618
+ text_Qformer.input_ids,
619
+ attention_mask=Qformer_atts,
620
+ query_embeds=query_tokens,
621
+ encoder_hidden_states=image_embeds,
622
+ encoder_attention_mask=image_atts,
623
+ return_dict=True,
624
+ )
625
+ else:
626
+ query_output = self.Qformer.bert(
627
+ query_embeds=query_tokens,
628
+ encoder_hidden_states=image_embeds,
629
+ encoder_attention_mask=image_atts,
630
+ return_dict=True,
631
+ )
632
+
633
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
634
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
635
+
636
+ input_tokens = self.t5_tokenizer(
637
+ prompt, padding="longest", return_tensors="pt"
638
+ ).to(image.device)
639
+ output_tokens = self.t5_tokenizer(
640
+ candidates, padding="longest", return_tensors="pt"
641
+ ).to(image.device)
642
+
643
+ encoder_atts = torch.cat([atts_t5, atts_add_feature_llm, input_tokens.attention_mask], dim=1)
644
+
645
+ n_cands = len(candidates)
646
+
647
+ with self.maybe_autocast(dtype=torch.bfloat16):
648
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
649
+ inputs_embeds = torch.cat([inputs_t5,add_feature_llm, inputs_embeds], dim=1)
650
+
651
+ encoder_outputs = self.t5_model.encoder(
652
+ inputs_embeds=inputs_embeds,
653
+ attention_mask=encoder_atts,
654
+ )
655
+
656
+ all_losses = []
657
+ for n in range(n_segments):
658
+ seg_len = n_cands // n_segments
659
+ if n == (n_segments - 1):
660
+ seg_len = n_cands - seg_len * (n_segments - 1)
661
+
662
+ # this_encoder_outputs = copy.deepcopy(encoder_outputs)
663
+ this_encoder_outputs = BaseModelOutput(
664
+ last_hidden_state=encoder_outputs[0].clone(),
665
+ )
666
+
667
+ this_encoder_outputs['last_hidden_state'] = this_encoder_outputs[0].repeat_interleave(seg_len, dim=0)
668
+ this_encoder_atts = encoder_atts.repeat_interleave(seg_len, dim=0)
669
+
670
+ start_i = n * (n_cands // n_segments)
671
+ end_i = start_i + seg_len
672
+ this_output_tokens_ids = output_tokens.input_ids[start_i:end_i].repeat(bs, 1)
673
+ this_output_tokens_atts = output_tokens.attention_mask[start_i:end_i].repeat(bs, 1)
674
+
675
+ this_targets = this_output_tokens_ids.masked_fill(this_output_tokens_ids == self.t5_tokenizer.pad_token_id, -100)
676
+
677
+ outputs = self.t5_model(
678
+ encoder_outputs=this_encoder_outputs,
679
+ attention_mask=this_encoder_atts,
680
+ decoder_attention_mask=this_output_tokens_atts,
681
+ return_dict=True,
682
+ labels=this_targets,
683
+ reduction="none",
684
+ )
685
+ loss = outputs.loss
686
+
687
+ loss = loss.reshape(bs, seg_len)
688
+ # output_class_ranks = torch.argsort(loss, dim=-1)
689
+ all_losses.append(loss)
690
+
691
+ all_losses = torch.cat(all_losses, dim=-1)
692
+ output_class_ranks = torch.argsort(all_losses, dim=-1)
693
+
694
+ # encoder_outputs['last_hidden_state'] = encoder_outputs[0].repeat_interleave(n_cands, dim=0)
695
+ # encoder_atts = encoder_atts.repeat_interleave(n_cands, dim=0)
696
+ # output_tokens.input_ids = output_tokens.input_ids.repeat(bs, 1)
697
+ # output_tokens.attention_mask = output_tokens.attention_mask.repeat(bs, 1)
698
+
699
+ # # compute the LM loss for each candidate (sum logprob across all tokens) and select the highest
700
+ # targets = output_tokens.input_ids.masked_fill(output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100)
701
+
702
+ # outputs = self.t5_model(
703
+ # encoder_outputs=encoder_outputs,
704
+ # attention_mask=encoder_atts,
705
+ # decoder_attention_mask=output_tokens.attention_mask,
706
+ # return_dict=True,
707
+ # labels=targets,
708
+ # reduction="none",
709
+ # )
710
+ # loss = outputs.loss
711
+
712
+ # loss = loss.reshape(bs, n_cands)
713
+ # output_class_ranks = torch.argsort(loss, dim=-1) # (bs, num_candidates)
714
+
715
+ return output_class_ranks
716
+
717
+ def _lemmatize(self, answers):
718
+ def apply(answer):
719
+ doc = self.lemmatizer(answer)
720
+
721
+ words = []
722
+ for token in doc:
723
+ if token.pos_ in ["NOUN", "VERB"]:
724
+ words.append(token.lemma_)
725
+ else:
726
+ words.append(token.text)
727
+ answer = " ".join(words)
728
+
729
+ return answer
730
+
731
+ return [apply(answer) for answer in answers]
732
+
733
+ @property
734
+ def lemmatizer(self):
735
+ if self._lemmatizer is None:
736
+ try:
737
+ import spacy
738
+
739
+ self._lemmatizer = spacy.load("en_core_web_sm")
740
+ except ImportError:
741
+ logging.error(
742
+ """
743
+ Please install spacy and en_core_web_sm model to apply lemmatization.
744
+ python -m spacy download en_core_web_sm
745
+ OR
746
+ import spacy.cli
747
+ spacy.cli.download("en_core_web_sm")
748
+ """
749
+ )
750
+ exit(1)
751
+
752
+ return self._lemmatizer
753
+
754
+ @classmethod
755
+ def from_config(cls, cfg):
756
+ vit_model = cfg.get("vit_model", "eva_clip_g")
757
+ img_size = cfg.get("image_size")
758
+ num_query_token = cfg.get("num_query_token")
759
+ t5_model = cfg.get("t5_model")
760
+
761
+ drop_path_rate = cfg.get("drop_path_rate", 0)
762
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
763
+ vit_precision = cfg.get("vit_precision", "fp16")
764
+ freeze_vit = cfg.get("freeze_vit", True)
765
+
766
+ prompt = cfg.get("prompt", "")
767
+ max_txt_len = cfg.get("max_txt_len", 128)
768
+ max_output_txt_len = cfg.get("max_output_txt_len", 256)
769
+
770
+ apply_lemmatizer = cfg.get("apply_lemmatizer", False)
771
+
772
+ num_few_shot_examples = cfg.get("num_few_shot_examples", 0)
773
+ few_shot_prob = cfg.get("few_shot_prob", 0.0)
774
+
775
+ qformer_text_input = cfg.get("qformer_text_input", True)
776
+
777
+ model = cls(
778
+ vit_model=vit_model,
779
+ img_size=img_size,
780
+ drop_path_rate=drop_path_rate,
781
+ use_grad_checkpoint=use_grad_checkpoint,
782
+ vit_precision=vit_precision,
783
+ freeze_vit=freeze_vit,
784
+ num_query_token=num_query_token,
785
+ t5_model=t5_model,
786
+ prompt=prompt,
787
+ max_txt_len=max_txt_len,
788
+ max_output_txt_len=max_output_txt_len,
789
+ apply_lemmatizer=apply_lemmatizer,
790
+ num_few_shot_examples=num_few_shot_examples,
791
+ few_shot_prob=few_shot_prob,
792
+ qformer_text_input=qformer_text_input,
793
+ )
794
+
795
+ # if qformer_text_input:
796
+ # # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
797
+ # model.load_from_pretrained(
798
+ # url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
799
+ # )
800
+
801
+ model.load_checkpoint_from_config(cfg)
802
+
803
+ return model
bliva/models/bliva_vicuna7b.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import string
3
+ from packaging import version
4
+
5
+ import torch
6
+ from torch.cuda.amp import autocast as autocast
7
+ import torch.nn as nn
8
+
9
+ import transformers
10
+
11
+ from bliva.common.registry import registry
12
+ from bliva.models.blip2 import Blip2Base, disabled_train
13
+
14
+ @registry.register_model("bliva_vicuna")
15
+ class BLIVAVicuna(Blip2Base):
16
+
17
+ PRETRAINED_MODEL_CONFIG_DICT = {
18
+ "vicuna7b": "configs/models/bliva_vicuna7b.yaml",
19
+ }
20
+
21
+ def __init__(
22
+ self,
23
+ vit_model="eva_clip_g",
24
+ img_size=224,
25
+ drop_path_rate=0,
26
+ use_grad_checkpoint=False,
27
+ vit_precision="fp16",
28
+ freeze_vit=True,
29
+ num_query_token=32,
30
+ llm_model="",
31
+ prompt="",
32
+ max_txt_len=128,
33
+ max_output_txt_len=256,
34
+ apply_lemmatizer=False,
35
+ qformer_text_input=True,
36
+ ):
37
+ super().__init__()
38
+ transformers_version = version.parse(transformers.__version__)
39
+ assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28"
40
+ from transformers import LlamaTokenizer
41
+ from bliva.models.modeling_llama import LlamaForCausalLM
42
+
43
+ self.tokenizer = self.init_tokenizer(truncation_side="left")
44
+
45
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
46
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
47
+ )
48
+
49
+ if freeze_vit:
50
+ for name, param in self.visual_encoder.named_parameters():
51
+ param.requires_grad = False
52
+ self.visual_encoder = self.visual_encoder.eval()
53
+ self.visual_encoder.train = disabled_train
54
+ logging.info("freeze vision encoder")
55
+
56
+ self.Qformer, self.query_tokens = self.init_Qformer(
57
+ num_query_token, self.visual_encoder.num_features
58
+ )
59
+
60
+ if not qformer_text_input:
61
+ self.Qformer.bert.embeddings.word_embeddings = None
62
+ self.Qformer.bert.embeddings.position_embeddings = None
63
+ for layer in self.Qformer.bert.encoder.layer:
64
+ layer.output = None
65
+ layer.intermediate = None
66
+ else:
67
+ self.Qformer.resize_token_embeddings(len(self.tokenizer))
68
+ self.Qformer.cls = None
69
+
70
+ self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, truncation_side="left")
71
+ self.llm_model = LlamaForCausalLM.from_pretrained(
72
+ llm_model, low_cpu_mem_usage=True, torch_dtype=torch.float16
73
+ ).to('cuda:0') #load_in_8bit=True
74
+ self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
75
+ self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
76
+ self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
77
+ self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
78
+ # self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
79
+
80
+ self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
81
+
82
+ # self.eos_token_id = self.llm_tokenizer(
83
+ # self.llm_tokenizer.eos_token, add_special_tokens=False
84
+ # ).input_ids[0]
85
+
86
+ for name, param in self.llm_model.named_parameters():
87
+ param.requires_grad = False
88
+
89
+ self.llm_proj = nn.Linear(
90
+ self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
91
+ )
92
+
93
+ self.max_txt_len = max_txt_len
94
+ self.max_output_txt_len = max_output_txt_len
95
+ self.prompt = prompt
96
+ prompt_tokens = self.llm_tokenizer(self.prompt, return_tensors="pt")
97
+ self.prompt_length = prompt_tokens.attention_mask.sum(1)
98
+
99
+ self._lemmatizer = None
100
+
101
+ self.qformer_text_input = qformer_text_input
102
+
103
+ self.vision_project = nn.Linear(self.visual_encoder.num_features, self.llm_model.config.hidden_size)
104
+
105
+ def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts):
106
+ input_part_targets_len = []
107
+ llm_tokens = {"input_ids": [], "attention_mask": []}
108
+ for i in range(input_ids.size(0)):
109
+ this_input_ones = input_atts[i].sum()
110
+ input_part_targets_len.append(this_input_ones)
111
+ llm_tokens['input_ids'].append(
112
+ torch.cat([
113
+ input_ids[i][:this_input_ones],
114
+ output_ids[i][1:],
115
+ input_ids[i][this_input_ones:]
116
+ ])
117
+ )
118
+ llm_tokens['attention_mask'].append(
119
+ torch.cat([
120
+ input_atts[i][:this_input_ones],
121
+ output_atts[i][1:],
122
+ input_atts[i][this_input_ones:]
123
+ ])
124
+ )
125
+ llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
126
+ llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
127
+ return llm_tokens, input_part_targets_len
128
+
129
+ def forward(self, samples):
130
+ # print('-----------------')
131
+ # print(samples["text_input"])
132
+ # print(samples["text_output"])
133
+ # print(samples["image"].shape)
134
+ # print('-----------------')
135
+
136
+ image = samples["image"]
137
+
138
+ image_features= self.visual_encoder.get_intermediate_layers(image)[-2] # [batch_size, 257, 1408]
139
+ image_features = image_features[:, 1:]
140
+ add_feature_llm = self.vision_project(image_features)
141
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
142
+
143
+ with self.maybe_autocast():
144
+ image_embeds = self.ln_vision(self.visual_encoder(image))
145
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
146
+
147
+ bs = image.size(0)
148
+
149
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
150
+ if self.qformer_text_input:
151
+ text_Qformer = self.tokenizer(
152
+ samples["text_input"],
153
+ padding='longest',
154
+ truncation=True,
155
+ max_length=self.max_txt_len,
156
+ return_tensors="pt",
157
+ ).to(image.device)
158
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
159
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1)
160
+
161
+ query_output = self.Qformer.bert(
162
+ text_Qformer.input_ids,
163
+ attention_mask=Qformer_atts,
164
+ query_embeds=query_tokens,
165
+ encoder_hidden_states=image_embeds,
166
+ encoder_attention_mask=image_atts,
167
+ return_dict=True,
168
+ )
169
+ else:
170
+ query_output = self.Qformer.bert(
171
+ query_embeds=query_tokens,
172
+ encoder_hidden_states=image_embeds,
173
+ encoder_attention_mask=image_atts,
174
+ return_dict=True,
175
+ )
176
+
177
+ inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
178
+ atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
179
+
180
+ self.llm_tokenizer.padding_side = "right"
181
+ self.llm_tokenizer.truncation_side = 'left'
182
+ text_input_tokens = self.llm_tokenizer(
183
+ samples['text_input'],
184
+ return_tensors="pt",
185
+ padding="longest",
186
+ truncation=True,
187
+ max_length=self.max_txt_len,
188
+ ).to(image.device)
189
+
190
+ self.llm_tokenizer.truncation_side = 'right'
191
+ text_output_tokens = self.llm_tokenizer(
192
+ [t + self.llm_tokenizer.eos_token for t in samples['text_output']],
193
+ return_tensors="pt",
194
+ padding="longest",
195
+ truncation=True,
196
+ max_length=self.max_output_txt_len,
197
+ ).to(image.device)
198
+
199
+ llm_tokens, input_part_targets_len = self.concat_text_input_output(
200
+ text_input_tokens.input_ids,
201
+ text_input_tokens.attention_mask,
202
+ text_output_tokens.input_ids,
203
+ text_output_tokens.attention_mask,
204
+ )
205
+
206
+ # do not apply loss to the padding
207
+ targets = llm_tokens['input_ids'].masked_fill(
208
+ llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
209
+ )
210
+
211
+ # do not apply loss to the text input (i.e., instruction)
212
+ for i, l in enumerate(input_part_targets_len):
213
+ targets[i][:l] = -100
214
+
215
+ # do not apply loss to the query tokens
216
+ empty_targets = (
217
+ torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
218
+ )
219
+ # do not apply loss to the additional image features
220
+ empty_add_targets = (
221
+ torch.ones(atts_add_feature_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
222
+ )
223
+ #targets = torch.cat([empty_targets, targets], dim=1)
224
+ targets = torch.cat([empty_targets, empty_add_targets, targets], dim=1)
225
+
226
+ inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
227
+ #inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
228
+ #attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1)
229
+ inputs_embeds = torch.cat([inputs_llm, add_feature_llm, inputs_embeds], dim=1)
230
+ attention_mask = torch.cat([atts_llm, atts_add_feature_llm, llm_tokens['attention_mask']], dim=1)
231
+
232
+ with self.maybe_autocast():
233
+ outputs = self.llm_model(
234
+ inputs_embeds=inputs_embeds,
235
+ attention_mask=attention_mask,
236
+ return_dict=True,
237
+ labels=targets,
238
+ )
239
+
240
+ loss = outputs.loss
241
+
242
+ return {"loss": loss}
243
+
244
+ @torch.no_grad()
245
+ def generate(
246
+ self,
247
+ samples,
248
+ use_nucleus_sampling=False,
249
+ num_beams=5,
250
+ max_length=256,
251
+ min_length=1,
252
+ top_p=0.9,
253
+ repetition_penalty=1.5,
254
+ length_penalty=1,
255
+ num_captions=1,
256
+ temperature=1,
257
+ ):
258
+ self.llm_tokenizer.padding_side = "left"
259
+
260
+ if "prompt" in samples.keys():
261
+ prompt = samples["prompt"]
262
+ else:
263
+ prompt = samples["text_input"]
264
+
265
+ image = samples["image"]
266
+
267
+ bs = image.size(0)
268
+
269
+ # if isinstance(prompt, str):
270
+ # prompt = [prompt] * bs
271
+ # else:
272
+ # assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
273
+
274
+ # For TextCaps
275
+ if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
276
+ prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
277
+
278
+ if 'context' in samples.keys() and samples['context'] != '':
279
+ prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
280
+ print('using context')
281
+ query_tokens = self.query_tokens.expand(bs, -1, -1)
282
+ if self.qformer_text_input:
283
+ # remove ocr tokens in q_former (for eval textvqa)
284
+ # qformer_prompt = prompt
285
+ # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
286
+
287
+ text_Qformer = self.tokenizer(
288
+ prompt,
289
+ padding='longest',
290
+ truncation=True,
291
+ max_length=self.max_txt_len,
292
+ return_tensors="pt",
293
+ ).to(image.device)
294
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
295
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
296
+
297
+ # For video data
298
+ if image.dim() == 5:
299
+ inputs_llm, atts_llm = [], []
300
+ add_inputs_llm, add_atts_llm = [], []
301
+ for j in range(image.size(2)):
302
+ this_frame = image[:,:,j,:,:]
303
+ with self.maybe_autocast():
304
+ frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
305
+ frame_features =self.visual_encoder.get_intermediate_layers(this_frame)[-2]
306
+
307
+ frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
308
+ frame_features = frame_features[:, 1:]
309
+
310
+ add_feature_llm = self.vision_project(frame_features)
311
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
312
+
313
+ if self.qformer_text_input:
314
+ frame_query_output = self.Qformer.bert(
315
+ text_Qformer.input_ids,
316
+ attention_mask=Qformer_atts,
317
+ query_embeds=query_tokens,
318
+ encoder_hidden_states=frame_embeds,
319
+ encoder_attention_mask=frame_atts,
320
+ return_dict=True,
321
+ )
322
+ else:
323
+ frame_query_output = self.Qformer.bert(
324
+ query_embeds=query_tokens,
325
+ encoder_hidden_states=frame_embeds,
326
+ encoder_attention_mask=frame_atts,
327
+ return_dict=True,
328
+ )
329
+ frame_inputs_llm = self.llm_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
330
+ frame_atts_llm = torch.ones(frame_inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
331
+ inputs_llm.append(frame_inputs_llm)
332
+ atts_llm.append(frame_atts_llm)
333
+ add_inputs_llm.append(add_feature_llm)
334
+ add_atts_llm.append(atts_add_feature_llm)
335
+ inputs_llm = torch.cat(inputs_llm, dim=1)
336
+ atts_llm = torch.cat(atts_llm, dim=1)
337
+ add_feature_llm = torch.cat(add_inputs_llm, dim=1)
338
+ atts_add_feature_llm = torch.cat(add_atts_llm, dim=1)
339
+ else:
340
+ with self.maybe_autocast():
341
+ image_embeds = self.ln_vision(self.visual_encoder(image))
342
+
343
+ image_features= self.visual_encoder.get_intermediate_layers(image)[-2] # [batch_size, 257, 1408]
344
+
345
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
346
+
347
+ image_features = image_features[:, 1:]
348
+ add_feature_llm = self.vision_project(image_features)
349
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
350
+
351
+ if self.qformer_text_input:
352
+ query_output = self.Qformer.bert(
353
+ text_Qformer.input_ids,
354
+ attention_mask=Qformer_atts,
355
+ query_embeds=query_tokens,
356
+ encoder_hidden_states=image_embeds,
357
+ encoder_attention_mask=image_atts,
358
+ return_dict=True,
359
+ )
360
+ else:
361
+ query_output = self.Qformer.bert(
362
+ query_embeds=query_tokens,
363
+ encoder_hidden_states=image_embeds,
364
+ encoder_attention_mask=image_atts,
365
+ return_dict=True,
366
+ )
367
+
368
+ inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
369
+ atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
370
+
371
+ llm_tokens = self.llm_tokenizer(
372
+ prompt,
373
+ padding="longest",
374
+ return_tensors="pt"
375
+ ).to(image.device)
376
+
377
+ with self.maybe_autocast():
378
+ inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
379
+ # inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
380
+ # attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
381
+ inputs_embeds = torch.cat([inputs_llm, add_feature_llm, inputs_embeds], dim=1)
382
+ attention_mask = torch.cat([atts_llm, atts_add_feature_llm, llm_tokens['attention_mask']], dim=1)
383
+
384
+ outputs = self.llm_model.generate(
385
+ inputs_embeds=inputs_embeds,
386
+ attention_mask=attention_mask,
387
+ do_sample=use_nucleus_sampling,
388
+ top_p=top_p,
389
+ temperature=temperature,
390
+ num_beams=num_beams,
391
+ max_length=max_length,
392
+ min_length=min_length,
393
+ # eos_token_id=self.eos_token_id,
394
+ repetition_penalty=repetition_penalty,
395
+ length_penalty=length_penalty,
396
+ num_return_sequences=num_captions,
397
+ )
398
+
399
+ outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
400
+ output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
401
+ output_text = [text.strip() for text in output_text]
402
+
403
+ return output_text
404
+
405
+ def predict_answers(
406
+ self,
407
+ samples,
408
+ num_beams=5,
409
+ inference_method="generate",
410
+ max_len=10,
411
+ min_len=1,
412
+ num_ans_candidates=128,
413
+ answer_list=None,
414
+ prompt="",
415
+ length_penalty=0,
416
+ **kwargs
417
+ ):
418
+ if isinstance(samples["text_input"], str):
419
+ samples["text_input"] = [samples["text_input"]]
420
+
421
+ if prompt:
422
+ if prompt.count("{}") == 2:
423
+ if 'ocr_tokens' in samples:
424
+ text_input = [
425
+ prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i])
426
+ for i in range(len(samples["text_input"]))]
427
+ elif 'choices' in samples:
428
+ text_input = []
429
+ for i in range(len(samples["text_input"])):
430
+ this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])]
431
+ this_choices = " ".join(this_choices)
432
+ text_input.append(prompt.format(samples["text_input"][i], this_choices))
433
+ else:
434
+ text_input = [prompt.format(question) for question in samples["text_input"]]
435
+ else:
436
+ text_input = samples["text_input"]
437
+
438
+ samples["prompt"] = text_input
439
+
440
+ output_text = self.generate(
441
+ samples,
442
+ num_beams=num_beams,
443
+ max_length=max_len,
444
+ min_length=min_len,
445
+ length_penalty=length_penalty
446
+ )
447
+
448
+ if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
449
+ output_text = self._lemmatize(output_text)
450
+
451
+ return output_text
452
+
453
+ def predict_class(
454
+ self,
455
+ samples,
456
+ candidates,
457
+ n_segments=1,
458
+ ):
459
+ self.llm_tokenizer.padding_side = "left"
460
+
461
+ # If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one
462
+ if type(candidates[0]) == list:
463
+ results = []
464
+
465
+ for i in range(samples["image"].size(0)):
466
+ this_sample = {
467
+ "image": samples["image"][i].unsqueeze(0),
468
+ "prompt": samples["prompt"][i],
469
+ }
470
+
471
+ if "text_input" in samples.keys():
472
+ this_sample["text_input"] = [samples["text_input"][i]]
473
+
474
+ if 'context' in samples.keys():
475
+ this_sample['context'] = [samples["context"][i]]
476
+
477
+ if 'history' in samples.keys():
478
+ this_sample['history'] = [samples["history"][i]]
479
+
480
+ if 'caption' in samples.keys():
481
+ this_sample['caption'] = [samples["caption"][i]]
482
+
483
+ this_result = self._predict_class(this_sample, candidates[i], n_segments)
484
+ results.append(this_result)
485
+
486
+ try:
487
+ results = torch.cat(results, dim=0)
488
+ except:
489
+ results = [res.tolist()[0] for res in results]
490
+
491
+ return results
492
+
493
+ return self._predict_class(samples, candidates, n_segments)
494
+
495
+ def _predict_class(
496
+ self,
497
+ samples,
498
+ candidates,
499
+ n_segments=1,
500
+ ):
501
+ image = samples["image"]
502
+ prompt = samples["prompt"]
503
+
504
+ bs = image.size(0)
505
+
506
+ if isinstance(prompt, str):
507
+ prompt = [prompt] * bs
508
+ else:
509
+ assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
510
+
511
+ if "text_input" in samples.keys():
512
+ if type(samples["text_input"][0]) == list:
513
+ prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))]
514
+ else:
515
+ prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))]
516
+
517
+ # scienceqa
518
+ if 'context' in samples.keys() and samples['context'] != '':
519
+ prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
520
+
521
+ # visual dialog
522
+ if 'history' in samples.keys() and samples['history'][0] != '':
523
+ prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))]
524
+
525
+ if 'caption' in samples.keys() and samples['caption'][0] != '':
526
+ prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))]
527
+
528
+ query_tokens = self.query_tokens.expand(bs, -1, -1)
529
+ if self.qformer_text_input:
530
+ text_Qformer = self.tokenizer(
531
+ prompt,
532
+ padding='longest',
533
+ truncation=True,
534
+ max_length=self.max_txt_len,
535
+ return_tensors="pt"
536
+ ).to(image.device)
537
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
538
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
539
+
540
+ # For video data
541
+ if image.dim() == 5:
542
+ inputs_llm, atts_llm = [], []
543
+ add_inputs_llm, add_atts_llm = [], []
544
+ for j in range(image.size(2)):
545
+ this_frame = image[:,:,j,:,:]
546
+ with self.maybe_autocast():
547
+
548
+ frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
549
+ frame_features =self.visual_encoder.get_intermediate_layers(this_frame)[-2]
550
+
551
+ frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
552
+ frame_features = frame_features[:, 1:]
553
+
554
+ add_feature_llm = self.vision_project(frame_features)
555
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
556
+
557
+ if self.qformer_text_input:
558
+ frame_query_output = self.Qformer.bert(
559
+ text_Qformer.input_ids,
560
+ attention_mask=Qformer_atts,
561
+ query_embeds=query_tokens,
562
+ encoder_hidden_states=frame_embeds,
563
+ encoder_attention_mask=frame_atts,
564
+ return_dict=True,
565
+ )
566
+ else:
567
+ frame_query_output = self.Qformer.bert(
568
+ query_embeds=query_tokens,
569
+ encoder_hidden_states=frame_embeds,
570
+ encoder_attention_mask=frame_atts,
571
+ return_dict=True,
572
+ )
573
+ frame_inputs_llm = self.llm_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
574
+ frame_atts_llm = torch.ones(frame_inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
575
+ inputs_llm.append(frame_inputs_llm)
576
+ atts_llm.append(frame_atts_llm)
577
+ add_inputs_llm.append(add_feature_llm)
578
+ add_atts_llm.append(atts_add_feature_llm)
579
+ inputs_llm = torch.cat(inputs_llm, dim=1)
580
+ atts_llm = torch.cat(atts_llm, dim=1)
581
+ add_feature_llm = torch.cat(add_inputs_llm, dim=1)
582
+ atts_add_feature_llm = torch.cat(add_atts_llm, dim=1)
583
+ else:
584
+ with self.maybe_autocast():
585
+ image_embeds = self.ln_vision(self.visual_encoder(image))
586
+
587
+ image_features= self.visual_encoder.get_intermediate_layers(image)[-2] # [batch_size, 257, 1408]
588
+
589
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
590
+
591
+ image_features = image_features[:, 1:]
592
+ add_feature_llm = self.vision_project(image_features)
593
+ atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
594
+
595
+ if self.qformer_text_input:
596
+ query_output = self.Qformer.bert(
597
+ text_Qformer.input_ids,
598
+ attention_mask=Qformer_atts,
599
+ query_embeds=query_tokens,
600
+ encoder_hidden_states=image_embeds,
601
+ encoder_attention_mask=image_atts,
602
+ return_dict=True,
603
+ )
604
+ else:
605
+ query_output = self.Qformer.bert(
606
+ query_embeds=query_tokens,
607
+ encoder_hidden_states=image_embeds,
608
+ encoder_attention_mask=image_atts,
609
+ return_dict=True,
610
+ )
611
+
612
+ inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
613
+ atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
614
+
615
+ self.llm_tokenizer.padding_side = "right"
616
+ self.llm_tokenizer.truncation_side = 'left'
617
+ text_input_tokens = self.llm_tokenizer(
618
+ prompt,
619
+ return_tensors="pt",
620
+ padding="longest",
621
+ # truncation=True,
622
+ # max_length=self.max_txt_len,
623
+ ).to(image.device)
624
+
625
+ empty_targets = torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
626
+ empty_add_targets = (
627
+ torch.ones(atts_add_feature_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
628
+ )
629
+
630
+ # self.llm_tokenizer.padding_side = "right"
631
+ self.llm_tokenizer.truncation_side = 'right'
632
+ n_cands = len(candidates)
633
+ with self.maybe_autocast(dtype=torch.bfloat16):
634
+ all_losses = []
635
+ for n in range(n_segments):
636
+ seg_len = n_cands // n_segments
637
+ if n == (n_segments - 1):
638
+ seg_len = n_cands - seg_len * (n_segments - 1)
639
+
640
+ start_i = n * (n_cands // n_segments)
641
+ end_i = start_i + seg_len
642
+
643
+ this_output_tokens = self.llm_tokenizer(
644
+ candidates[start_i:end_i],
645
+ return_tensors="pt",
646
+ padding="longest",
647
+ # truncation=True,
648
+ # max_length=self.max_output_txt_len,
649
+ ).to(image.device)
650
+
651
+ this_input_tokens_ids = text_input_tokens.input_ids.repeat_interleave(seg_len, dim=0)
652
+ this_input_tokens_atts = text_input_tokens.attention_mask.repeat_interleave(seg_len, dim=0)
653
+
654
+ this_output_tokens_ids = this_output_tokens.input_ids.repeat(bs, 1)
655
+ this_output_tokens_atts = this_output_tokens.attention_mask.repeat(bs, 1)
656
+
657
+ this_llm_tokens, this_input_targets_len = self.concat_text_input_output(
658
+ this_input_tokens_ids,
659
+ this_input_tokens_atts,
660
+ this_output_tokens_ids,
661
+ this_output_tokens_atts
662
+ )
663
+
664
+ this_llm_input_ids = this_llm_tokens['input_ids']
665
+ this_llm_atts = this_llm_tokens['attention_mask']
666
+ # this_llm_input_ids = torch.cat([this_input_tokens_ids, this_output_tokens_ids], dim=1)
667
+ # this_llm_atts = torch.cat([this_input_tokens_atts, this_output_tokens_atts], dim=1)
668
+
669
+ inputs_embeds = self.llm_model.get_input_embeddings()(this_llm_input_ids)
670
+ inputs_embeds = torch.cat([inputs_llm.repeat_interleave(seg_len, dim=0), \
671
+ add_feature_llm.repeat_interleave(seg_len, dim=0), inputs_embeds], dim=1)
672
+ attention_mask = torch.cat([atts_llm.repeat_interleave(seg_len, dim=0), \
673
+ atts_add_feature_llm.repeat_interleave(seg_len, dim=0) ,this_llm_atts], dim=1)
674
+
675
+ this_targets = this_llm_input_ids.masked_fill(this_llm_input_ids == self.llm_tokenizer.pad_token_id, -100)
676
+ # this_targets[:, :this_input_tokens_ids.size(1)] = -100
677
+ for i, l in enumerate(this_input_targets_len):
678
+ this_targets[i][:l] = -100
679
+
680
+ this_targets = torch.cat([empty_targets.repeat_interleave(seg_len, dim=0), \
681
+ empty_add_targets.repeat_interleave(seg_len, dim=0) ,this_targets], dim=1)
682
+
683
+ outputs = self.llm_model(
684
+ inputs_embeds=inputs_embeds,
685
+ attention_mask=attention_mask,
686
+ return_dict=True,
687
+ labels=this_targets,
688
+ reduction="none",
689
+ )
690
+
691
+ loss = outputs.loss
692
+
693
+ loss = loss.reshape(bs, seg_len)
694
+ # output_class_ranks = torch.argsort(loss, dim=-1)
695
+ all_losses.append(loss)
696
+
697
+ all_losses = torch.cat(all_losses, dim=-1)
698
+ output_class_ranks = torch.argsort(all_losses, dim=-1)
699
+
700
+ return output_class_ranks
701
+
702
+ def _lemmatize(self, answers):
703
+ def apply(answer):
704
+ doc = self.lemmatizer(answer)
705
+
706
+ words = []
707
+ for token in doc:
708
+ if token.pos_ in ["NOUN", "VERB"]:
709
+ words.append(token.lemma_)
710
+ else:
711
+ words.append(token.text)
712
+ answer = " ".join(words)
713
+
714
+ return answer
715
+
716
+ return [apply(answer) for answer in answers]
717
+
718
+ @property
719
+ def lemmatizer(self):
720
+ if self._lemmatizer is None:
721
+ try:
722
+ import spacy
723
+
724
+ self._lemmatizer = spacy.load("en_core_web_sm")
725
+ except ImportError:
726
+ logging.error(
727
+ """
728
+ Please install spacy and en_core_web_sm model to apply lemmatization.
729
+ python -m spacy download en_core_web_sm
730
+ OR
731
+ import spacy.cli
732
+ spacy.cli.download("en_core_web_sm")
733
+ """
734
+ )
735
+ exit(1)
736
+
737
+ return self._lemmatizer
738
+
739
+ @classmethod
740
+ def from_config(cls, cfg):
741
+ vit_model = cfg.get("vit_model", "eva_clip_g")
742
+ img_size = cfg.get("image_size")
743
+ num_query_token = cfg.get("num_query_token")
744
+ llm_model = cfg.get("llm_model")
745
+
746
+ drop_path_rate = cfg.get("drop_path_rate", 0)
747
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
748
+ vit_precision = cfg.get("vit_precision", "fp16")
749
+ freeze_vit = cfg.get("freeze_vit", True)
750
+
751
+ prompt = cfg.get("prompt", "")
752
+ max_txt_len = cfg.get("max_txt_len", 128)
753
+ max_output_txt_len = cfg.get("max_output_txt_len", 256)
754
+
755
+ apply_lemmatizer = cfg.get("apply_lemmatizer", False)
756
+
757
+ qformer_text_input = cfg.get("qformer_text_input", True)
758
+
759
+ model = cls(
760
+ vit_model=vit_model,
761
+ img_size=img_size,
762
+ drop_path_rate=drop_path_rate,
763
+ use_grad_checkpoint=use_grad_checkpoint,
764
+ vit_precision=vit_precision,
765
+ freeze_vit=freeze_vit,
766
+ num_query_token=num_query_token,
767
+ llm_model=llm_model,
768
+ prompt=prompt,
769
+ max_txt_len=max_txt_len,
770
+ max_output_txt_len=max_output_txt_len,
771
+ apply_lemmatizer=apply_lemmatizer,
772
+ qformer_text_input=qformer_text_input,
773
+ )
774
+
775
+ # if qformer_text_input:
776
+ # # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
777
+ # model.load_from_pretrained(
778
+ # url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
779
+ # )
780
+
781
+ model.load_checkpoint_from_config(cfg)
782
+
783
+ return model
bliva/models/clip_vit.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from itertools import repeat
3
+ import collections.abc
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
11
+
12
+ from bliva.models.eva_vit import convert_weights_to_fp16
13
+ from bliva.common.dist_utils import download_cached_file
14
+
15
+ class Bottleneck(nn.Module):
16
+ expansion = 4
17
+
18
+ def __init__(self, inplanes, planes, stride=1):
19
+ super().__init__()
20
+
21
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
22
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
23
+ self.bn1 = nn.BatchNorm2d(planes)
24
+ self.relu1 = nn.ReLU(inplace=True)
25
+
26
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+ self.relu2 = nn.ReLU(inplace=True)
29
+
30
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
31
+
32
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
33
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
34
+ self.relu3 = nn.ReLU(inplace=True)
35
+
36
+ self.downsample = None
37
+ self.stride = stride
38
+
39
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
40
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
41
+ self.downsample = nn.Sequential(OrderedDict([
42
+ ("-1", nn.AvgPool2d(stride)),
43
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
44
+ ("1", nn.BatchNorm2d(planes * self.expansion))
45
+ ]))
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ identity = x
49
+
50
+ out = self.relu1(self.bn1(self.conv1(x)))
51
+ out = self.relu2(self.bn2(self.conv2(out)))
52
+ out = self.avgpool(out)
53
+ out = self.bn3(self.conv3(out))
54
+
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+
58
+ out += identity
59
+ out = self.relu3(out)
60
+ return out
61
+
62
+
63
+ class AttentionPool2d(nn.Module):
64
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
65
+ super().__init__()
66
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
67
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
68
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
69
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
70
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
71
+ self.num_heads = num_heads
72
+
73
+ def forward(self, x):
74
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
75
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
76
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
77
+ x, _ = F.multi_head_attention_forward(
78
+ query=x, key=x, value=x,
79
+ embed_dim_to_check=x.shape[-1],
80
+ num_heads=self.num_heads,
81
+ q_proj_weight=self.q_proj.weight,
82
+ k_proj_weight=self.k_proj.weight,
83
+ v_proj_weight=self.v_proj.weight,
84
+ in_proj_weight=None,
85
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
86
+ bias_k=None,
87
+ bias_v=None,
88
+ add_zero_attn=False,
89
+ dropout_p=0,
90
+ out_proj_weight=self.c_proj.weight,
91
+ out_proj_bias=self.c_proj.bias,
92
+ use_separate_proj_weight=True,
93
+ training=self.training,
94
+ need_weights=False
95
+ )
96
+
97
+ return x[0]
98
+
99
+
100
+ class LayerNorm(nn.LayerNorm):
101
+ """Subclass torch's LayerNorm to handle fp16."""
102
+
103
+ def forward(self, x: torch.Tensor):
104
+ orig_type = x.dtype
105
+ ret = super().forward(x.type(torch.float32))
106
+ return ret.type(orig_type)
107
+
108
+
109
+ class QuickGELU(nn.Module):
110
+ def forward(self, x: torch.Tensor):
111
+ return x * torch.sigmoid(1.702 * x)
112
+
113
+
114
+ class ResidualAttentionBlock(nn.Module):
115
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
116
+ super().__init__()
117
+
118
+ self.attn = nn.MultiheadAttention(d_model, n_head)
119
+ self.ln_1 = LayerNorm(d_model)
120
+ self.mlp = nn.Sequential(OrderedDict([
121
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
122
+ ("gelu", QuickGELU()),
123
+ ("c_proj", nn.Linear(d_model * 4, d_model))
124
+ ]))
125
+ self.ln_2 = LayerNorm(d_model)
126
+ self.attn_mask = attn_mask
127
+
128
+ if use_grad_checkpointing:
129
+ self.attn = checkpoint_wrapper(self.attn)
130
+ self.mlp = checkpoint_wrapper(self.mlp)
131
+
132
+ def attention(self, x: torch.Tensor):
133
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
134
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
135
+
136
+ def forward(self, x: torch.Tensor):
137
+ x = x + self.attention(self.ln_1(x))
138
+ x = x + self.mlp(self.ln_2(x))
139
+ return x
140
+
141
+
142
+ class Transformer(nn.Module):
143
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
144
+ super().__init__()
145
+ self.width = width
146
+ self.layers = layers
147
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
148
+
149
+ def forward(self, x: torch.Tensor):
150
+ return self.resblocks(x)
151
+
152
+ def get_second_last_feature(self, x: torch.Tensor):
153
+ for i in range(len(self.resblocks) - 2):
154
+ x = self.resblocks[i](x)
155
+ return x
156
+
157
+
158
+ class VisionTransformer(nn.Module):
159
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
160
+ super().__init__()
161
+ self.input_resolution = input_resolution
162
+ self.num_features = width
163
+ self.num_heads = heads
164
+ self.num_patches = (input_resolution // patch_size) ** 2
165
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
166
+
167
+ scale = width ** -0.5
168
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
169
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
170
+ self.ln_pre = LayerNorm(width)
171
+
172
+ self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
173
+
174
+ # self.ln_final = LayerNorm(width)
175
+
176
+ def forward(self, x: torch.Tensor):
177
+
178
+ x = self.conv1(x) # shape = [*, width, grid, grid]
179
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
180
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
181
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
182
+ x = x + self.positional_embedding.to(x.dtype)
183
+ x = self.ln_pre(x)
184
+
185
+ x = x.permute(1, 0, 2) # NLD -> LND
186
+ x = self.transformer(x)
187
+ x = x.permute(1, 0, 2) # LND -> NLD
188
+
189
+ # x = self.ln_final(x)
190
+ return x
191
+
192
+ def get_last_second_feature(self, x):
193
+
194
+ x = self.conv1(x) # shape = [*, width, grid, grid]
195
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
196
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
197
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
198
+ x = x + self.positional_embedding.to(x.dtype)
199
+ x = self.ln_pre(x)
200
+
201
+ x = x.permute(1, 0, 2) # NLD -> LND
202
+ x = self.transformer.get_second_last_feature(x)
203
+ x = x.permute(1, 0, 2) # LND -> NLD
204
+
205
+ return x
206
+
207
+ # From PyTorch internals
208
+ def _ntuple(n):
209
+ def parse(x):
210
+ if isinstance(x, collections.abc.Iterable):
211
+ return x
212
+ return tuple(repeat(x, n))
213
+ return parse
214
+ to_2tuple = _ntuple(2)
215
+ def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1):
216
+ # Rescale the grid of position embeddings when loading from state_dict
217
+ old_pos_embed = state_dict.get('positional_embedding', None)
218
+
219
+ grid_size = round((model.positional_embedding.shape[0] - 1) ** 0.5)
220
+ if old_pos_embed is None:
221
+ return
222
+ grid_size = to_2tuple(grid_size)
223
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
224
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
225
+ if new_seq_len == old_pos_embed.shape[0]:
226
+ return
227
+
228
+ if extra_tokens:
229
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
230
+ else:
231
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
232
+
233
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
234
+
235
+ print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
236
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
237
+ pos_emb_img = F.interpolate(
238
+ pos_emb_img,
239
+ size=grid_size,
240
+ mode=interpolation,
241
+ align_corners=True,
242
+ )
243
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
244
+ if pos_emb_tok is not None:
245
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
246
+ else:
247
+ new_pos_embed = pos_emb_img
248
+ state_dict['positional_embedding'] = new_pos_embed
249
+
250
+
251
+ def create_clip_vit_L(img_size=224,use_checkpoint=False,precision="fp16"):
252
+ model = VisionTransformer(
253
+ input_resolution=img_size,
254
+ patch_size=14,
255
+ width=1024,
256
+ layers=23,
257
+ heads=16,
258
+ use_grad_checkpointing=use_checkpoint,
259
+ )
260
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth"
261
+ cached_file = download_cached_file(
262
+ url, check_hash=False, progress=True
263
+ )
264
+ state_dict = torch.load(cached_file, map_location="cpu")
265
+ interpolate_pos_embed(model,state_dict)
266
+
267
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
268
+ # print(incompatible_keys)
269
+
270
+ if precision == "fp16":
271
+ convert_weights_to_fp16(model)
272
+ return model
bliva/models/eva_vit.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+
18
+ from bliva.common.dist_utils import download_cached_file
19
+
20
+ def _cfg(url='', **kwargs):
21
+ return {
22
+ 'url': url,
23
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
24
+ 'crop_pct': .9, 'interpolation': 'bicubic',
25
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
26
+ **kwargs
27
+ }
28
+
29
+
30
+ class DropPath(nn.Module):
31
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32
+ """
33
+ def __init__(self, drop_prob=None):
34
+ super(DropPath, self).__init__()
35
+ self.drop_prob = drop_prob
36
+
37
+ def forward(self, x):
38
+ return drop_path(x, self.drop_prob, self.training)
39
+
40
+ def extra_repr(self) -> str:
41
+ return 'p={}'.format(self.drop_prob)
42
+
43
+
44
+ class Mlp(nn.Module):
45
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
46
+ super().__init__()
47
+ out_features = out_features or in_features
48
+ hidden_features = hidden_features or in_features
49
+ self.fc1 = nn.Linear(in_features, hidden_features)
50
+ self.act = act_layer()
51
+ self.fc2 = nn.Linear(hidden_features, out_features)
52
+ self.drop = nn.Dropout(drop)
53
+
54
+ def forward(self, x):
55
+ x = self.fc1(x)
56
+ x = self.act(x)
57
+ # x = self.drop(x)
58
+ # commit this for the orignal BERT implement
59
+ x = self.fc2(x)
60
+ x = self.drop(x)
61
+ return x
62
+
63
+
64
+ class Attention(nn.Module):
65
+ def __init__(
66
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
67
+ proj_drop=0., window_size=None, attn_head_dim=None):
68
+ super().__init__()
69
+ self.num_heads = num_heads
70
+ head_dim = dim // num_heads
71
+ if attn_head_dim is not None:
72
+ head_dim = attn_head_dim
73
+ all_head_dim = head_dim * self.num_heads
74
+ self.scale = qk_scale or head_dim ** -0.5
75
+
76
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
77
+ if qkv_bias:
78
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
79
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
80
+ else:
81
+ self.q_bias = None
82
+ self.v_bias = None
83
+
84
+ if window_size:
85
+ self.window_size = window_size
86
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
87
+ self.relative_position_bias_table = nn.Parameter(
88
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
89
+ # cls to token & token 2 cls & cls to cls
90
+
91
+ # get pair-wise relative position index for each token inside the window
92
+ coords_h = torch.arange(window_size[0])
93
+ coords_w = torch.arange(window_size[1])
94
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
95
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
96
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
97
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
98
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
99
+ relative_coords[:, :, 1] += window_size[1] - 1
100
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
101
+ relative_position_index = \
102
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
103
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
104
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
105
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
106
+ relative_position_index[0, 0] = self.num_relative_distance - 1
107
+
108
+ self.register_buffer("relative_position_index", relative_position_index)
109
+ else:
110
+ self.window_size = None
111
+ self.relative_position_bias_table = None
112
+ self.relative_position_index = None
113
+
114
+ self.attn_drop = nn.Dropout(attn_drop)
115
+ self.proj = nn.Linear(all_head_dim, dim)
116
+ self.proj_drop = nn.Dropout(proj_drop)
117
+
118
+ def forward(self, x, rel_pos_bias=None):
119
+ B, N, C = x.shape
120
+ qkv_bias = None
121
+ if self.q_bias is not None:
122
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
123
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
125
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
126
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1))
130
+
131
+ if self.relative_position_bias_table is not None:
132
+ relative_position_bias = \
133
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
134
+ self.window_size[0] * self.window_size[1] + 1,
135
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
136
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
137
+ attn = attn + relative_position_bias.unsqueeze(0)
138
+
139
+ if rel_pos_bias is not None:
140
+ attn = attn + rel_pos_bias
141
+
142
+ attn = attn.softmax(dim=-1)
143
+ attn = self.attn_drop(attn)
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ return x
149
+
150
+
151
+ class Block(nn.Module):
152
+
153
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
154
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
155
+ window_size=None, attn_head_dim=None):
156
+ super().__init__()
157
+ self.norm1 = norm_layer(dim)
158
+ self.attn = Attention(
159
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
160
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
161
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
162
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+ self.norm2 = norm_layer(dim)
164
+ mlp_hidden_dim = int(dim * mlp_ratio)
165
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
166
+
167
+ if init_values is not None and init_values > 0:
168
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
169
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
170
+ else:
171
+ self.gamma_1, self.gamma_2 = None, None
172
+
173
+ def forward(self, x, rel_pos_bias=None):
174
+ if self.gamma_1 is None:
175
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
176
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
177
+ else:
178
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
179
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
180
+ return x
181
+
182
+
183
+ class PatchEmbed(nn.Module):
184
+ """ Image to Patch Embedding
185
+ """
186
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
187
+ super().__init__()
188
+ img_size = to_2tuple(img_size)
189
+ patch_size = to_2tuple(patch_size)
190
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
191
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
192
+ self.img_size = img_size
193
+ self.patch_size = patch_size
194
+ self.num_patches = num_patches
195
+
196
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
197
+
198
+ def forward(self, x, **kwargs):
199
+ B, C, H, W = x.shape
200
+ # FIXME look at relaxing size constraints
201
+ assert H == self.img_size[0] and W == self.img_size[1], \
202
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
203
+ x = self.proj(x).flatten(2).transpose(1, 2)
204
+ return x
205
+
206
+
207
+ class RelativePositionBias(nn.Module):
208
+
209
+ def __init__(self, window_size, num_heads):
210
+ super().__init__()
211
+ self.window_size = window_size
212
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
213
+ self.relative_position_bias_table = nn.Parameter(
214
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
215
+ # cls to token & token 2 cls & cls to cls
216
+
217
+ # get pair-wise relative position index for each token inside the window
218
+ coords_h = torch.arange(window_size[0])
219
+ coords_w = torch.arange(window_size[1])
220
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
221
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
222
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
223
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
224
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
225
+ relative_coords[:, :, 1] += window_size[1] - 1
226
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
227
+ relative_position_index = \
228
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
229
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
230
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
231
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
232
+ relative_position_index[0, 0] = self.num_relative_distance - 1
233
+
234
+ self.register_buffer("relative_position_index", relative_position_index)
235
+
236
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
237
+
238
+ def forward(self):
239
+ relative_position_bias = \
240
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
241
+ self.window_size[0] * self.window_size[1] + 1,
242
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
243
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
244
+
245
+
246
+ class VisionTransformer(nn.Module):
247
+ """ Vision Transformer with support for patch or hybrid CNN input stage
248
+ """
249
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
250
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
251
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
252
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
253
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
254
+ super().__init__()
255
+ self.image_size = img_size
256
+ self.num_classes = num_classes
257
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
258
+
259
+ self.patch_embed = PatchEmbed(
260
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
261
+ num_patches = self.patch_embed.num_patches
262
+
263
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
264
+ if use_abs_pos_emb:
265
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
266
+ else:
267
+ self.pos_embed = None
268
+ self.pos_drop = nn.Dropout(p=drop_rate)
269
+
270
+ if use_shared_rel_pos_bias:
271
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
272
+ else:
273
+ self.rel_pos_bias = None
274
+ self.use_checkpoint = use_checkpoint
275
+
276
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
277
+ self.use_rel_pos_bias = use_rel_pos_bias
278
+ self.blocks = nn.ModuleList([
279
+ Block(
280
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
281
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
282
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
283
+ for i in range(depth)])
284
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
285
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
286
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
287
+
288
+ if self.pos_embed is not None:
289
+ trunc_normal_(self.pos_embed, std=.02)
290
+ trunc_normal_(self.cls_token, std=.02)
291
+ # trunc_normal_(self.mask_token, std=.02)
292
+ # if isinstance(self.head, nn.Linear):
293
+ # trunc_normal_(self.head.weight, std=.02)
294
+ self.apply(self._init_weights)
295
+ self.fix_init_weight()
296
+ # if isinstance(self.head, nn.Linear):
297
+ # self.head.weight.data.mul_(init_scale)
298
+ # self.head.bias.data.mul_(init_scale)
299
+
300
+ def fix_init_weight(self):
301
+ def rescale(param, layer_id):
302
+ param.div_(math.sqrt(2.0 * layer_id))
303
+
304
+ for layer_id, layer in enumerate(self.blocks):
305
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
306
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
307
+
308
+ def _init_weights(self, m):
309
+ if isinstance(m, nn.Linear):
310
+ trunc_normal_(m.weight, std=.02)
311
+ if isinstance(m, nn.Linear) and m.bias is not None:
312
+ nn.init.constant_(m.bias, 0)
313
+ elif isinstance(m, nn.LayerNorm):
314
+ nn.init.constant_(m.bias, 0)
315
+ nn.init.constant_(m.weight, 1.0)
316
+
317
+ def get_classifier(self):
318
+ return self.head
319
+
320
+ def reset_classifier(self, num_classes, global_pool=''):
321
+ self.num_classes = num_classes
322
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
323
+
324
+ def forward_features(self, x):
325
+ x = self.patch_embed(x)
326
+ batch_size, seq_len, _ = x.size()
327
+
328
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
329
+ x = torch.cat((cls_tokens, x), dim=1)
330
+ if self.pos_embed is not None:
331
+ x = x + self.pos_embed
332
+ x = self.pos_drop(x)
333
+
334
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
335
+ for blk in self.blocks:
336
+ if self.use_checkpoint:
337
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
338
+ else:
339
+ x = blk(x, rel_pos_bias)
340
+ return x
341
+ # x = self.norm(x)
342
+
343
+ # if self.fc_norm is not None:
344
+ # t = x[:, 1:, :]
345
+ # return self.fc_norm(t.mean(1))
346
+ # else:
347
+ # return x[:, 0]
348
+
349
+ def forward(self, x):
350
+ x = self.forward_features(x)
351
+ # x = self.head(x)
352
+ return x
353
+
354
+ def get_intermediate_layers(self, x):
355
+ x = self.patch_embed(x)
356
+ batch_size, seq_len, _ = x.size()
357
+
358
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
359
+ x = torch.cat((cls_tokens, x), dim=1)
360
+ if self.pos_embed is not None:
361
+ x = x + self.pos_embed
362
+ x = self.pos_drop(x)
363
+
364
+ features = []
365
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
366
+ for blk in self.blocks:
367
+ x = blk(x, rel_pos_bias)
368
+ features.append(x)
369
+
370
+ return features
371
+
372
+
373
+ def interpolate_pos_embed(model, checkpoint_model):
374
+ if 'pos_embed' in checkpoint_model:
375
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
376
+ embedding_size = pos_embed_checkpoint.shape[-1]
377
+ num_patches = model.patch_embed.num_patches
378
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
379
+ # height (== width) for the checkpoint position embedding
380
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
381
+ # height (== width) for the new position embedding
382
+ new_size = int(num_patches ** 0.5)
383
+ # class_token and dist_token are kept unchanged
384
+ if orig_size != new_size:
385
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
386
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
387
+ # only the position tokens are interpolated
388
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
389
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
390
+ pos_tokens = torch.nn.functional.interpolate(
391
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
392
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
393
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
394
+ checkpoint_model['pos_embed'] = new_pos_embed
395
+
396
+
397
+ def convert_weights_to_fp16(model: nn.Module):
398
+ """Convert applicable model parameters to fp16"""
399
+
400
+ def _convert_weights_to_fp16(l):
401
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
402
+ l.weight.data = l.weight.data.half()
403
+ if l.bias is not None:
404
+ l.bias.data = l.bias.data.half()
405
+
406
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
407
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
408
+ # tensor = getattr(l, attr)
409
+ # if tensor is not None:
410
+ # tensor.data = tensor.data.half()
411
+
412
+ model.apply(_convert_weights_to_fp16)
413
+
414
+
415
+ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
416
+ model = VisionTransformer(
417
+ img_size=img_size,
418
+ patch_size=14,
419
+ use_mean_pooling=False,
420
+ embed_dim=1408,
421
+ depth=39,
422
+ num_heads=1408//88,
423
+ mlp_ratio=4.3637,
424
+ qkv_bias=True,
425
+ drop_path_rate=drop_path_rate,
426
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
427
+ use_checkpoint=use_checkpoint,
428
+ )
429
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
430
+ cached_file = download_cached_file(
431
+ url, check_hash=False, progress=True
432
+ )
433
+ state_dict = torch.load(cached_file, map_location="cpu")
434
+ interpolate_pos_embed(model,state_dict)
435
+
436
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
437
+ # print(incompatible_keys)
438
+
439
+ if precision == "fp16":
440
+ # model.to("cuda")
441
+ convert_weights_to_fp16(model)
442
+ return model
bliva/models/modeling_llama.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
+ from transformers.models.llama.configuration_llama import LlamaConfig
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = "LlamaConfig"
39
+
40
+
41
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42
+ def _make_causal_mask(
43
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44
+ ):
45
+ """
46
+ Make causal mask used for bi-directional self-attention.
47
+ """
48
+ bsz, tgt_len = input_ids_shape
49
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
50
+ mask_cond = torch.arange(mask.size(-1), device=device)
51
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52
+ mask = mask.to(dtype)
53
+
54
+ if past_key_values_length > 0:
55
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
60
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
+ """
62
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
+ """
64
+ bsz, src_len = mask.size()
65
+ tgt_len = tgt_len if tgt_len is not None else src_len
66
+
67
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+
69
+ inverted_mask = 1.0 - expanded_mask
70
+
71
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72
+
73
+
74
+ class LlamaRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ LlamaRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
85
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
86
+
87
+ # convert into half-precision if necessary
88
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
89
+ hidden_states = hidden_states.to(self.weight.dtype)
90
+
91
+ return self.weight * hidden_states
92
+
93
+
94
+ class LlamaRotaryEmbedding(torch.nn.Module):
95
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
96
+ super().__init__()
97
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
98
+ self.register_buffer("inv_freq", inv_freq)
99
+
100
+ # Build here to make `torch.jit.trace` work.
101
+ self.max_seq_len_cached = max_position_embeddings
102
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
103
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
104
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
105
+ emb = torch.cat((freqs, freqs), dim=-1)
106
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
107
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
108
+
109
+ def forward(self, x, seq_len=None):
110
+ # x: [bs, num_attention_heads, seq_len, head_size]
111
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
112
+ if seq_len > self.max_seq_len_cached:
113
+ self.max_seq_len_cached = seq_len
114
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
115
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
116
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
117
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
118
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
119
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
120
+ return (
121
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
122
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
123
+ )
124
+
125
+
126
+ def rotate_half(x):
127
+ """Rotates half the hidden dims of the input."""
128
+ x1 = x[..., : x.shape[-1] // 2]
129
+ x2 = x[..., x.shape[-1] // 2 :]
130
+ return torch.cat((-x2, x1), dim=-1)
131
+
132
+
133
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
134
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
135
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
136
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
137
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
138
+ q_embed = (q * cos) + (rotate_half(q) * sin)
139
+ k_embed = (k * cos) + (rotate_half(k) * sin)
140
+ return q_embed, k_embed
141
+
142
+
143
+ class LlamaMLP(nn.Module):
144
+ def __init__(
145
+ self,
146
+ hidden_size: int,
147
+ intermediate_size: int,
148
+ hidden_act: str,
149
+ ):
150
+ super().__init__()
151
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
152
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
153
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
154
+ self.act_fn = ACT2FN[hidden_act]
155
+
156
+ def forward(self, x):
157
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
158
+
159
+
160
+ class LlamaAttention(nn.Module):
161
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
162
+
163
+ def __init__(self, config: LlamaConfig):
164
+ super().__init__()
165
+ self.config = config
166
+ self.hidden_size = config.hidden_size
167
+ self.num_heads = config.num_attention_heads
168
+ self.head_dim = self.hidden_size // self.num_heads
169
+ self.max_position_embeddings = config.max_position_embeddings
170
+
171
+ if (self.head_dim * self.num_heads) != self.hidden_size:
172
+ raise ValueError(
173
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
174
+ f" and `num_heads`: {self.num_heads})."
175
+ )
176
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
177
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
178
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
179
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
180
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
181
+
182
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
183
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
184
+
185
+ def forward(
186
+ self,
187
+ hidden_states: torch.Tensor,
188
+ attention_mask: Optional[torch.Tensor] = None,
189
+ position_ids: Optional[torch.LongTensor] = None,
190
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
191
+ output_attentions: bool = False,
192
+ use_cache: bool = False,
193
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
194
+ bsz, q_len, _ = hidden_states.size()
195
+
196
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
197
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
199
+
200
+ kv_seq_len = key_states.shape[-2]
201
+ if past_key_value is not None:
202
+ kv_seq_len += past_key_value[0].shape[-2]
203
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
204
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
205
+ # [bsz, nh, t, hd]
206
+
207
+ if past_key_value is not None:
208
+ # reuse k, v, self_attention
209
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
210
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
211
+
212
+ past_key_value = (key_states, value_states) if use_cache else None
213
+
214
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
215
+
216
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
217
+ raise ValueError(
218
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
219
+ f" {attn_weights.size()}"
220
+ )
221
+
222
+ if attention_mask is not None:
223
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
224
+ raise ValueError(
225
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
226
+ )
227
+ attn_weights = attn_weights + attention_mask
228
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
229
+
230
+ # upcast attention to fp32
231
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
232
+ attn_output = torch.matmul(attn_weights, value_states)
233
+
234
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
235
+ raise ValueError(
236
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
237
+ f" {attn_output.size()}"
238
+ )
239
+
240
+ attn_output = attn_output.transpose(1, 2)
241
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
242
+
243
+ attn_output = self.o_proj(attn_output)
244
+
245
+ if not output_attentions:
246
+ attn_weights = None
247
+
248
+ return attn_output, attn_weights, past_key_value
249
+
250
+
251
+ class LlamaDecoderLayer(nn.Module):
252
+ def __init__(self, config: LlamaConfig):
253
+ super().__init__()
254
+ self.hidden_size = config.hidden_size
255
+ self.self_attn = LlamaAttention(config=config)
256
+ self.mlp = LlamaMLP(
257
+ hidden_size=self.hidden_size,
258
+ intermediate_size=config.intermediate_size,
259
+ hidden_act=config.hidden_act,
260
+ )
261
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
262
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
270
+ output_attentions: Optional[bool] = False,
271
+ use_cache: Optional[bool] = False,
272
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
273
+ """
274
+ Args:
275
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
276
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
277
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
278
+ output_attentions (`bool`, *optional*):
279
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
280
+ returned tensors for more detail.
281
+ use_cache (`bool`, *optional*):
282
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
283
+ (see `past_key_values`).
284
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
285
+ """
286
+
287
+ residual = hidden_states
288
+
289
+ hidden_states = self.input_layernorm(hidden_states)
290
+
291
+ # Self Attention
292
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
293
+ hidden_states=hidden_states,
294
+ attention_mask=attention_mask,
295
+ position_ids=position_ids,
296
+ past_key_value=past_key_value,
297
+ output_attentions=output_attentions,
298
+ use_cache=use_cache,
299
+ )
300
+ hidden_states = residual + hidden_states
301
+
302
+ # Fully Connected
303
+ residual = hidden_states
304
+ hidden_states = self.post_attention_layernorm(hidden_states)
305
+ hidden_states = self.mlp(hidden_states)
306
+ hidden_states = residual + hidden_states
307
+
308
+ outputs = (hidden_states,)
309
+
310
+ if output_attentions:
311
+ outputs += (self_attn_weights,)
312
+
313
+ if use_cache:
314
+ outputs += (present_key_value,)
315
+
316
+ return outputs
317
+
318
+
319
+ LLAMA_START_DOCSTRING = r"""
320
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
321
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
322
+ etc.)
323
+
324
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
325
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
326
+ and behavior.
327
+
328
+ Parameters:
329
+ config ([`LlamaConfig`]):
330
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
331
+ load the weights associated with the model, only the configuration. Check out the
332
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
333
+ """
334
+
335
+
336
+ @add_start_docstrings(
337
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
338
+ LLAMA_START_DOCSTRING,
339
+ )
340
+ class LlamaPreTrainedModel(PreTrainedModel):
341
+ config_class = LlamaConfig
342
+ base_model_prefix = "model"
343
+ supports_gradient_checkpointing = True
344
+ _no_split_modules = ["LlamaDecoderLayer"]
345
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
346
+
347
+ def _init_weights(self, module):
348
+ std = self.config.initializer_range
349
+ if isinstance(module, nn.Linear):
350
+ module.weight.data.normal_(mean=0.0, std=std)
351
+ if module.bias is not None:
352
+ module.bias.data.zero_()
353
+ elif isinstance(module, nn.Embedding):
354
+ module.weight.data.normal_(mean=0.0, std=std)
355
+ if module.padding_idx is not None:
356
+ module.weight.data[module.padding_idx].zero_()
357
+
358
+ def _set_gradient_checkpointing(self, module, value=False):
359
+ if isinstance(module, LlamaModel):
360
+ module.gradient_checkpointing = value
361
+
362
+
363
+ LLAMA_INPUTS_DOCSTRING = r"""
364
+ Args:
365
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
366
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
367
+ it.
368
+
369
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
370
+ [`PreTrainedTokenizer.__call__`] for details.
371
+
372
+ [What are input IDs?](../glossary#input-ids)
373
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
374
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
375
+
376
+ - 1 for tokens that are **not masked**,
377
+ - 0 for tokens that are **masked**.
378
+
379
+ [What are attention masks?](../glossary#attention-mask)
380
+
381
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
382
+ [`PreTrainedTokenizer.__call__`] for details.
383
+
384
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
385
+ `past_key_values`).
386
+
387
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
388
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
389
+ information on the default strategy.
390
+
391
+ - 1 indicates the head is **not masked**,
392
+ - 0 indicates the head is **masked**.
393
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
394
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
395
+ config.n_positions - 1]`.
396
+
397
+ [What are position IDs?](../glossary#position-ids)
398
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
399
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
400
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
401
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
402
+
403
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
404
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
405
+
406
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
407
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
408
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
409
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
410
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
411
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
412
+ model's internal embedding lookup matrix.
413
+ use_cache (`bool`, *optional*):
414
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
415
+ `past_key_values`).
416
+ output_attentions (`bool`, *optional*):
417
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
418
+ tensors for more detail.
419
+ output_hidden_states (`bool`, *optional*):
420
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
421
+ more detail.
422
+ return_dict (`bool`, *optional*):
423
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
424
+ """
425
+
426
+
427
+ @add_start_docstrings(
428
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
429
+ LLAMA_START_DOCSTRING,
430
+ )
431
+ class LlamaModel(LlamaPreTrainedModel):
432
+ """
433
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
434
+
435
+ Args:
436
+ config: LlamaConfig
437
+ """
438
+
439
+ def __init__(self, config: LlamaConfig):
440
+ super().__init__(config)
441
+ self.padding_idx = config.pad_token_id
442
+ self.vocab_size = config.vocab_size
443
+
444
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
445
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
446
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
447
+
448
+ self.gradient_checkpointing = False
449
+ # Initialize weights and apply final processing
450
+ self.post_init()
451
+
452
+ def get_input_embeddings(self):
453
+ return self.embed_tokens
454
+
455
+ def set_input_embeddings(self, value):
456
+ self.embed_tokens = value
457
+
458
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
459
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
460
+ # create causal mask
461
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
462
+ combined_attention_mask = None
463
+ if input_shape[-1] > 1:
464
+ combined_attention_mask = _make_causal_mask(
465
+ input_shape,
466
+ inputs_embeds.dtype,
467
+ device=inputs_embeds.device,
468
+ past_key_values_length=past_key_values_length,
469
+ )
470
+
471
+ if attention_mask is not None:
472
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
473
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
474
+ inputs_embeds.device
475
+ )
476
+ combined_attention_mask = (
477
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
478
+ )
479
+
480
+ return combined_attention_mask
481
+
482
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
483
+ def forward(
484
+ self,
485
+ input_ids: torch.LongTensor = None,
486
+ attention_mask: Optional[torch.Tensor] = None,
487
+ position_ids: Optional[torch.LongTensor] = None,
488
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
489
+ inputs_embeds: Optional[torch.FloatTensor] = None,
490
+ use_cache: Optional[bool] = None,
491
+ output_attentions: Optional[bool] = None,
492
+ output_hidden_states: Optional[bool] = None,
493
+ return_dict: Optional[bool] = None,
494
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
495
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
496
+ output_hidden_states = (
497
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
498
+ )
499
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
500
+
501
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
502
+
503
+ # retrieve input_ids and inputs_embeds
504
+ if input_ids is not None and inputs_embeds is not None:
505
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
506
+ elif input_ids is not None:
507
+ batch_size, seq_length = input_ids.shape
508
+ elif inputs_embeds is not None:
509
+ batch_size, seq_length, _ = inputs_embeds.shape
510
+ else:
511
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
512
+
513
+ seq_length_with_past = seq_length
514
+ past_key_values_length = 0
515
+
516
+ if past_key_values is not None:
517
+ past_key_values_length = past_key_values[0][0].shape[2]
518
+ seq_length_with_past = seq_length_with_past + past_key_values_length
519
+
520
+ if position_ids is None:
521
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
522
+ position_ids = torch.arange(
523
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
524
+ )
525
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
526
+ else:
527
+ position_ids = position_ids.view(-1, seq_length).long()
528
+
529
+ if inputs_embeds is None:
530
+ inputs_embeds = self.embed_tokens(input_ids)
531
+ # embed positions
532
+ if attention_mask is None:
533
+ attention_mask = torch.ones(
534
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
535
+ )
536
+ attention_mask = self._prepare_decoder_attention_mask(
537
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
538
+ )
539
+
540
+ hidden_states = inputs_embeds
541
+
542
+ if self.gradient_checkpointing and self.training:
543
+ if use_cache:
544
+ logger.warning_once(
545
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
546
+ )
547
+ use_cache = False
548
+
549
+ # decoder layers
550
+ all_hidden_states = () if output_hidden_states else None
551
+ all_self_attns = () if output_attentions else None
552
+ next_decoder_cache = () if use_cache else None
553
+
554
+ for idx, decoder_layer in enumerate(self.layers):
555
+ if output_hidden_states:
556
+ all_hidden_states += (hidden_states,)
557
+
558
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
559
+
560
+ if self.gradient_checkpointing and self.training:
561
+
562
+ def create_custom_forward(module):
563
+ def custom_forward(*inputs):
564
+ # None for past_key_value
565
+ return module(*inputs, output_attentions, None)
566
+
567
+ return custom_forward
568
+
569
+ layer_outputs = torch.utils.checkpoint.checkpoint(
570
+ create_custom_forward(decoder_layer),
571
+ hidden_states,
572
+ attention_mask,
573
+ position_ids,
574
+ None,
575
+ )
576
+ else:
577
+ layer_outputs = decoder_layer(
578
+ hidden_states,
579
+ attention_mask=attention_mask,
580
+ position_ids=position_ids,
581
+ past_key_value=past_key_value,
582
+ output_attentions=output_attentions,
583
+ use_cache=use_cache,
584
+ )
585
+
586
+ hidden_states = layer_outputs[0]
587
+
588
+ if use_cache:
589
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
590
+
591
+ if output_attentions:
592
+ all_self_attns += (layer_outputs[1],)
593
+
594
+ hidden_states = self.norm(hidden_states)
595
+
596
+ # add hidden states from the last decoder layer
597
+ if output_hidden_states:
598
+ all_hidden_states += (hidden_states,)
599
+
600
+ next_cache = next_decoder_cache if use_cache else None
601
+ if not return_dict:
602
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
603
+ return BaseModelOutputWithPast(
604
+ last_hidden_state=hidden_states,
605
+ past_key_values=next_cache,
606
+ hidden_states=all_hidden_states,
607
+ attentions=all_self_attns,
608
+ )
609
+
610
+
611
+ class LlamaForCausalLM(LlamaPreTrainedModel):
612
+ def __init__(self, config):
613
+ super().__init__(config)
614
+ self.model = LlamaModel(config)
615
+
616
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
617
+
618
+ # Initialize weights and apply final processing
619
+ self.post_init()
620
+
621
+ def get_input_embeddings(self):
622
+ return self.model.embed_tokens
623
+
624
+ def set_input_embeddings(self, value):
625
+ self.model.embed_tokens = value
626
+
627
+ def get_output_embeddings(self):
628
+ return self.lm_head
629
+
630
+ def set_output_embeddings(self, new_embeddings):
631
+ self.lm_head = new_embeddings
632
+
633
+ def set_decoder(self, decoder):
634
+ self.model = decoder
635
+
636
+ def get_decoder(self):
637
+ return self.model
638
+
639
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
640
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
641
+ def forward(
642
+ self,
643
+ input_ids: torch.LongTensor = None,
644
+ attention_mask: Optional[torch.Tensor] = None,
645
+ position_ids: Optional[torch.LongTensor] = None,
646
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
647
+ inputs_embeds: Optional[torch.FloatTensor] = None,
648
+ labels: Optional[torch.LongTensor] = None,
649
+ use_cache: Optional[bool] = None,
650
+ output_attentions: Optional[bool] = None,
651
+ output_hidden_states: Optional[bool] = None,
652
+ return_dict: Optional[bool] = None,
653
+ reduction: Optional[str] = "mean",
654
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
655
+ r"""
656
+ Args:
657
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
658
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
659
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
660
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
661
+
662
+ Returns:
663
+
664
+ Example:
665
+
666
+ ```python
667
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
668
+
669
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
670
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
671
+
672
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
673
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
674
+
675
+ >>> # Generate
676
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
677
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
678
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
679
+ ```"""
680
+
681
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
682
+ output_hidden_states = (
683
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
684
+ )
685
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
686
+
687
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
688
+ outputs = self.model(
689
+ input_ids=input_ids,
690
+ attention_mask=attention_mask,
691
+ position_ids=position_ids,
692
+ past_key_values=past_key_values,
693
+ inputs_embeds=inputs_embeds,
694
+ use_cache=use_cache,
695
+ output_attentions=output_attentions,
696
+ output_hidden_states=output_hidden_states,
697
+ return_dict=return_dict,
698
+ )
699
+
700
+ hidden_states = outputs[0]
701
+ logits = self.lm_head(hidden_states)
702
+
703
+ loss = None
704
+ if labels is not None:
705
+ # Shift so that tokens < n predict n
706
+ shift_logits = logits[..., :-1, :].contiguous()
707
+ shift_labels = labels[..., 1:].contiguous()
708
+ # Flatten the tokens
709
+ loss_fct = CrossEntropyLoss(reduction=reduction)
710
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
711
+ shift_labels = shift_labels.view(-1)
712
+ # Enable model parallelism
713
+ shift_labels = shift_labels.to(shift_logits.device)
714
+ loss = loss_fct(shift_logits, shift_labels)
715
+ if reduction == "none":
716
+ # loss = loss.view(logits.size(0), -1).sum(1)
717
+ loss = loss.view(logits.size(0), -1).mean(1)
718
+
719
+ if not return_dict:
720
+ output = (logits,) + outputs[1:]
721
+ return (loss,) + output if loss is not None else output
722
+
723
+ return CausalLMOutputWithPast(
724
+ loss=loss,
725
+ logits=logits,
726
+ past_key_values=outputs.past_key_values,
727
+ hidden_states=outputs.hidden_states,
728
+ attentions=outputs.attentions,
729
+ )
730
+
731
+ def prepare_inputs_for_generation(
732
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
733
+ ):
734
+ if past_key_values:
735
+ input_ids = input_ids[:, -1:]
736
+
737
+ position_ids = kwargs.get("position_ids", None)
738
+ if attention_mask is not None and position_ids is None:
739
+ # create position_ids on the fly for batch generation
740
+ position_ids = attention_mask.long().cumsum(-1) - 1
741
+ position_ids.masked_fill_(attention_mask == 0, 1)
742
+ if past_key_values:
743
+ position_ids = position_ids[:, -1].unsqueeze(-1)
744
+
745
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
746
+ if inputs_embeds is not None and past_key_values is None:
747
+ model_inputs = {"inputs_embeds": inputs_embeds}
748
+ else:
749
+ model_inputs = {"input_ids": input_ids}
750
+
751
+ model_inputs.update(
752
+ {
753
+ "position_ids": position_ids,
754
+ "past_key_values": past_key_values,
755
+ "use_cache": kwargs.get("use_cache"),
756
+ "attention_mask": attention_mask,
757
+ }
758
+ )
759
+ return model_inputs
760
+
761
+ @staticmethod
762
+ def _reorder_cache(past_key_values, beam_idx):
763
+ reordered_past = ()
764
+ for layer_past in past_key_values:
765
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
766
+ return reordered_past
767
+
768
+
769
+ @add_start_docstrings(
770
+ """
771
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
772
+
773
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
774
+ (e.g. GPT-2) do.
775
+
776
+ Since it does classification on the last token, it requires to know the position of the last token. If a
777
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
778
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
779
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
780
+ each row of the batch).
781
+ """,
782
+ LLAMA_START_DOCSTRING,
783
+ )
784
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
785
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
786
+
787
+ def __init__(self, config):
788
+ super().__init__(config)
789
+ self.num_labels = config.num_labels
790
+ self.model = LlamaModel(config)
791
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
792
+
793
+ # Initialize weights and apply final processing
794
+ self.post_init()
795
+
796
+ def get_input_embeddings(self):
797
+ return self.model.embed_tokens
798
+
799
+ def set_input_embeddings(self, value):
800
+ self.model.embed_tokens = value
801
+
802
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
803
+ def forward(
804
+ self,
805
+ input_ids: torch.LongTensor = None,
806
+ attention_mask: Optional[torch.Tensor] = None,
807
+ position_ids: Optional[torch.LongTensor] = None,
808
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
809
+ inputs_embeds: Optional[torch.FloatTensor] = None,
810
+ labels: Optional[torch.LongTensor] = None,
811
+ use_cache: Optional[bool] = None,
812
+ output_attentions: Optional[bool] = None,
813
+ output_hidden_states: Optional[bool] = None,
814
+ return_dict: Optional[bool] = None,
815
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
816
+ r"""
817
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
818
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
819
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
820
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
821
+ """
822
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
823
+
824
+ transformer_outputs = self.model(
825
+ input_ids,
826
+ attention_mask=attention_mask,
827
+ position_ids=position_ids,
828
+ past_key_values=past_key_values,
829
+ inputs_embeds=inputs_embeds,
830
+ use_cache=use_cache,
831
+ output_attentions=output_attentions,
832
+ output_hidden_states=output_hidden_states,
833
+ return_dict=return_dict,
834
+ )
835
+ hidden_states = transformer_outputs[0]
836
+ logits = self.score(hidden_states)
837
+
838
+ if input_ids is not None:
839
+ batch_size = input_ids.shape[0]
840
+ else:
841
+ batch_size = inputs_embeds.shape[0]
842
+
843
+ if self.config.pad_token_id is None and batch_size != 1:
844
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
845
+ if self.config.pad_token_id is None:
846
+ sequence_lengths = -1
847
+ else:
848
+ if input_ids is not None:
849
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
850
+ else:
851
+ sequence_lengths = -1
852
+
853
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
854
+
855
+ loss = None
856
+ if labels is not None:
857
+ labels = labels.to(logits.device)
858
+ if self.config.problem_type is None:
859
+ if self.num_labels == 1:
860
+ self.config.problem_type = "regression"
861
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
862
+ self.config.problem_type = "single_label_classification"
863
+ else:
864
+ self.config.problem_type = "multi_label_classification"
865
+
866
+ if self.config.problem_type == "regression":
867
+ loss_fct = MSELoss()
868
+ if self.num_labels == 1:
869
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
870
+ else:
871
+ loss = loss_fct(pooled_logits, labels)
872
+ elif self.config.problem_type == "single_label_classification":
873
+ loss_fct = CrossEntropyLoss()
874
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
875
+ elif self.config.problem_type == "multi_label_classification":
876
+ loss_fct = BCEWithLogitsLoss()
877
+ loss = loss_fct(pooled_logits, labels)
878
+ if not return_dict:
879
+ output = (pooled_logits,) + transformer_outputs[1:]
880
+ return ((loss,) + output) if loss is not None else output
881
+
882
+ return SequenceClassifierOutputWithPast(
883
+ loss=loss,
884
+ logits=pooled_logits,
885
+ past_key_values=transformer_outputs.past_key_values,
886
+ hidden_states=transformer_outputs.hidden_states,
887
+ attentions=transformer_outputs.attentions,
888
+ )
bliva/models/modeling_t5.py ADDED
@@ -0,0 +1,2063 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch T5 model."""
16
+
17
+
18
+ import copy
19
+ import math
20
+ import os
21
+ import warnings
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+ from torch.utils.checkpoint import checkpoint
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ Seq2SeqLMOutput,
34
+ Seq2SeqModelOutput,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.pytorch_utils import (
38
+ ALL_LAYERNORM_LAYERS,
39
+ find_pruneable_heads_and_indices,
40
+ prune_linear_layer,
41
+ )
42
+ from transformers.utils import (
43
+ DUMMY_INPUTS,
44
+ DUMMY_MASK,
45
+ add_start_docstrings,
46
+ add_start_docstrings_to_model_forward,
47
+ is_torch_fx_proxy,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
52
+ from transformers.models.t5.configuration_t5 import T5Config
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CONFIG_FOR_DOC = "T5Config"
58
+ _TOKENIZER_FOR_DOC = "T5Tokenizer"
59
+ _CHECKPOINT_FOR_DOC = "t5-small"
60
+
61
+ ####################################################
62
+ # This dict contains ids and associated url
63
+ # for the pretrained weights provided with the models
64
+ ####################################################
65
+ T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
66
+ "t5-small",
67
+ "t5-base",
68
+ "t5-large",
69
+ "t5-3b",
70
+ "t5-11b",
71
+ # See all T5 models at https://huggingface.co/models?filter=t5
72
+ ]
73
+
74
+
75
+ ####################################################
76
+ # This is a conversion method from TF 1.0 to PyTorch
77
+ # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
78
+ ####################################################
79
+ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
80
+ """Load tf checkpoints in a pytorch model."""
81
+ try:
82
+ import re
83
+
84
+ import numpy as np
85
+ import tensorflow as tf
86
+ except ImportError:
87
+ logger.error(
88
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
89
+ "https://www.tensorflow.org/install/ for installation instructions."
90
+ )
91
+ raise
92
+ tf_path = os.path.abspath(tf_checkpoint_path)
93
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
94
+ # Load weights from TF model
95
+ init_vars = tf.train.list_variables(tf_path)
96
+ names = []
97
+ tf_weights = {}
98
+ for name, shape in init_vars:
99
+ logger.info(f"Loading TF weight {name} with shape {shape}")
100
+ array = tf.train.load_variable(tf_path, name)
101
+ names.append(name)
102
+ tf_weights[name] = array
103
+
104
+ for txt_name in names:
105
+ name = txt_name.split("/")
106
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
107
+ # which are not required for using pretrained model
108
+ if any(
109
+ n
110
+ in [
111
+ "adam_v",
112
+ "adam_m",
113
+ "AdamWeightDecayOptimizer",
114
+ "AdamWeightDecayOptimizer_1",
115
+ "global_step",
116
+ ]
117
+ for n in name
118
+ ):
119
+ logger.info(f"Skipping {'/'.join(name)}")
120
+ tf_weights.pop(txt_name, None)
121
+ continue
122
+ if "_slot_" in name[-1]:
123
+ logger.info(f"Skipping {'/'.join(name)}")
124
+ tf_weights.pop(txt_name, None)
125
+ continue
126
+ pointer = model
127
+ array = tf_weights[txt_name]
128
+
129
+ for m_name in name:
130
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
131
+ scope_names = re.split(r"_(\d+)", m_name)
132
+ else:
133
+ scope_names = [m_name]
134
+ if scope_names[0] in ["kernel", "scale", "embedding"]:
135
+ pointer = getattr(pointer, "weight")
136
+ elif scope_names[0] == "self_attention":
137
+ pointer = getattr(pointer, "layer")
138
+ pointer = pointer[0]
139
+ elif scope_names[0] == "enc_dec_attention":
140
+ pointer = getattr(pointer, "layer")
141
+ pointer = pointer[1]
142
+ elif scope_names[0] == "dense_relu_dense":
143
+ pointer = getattr(pointer, "layer")
144
+ pointer = pointer[2]
145
+ elif scope_names[0] == "rms_norm":
146
+ if hasattr(pointer, "layer_norm"):
147
+ pointer = getattr(pointer, "layer_norm")
148
+ elif hasattr(pointer, "final_layer_norm"):
149
+ pointer = getattr(pointer, "final_layer_norm")
150
+ elif scope_names[0] == "scale":
151
+ pointer = getattr(pointer, "weight")
152
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
153
+ pointer = getattr(pointer, "bias")
154
+ elif scope_names[0] == "squad":
155
+ pointer = getattr(pointer, "classifier")
156
+ elif scope_names[0] == "decoder" and name[1] == "logits":
157
+ continue
158
+ elif scope_names[0] == "logits":
159
+ pointer = getattr(pointer, "lm_head")
160
+ elif (
161
+ scope_names[0] == "wi"
162
+ and len(scope_names) > 1
163
+ and scope_names[1].isdigit()
164
+ ):
165
+ pointer = getattr(pointer, f"wi_{scope_names[1]}")
166
+ continue
167
+ else:
168
+ try:
169
+ pointer = getattr(pointer, scope_names[0])
170
+ except AttributeError:
171
+ logger.info(f"Skipping {'/'.join(name)}")
172
+ continue
173
+ if len(scope_names) >= 2:
174
+ num = int(scope_names[1])
175
+ pointer = pointer[num]
176
+ if scope_names[0] not in ["kernel", "scale", "embedding"]:
177
+ pointer = getattr(pointer, "weight")
178
+ if scope_names[0] != "embedding":
179
+ logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
180
+ array = np.transpose(array)
181
+ try:
182
+ assert (
183
+ pointer.shape == array.shape
184
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
185
+ except AssertionError as e:
186
+ e.args += (pointer.shape, array.shape)
187
+ raise
188
+ logger.info(f"Initialize PyTorch weight {name}")
189
+ pointer.data = torch.from_numpy(array.astype(np.float32))
190
+ tf_weights.pop(txt_name, None)
191
+
192
+ logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
193
+ return model
194
+
195
+
196
+ ####################################################
197
+ # PyTorch Models are constructed by sub-classing
198
+ # - torch.nn.Module for the layers and
199
+ # - PreTrainedModel for the models (it-self a sub-class of nn.Module)
200
+ ####################################################
201
+ PARALLELIZE_DOCSTRING = r"""
202
+ This is an experimental feature and is a subject to change at a moment's notice.
203
+
204
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
205
+ it will evenly distribute blocks across all devices.
206
+
207
+ Args:
208
+ device_map (`Dict[int, list]`, optional, defaults to None):
209
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
210
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
211
+ have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
212
+ following number of attention modules:
213
+
214
+ - t5-small: 6
215
+ - t5-base: 12
216
+ - t5-large: 24
217
+ - t5-3b: 24
218
+ - t5-11b: 24
219
+
220
+ Example:
221
+
222
+ ```python
223
+ # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
224
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
225
+ device_map = {
226
+ 0: [0, 1, 2],
227
+ 1: [3, 4, 5, 6, 7, 8, 9],
228
+ 2: [10, 11, 12, 13, 14, 15, 16],
229
+ 3: [17, 18, 19, 20, 21, 22, 23],
230
+ }
231
+ model.parallelize(device_map)
232
+ ```
233
+ """
234
+ DEPARALLELIZE_DOCSTRING = r"""
235
+ Moves the model to cpu from a model parallel state.
236
+
237
+ Example:
238
+
239
+ ```python
240
+ # On a 4 GPU machine with t5-3b:
241
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
242
+ device_map = {
243
+ 0: [0, 1, 2],
244
+ 1: [3, 4, 5, 6, 7, 8, 9],
245
+ 2: [10, 11, 12, 13, 14, 15, 16],
246
+ 3: [17, 18, 19, 20, 21, 22, 23],
247
+ }
248
+ model.parallelize(device_map) # Splits the model across several devices
249
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
250
+ ```
251
+ """
252
+
253
+
254
+ class T5LayerNorm(nn.Module):
255
+ def __init__(self, hidden_size, eps=1e-6):
256
+ """
257
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
258
+ """
259
+ super().__init__()
260
+ self.weight = nn.Parameter(torch.ones(hidden_size))
261
+ self.variance_epsilon = eps
262
+
263
+ def forward(self, hidden_states):
264
+
265
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
266
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
267
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
268
+ # half-precision inputs is done in fp32
269
+
270
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
271
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
272
+
273
+ # convert into half-precision if necessary
274
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
275
+ hidden_states = hidden_states.to(self.weight.dtype)
276
+
277
+ return self.weight * hidden_states
278
+
279
+
280
+ try:
281
+ from apex.normalization import FusedRMSNorm
282
+
283
+ T5LayerNorm = FusedRMSNorm # noqa
284
+
285
+ logger.info(
286
+ "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm"
287
+ )
288
+ except ImportError:
289
+ # using the normal T5LayerNorm
290
+ pass
291
+ except Exception:
292
+ logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
293
+ pass
294
+
295
+ ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
296
+
297
+
298
+ class T5DenseActDense(nn.Module):
299
+ def __init__(self, config: T5Config):
300
+ super().__init__()
301
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
302
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
303
+ self.dropout = nn.Dropout(config.dropout_rate)
304
+ self.act = ACT2FN[config.dense_act_fn]
305
+
306
+ def forward(self, hidden_states):
307
+ hidden_states = self.wi(hidden_states)
308
+ hidden_states = self.act(hidden_states)
309
+ hidden_states = self.dropout(hidden_states)
310
+ hidden_states = self.wo(hidden_states)
311
+ return hidden_states
312
+
313
+
314
+ class T5DenseGatedActDense(nn.Module):
315
+ def __init__(self, config: T5Config):
316
+ super().__init__()
317
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
318
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
319
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
320
+ self.dropout = nn.Dropout(config.dropout_rate)
321
+ self.act = ACT2FN[config.dense_act_fn]
322
+
323
+ def forward(self, hidden_states):
324
+ hidden_gelu = self.act(self.wi_0(hidden_states))
325
+ hidden_linear = self.wi_1(hidden_states)
326
+ hidden_states = hidden_gelu * hidden_linear
327
+ hidden_states = self.dropout(hidden_states)
328
+ hidden_states = self.wo(hidden_states)
329
+ return hidden_states
330
+
331
+
332
+ class T5LayerFF(nn.Module):
333
+ def __init__(self, config: T5Config):
334
+ super().__init__()
335
+ if config.is_gated_act:
336
+ self.DenseReluDense = T5DenseGatedActDense(config)
337
+ else:
338
+ self.DenseReluDense = T5DenseActDense(config)
339
+
340
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
341
+ self.dropout = nn.Dropout(config.dropout_rate)
342
+
343
+ def forward(self, hidden_states):
344
+ forwarded_states = self.layer_norm(hidden_states)
345
+ forwarded_states = self.DenseReluDense(forwarded_states)
346
+ hidden_states = hidden_states + self.dropout(forwarded_states)
347
+ return hidden_states
348
+
349
+
350
+ class T5Attention(nn.Module):
351
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
352
+ super().__init__()
353
+ self.is_decoder = config.is_decoder
354
+ self.has_relative_attention_bias = has_relative_attention_bias
355
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
356
+ self.relative_attention_max_distance = config.relative_attention_max_distance
357
+ self.d_model = config.d_model
358
+ self.key_value_proj_dim = config.d_kv
359
+ self.n_heads = config.num_heads
360
+ self.dropout = config.dropout_rate
361
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
362
+
363
+ # Mesh TensorFlow initialization to avoid scaling before softmax
364
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
365
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
366
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
367
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
368
+
369
+ if self.has_relative_attention_bias:
370
+ self.relative_attention_bias = nn.Embedding(
371
+ self.relative_attention_num_buckets, self.n_heads
372
+ )
373
+ self.pruned_heads = set()
374
+ self.gradient_checkpointing = False
375
+
376
+ def prune_heads(self, heads):
377
+ if len(heads) == 0:
378
+ return
379
+ heads, index = find_pruneable_heads_and_indices(
380
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
381
+ )
382
+ # Prune linear layers
383
+ self.q = prune_linear_layer(self.q, index)
384
+ self.k = prune_linear_layer(self.k, index)
385
+ self.v = prune_linear_layer(self.v, index)
386
+ self.o = prune_linear_layer(self.o, index, dim=1)
387
+ # Update hyper params
388
+ self.n_heads = self.n_heads - len(heads)
389
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
390
+ self.pruned_heads = self.pruned_heads.union(heads)
391
+
392
+ @staticmethod
393
+ def _relative_position_bucket(
394
+ relative_position, bidirectional=True, num_buckets=32, max_distance=128
395
+ ):
396
+ """
397
+ Adapted from Mesh Tensorflow:
398
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
399
+
400
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
401
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
402
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
403
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
404
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
405
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
406
+
407
+ Args:
408
+ relative_position: an int32 Tensor
409
+ bidirectional: a boolean - whether the attention is bidirectional
410
+ num_buckets: an integer
411
+ max_distance: an integer
412
+
413
+ Returns:
414
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
415
+ """
416
+ relative_buckets = 0
417
+ if bidirectional:
418
+ num_buckets //= 2
419
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
420
+ relative_position = torch.abs(relative_position)
421
+ else:
422
+ relative_position = -torch.min(
423
+ relative_position, torch.zeros_like(relative_position)
424
+ )
425
+ # now relative_position is in the range [0, inf)
426
+
427
+ # half of the buckets are for exact increments in positions
428
+ max_exact = num_buckets // 2
429
+ is_small = relative_position < max_exact
430
+
431
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
432
+ relative_position_if_large = max_exact + (
433
+ torch.log(relative_position.float() / max_exact)
434
+ / math.log(max_distance / max_exact)
435
+ * (num_buckets - max_exact)
436
+ ).to(torch.long)
437
+ relative_position_if_large = torch.min(
438
+ relative_position_if_large,
439
+ torch.full_like(relative_position_if_large, num_buckets - 1),
440
+ )
441
+
442
+ relative_buckets += torch.where(
443
+ is_small, relative_position, relative_position_if_large
444
+ )
445
+ return relative_buckets
446
+
447
+ def compute_bias(self, query_length, key_length, device=None):
448
+ """Compute binned relative position bias"""
449
+ if device is None:
450
+ device = self.relative_attention_bias.weight.device
451
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[
452
+ :, None
453
+ ]
454
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
455
+ None, :
456
+ ]
457
+ relative_position = (
458
+ memory_position - context_position
459
+ ) # shape (query_length, key_length)
460
+ relative_position_bucket = self._relative_position_bucket(
461
+ relative_position, # shape (query_length, key_length)
462
+ bidirectional=(not self.is_decoder),
463
+ num_buckets=self.relative_attention_num_buckets,
464
+ max_distance=self.relative_attention_max_distance,
465
+ )
466
+ values = self.relative_attention_bias(
467
+ relative_position_bucket
468
+ ) # shape (query_length, key_length, num_heads)
469
+ values = values.permute([2, 0, 1]).unsqueeze(
470
+ 0
471
+ ) # shape (1, num_heads, query_length, key_length)
472
+ return values
473
+
474
+ def forward(
475
+ self,
476
+ hidden_states,
477
+ mask=None,
478
+ key_value_states=None,
479
+ position_bias=None,
480
+ past_key_value=None,
481
+ layer_head_mask=None,
482
+ query_length=None,
483
+ use_cache=False,
484
+ output_attentions=False,
485
+ ):
486
+ """
487
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
488
+ """
489
+ # Input is (batch_size, seq_length, dim)
490
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
491
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
492
+ batch_size, seq_length = hidden_states.shape[:2]
493
+
494
+ real_seq_length = seq_length
495
+
496
+ if past_key_value is not None:
497
+ assert (
498
+ len(past_key_value) == 2
499
+ ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
500
+ real_seq_length += (
501
+ past_key_value[0].shape[2] if query_length is None else query_length
502
+ )
503
+
504
+ key_length = (
505
+ real_seq_length if key_value_states is None else key_value_states.shape[1]
506
+ )
507
+
508
+ def shape(states):
509
+ """projection"""
510
+ return states.view(
511
+ batch_size, -1, self.n_heads, self.key_value_proj_dim
512
+ ).transpose(1, 2)
513
+
514
+ def unshape(states):
515
+ """reshape"""
516
+ return (
517
+ states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
518
+ )
519
+
520
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
521
+ """projects hidden states correctly to key/query states"""
522
+ if key_value_states is None:
523
+ # self-attn
524
+ # (batch_size, n_heads, seq_length, dim_per_head)
525
+ hidden_states = shape(proj_layer(hidden_states))
526
+ elif past_key_value is None:
527
+ # cross-attn
528
+ # (batch_size, n_heads, seq_length, dim_per_head)
529
+ hidden_states = shape(proj_layer(key_value_states))
530
+
531
+ if past_key_value is not None:
532
+ if key_value_states is None:
533
+ # self-attn
534
+ # (batch_size, n_heads, key_length, dim_per_head)
535
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
536
+ else:
537
+ # cross-attn
538
+ hidden_states = past_key_value
539
+ return hidden_states
540
+
541
+ # get query states
542
+ query_states = shape(
543
+ self.q(hidden_states)
544
+ ) # (batch_size, n_heads, seq_length, dim_per_head)
545
+
546
+ # get key/value states
547
+ key_states = project(
548
+ hidden_states,
549
+ self.k,
550
+ key_value_states,
551
+ past_key_value[0] if past_key_value is not None else None,
552
+ )
553
+ value_states = project(
554
+ hidden_states,
555
+ self.v,
556
+ key_value_states,
557
+ past_key_value[1] if past_key_value is not None else None,
558
+ )
559
+
560
+ # compute scores
561
+ scores = torch.matmul(
562
+ query_states, key_states.transpose(3, 2)
563
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
564
+
565
+ if position_bias is None:
566
+ if not self.has_relative_attention_bias:
567
+ position_bias = torch.zeros(
568
+ (1, self.n_heads, real_seq_length, key_length),
569
+ device=scores.device,
570
+ dtype=scores.dtype,
571
+ )
572
+ if self.gradient_checkpointing and self.training:
573
+ position_bias.requires_grad = True
574
+ else:
575
+ position_bias = self.compute_bias(
576
+ real_seq_length, key_length, device=scores.device
577
+ )
578
+
579
+ # if key and values are already calculated
580
+ # we want only the last query position bias
581
+ if past_key_value is not None:
582
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
583
+
584
+ if mask is not None:
585
+ position_bias = (
586
+ position_bias + mask
587
+ ) # (batch_size, n_heads, seq_length, key_length)
588
+
589
+ if self.pruned_heads:
590
+ mask = torch.ones(position_bias.shape[1])
591
+ mask[list(self.pruned_heads)] = 0
592
+ position_bias_masked = position_bias[:, mask.bool()]
593
+ else:
594
+ position_bias_masked = position_bias
595
+
596
+ scores += position_bias_masked
597
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
598
+ scores
599
+ ) # (batch_size, n_heads, seq_length, key_length)
600
+ attn_weights = nn.functional.dropout(
601
+ attn_weights, p=self.dropout, training=self.training
602
+ ) # (batch_size, n_heads, seq_length, key_length)
603
+
604
+ # Mask heads if we want to
605
+ if layer_head_mask is not None:
606
+ attn_weights = attn_weights * layer_head_mask
607
+
608
+ attn_output = unshape(
609
+ torch.matmul(attn_weights, value_states)
610
+ ) # (batch_size, seq_length, dim)
611
+ attn_output = self.o(attn_output)
612
+
613
+ present_key_value_state = (
614
+ (key_states, value_states) if (self.is_decoder and use_cache) else None
615
+ )
616
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
617
+
618
+ if output_attentions:
619
+ outputs = outputs + (attn_weights,)
620
+ return outputs
621
+
622
+
623
+ class T5LayerSelfAttention(nn.Module):
624
+ def __init__(self, config, has_relative_attention_bias=False):
625
+ super().__init__()
626
+ self.SelfAttention = T5Attention(
627
+ config, has_relative_attention_bias=has_relative_attention_bias
628
+ )
629
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
630
+ self.dropout = nn.Dropout(config.dropout_rate)
631
+
632
+ def forward(
633
+ self,
634
+ hidden_states,
635
+ attention_mask=None,
636
+ position_bias=None,
637
+ layer_head_mask=None,
638
+ past_key_value=None,
639
+ use_cache=False,
640
+ output_attentions=False,
641
+ ):
642
+ normed_hidden_states = self.layer_norm(hidden_states)
643
+ attention_output = self.SelfAttention(
644
+ normed_hidden_states,
645
+ mask=attention_mask,
646
+ position_bias=position_bias,
647
+ layer_head_mask=layer_head_mask,
648
+ past_key_value=past_key_value,
649
+ use_cache=use_cache,
650
+ output_attentions=output_attentions,
651
+ )
652
+ hidden_states = hidden_states + self.dropout(attention_output[0])
653
+ outputs = (hidden_states,) + attention_output[
654
+ 1:
655
+ ] # add attentions if we output them
656
+ return outputs
657
+
658
+
659
+ class T5LayerCrossAttention(nn.Module):
660
+ def __init__(self, config):
661
+ super().__init__()
662
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
663
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
664
+ self.dropout = nn.Dropout(config.dropout_rate)
665
+
666
+ def forward(
667
+ self,
668
+ hidden_states,
669
+ key_value_states,
670
+ attention_mask=None,
671
+ position_bias=None,
672
+ layer_head_mask=None,
673
+ past_key_value=None,
674
+ use_cache=False,
675
+ query_length=None,
676
+ output_attentions=False,
677
+ ):
678
+ normed_hidden_states = self.layer_norm(hidden_states)
679
+ attention_output = self.EncDecAttention(
680
+ normed_hidden_states,
681
+ mask=attention_mask,
682
+ key_value_states=key_value_states,
683
+ position_bias=position_bias,
684
+ layer_head_mask=layer_head_mask,
685
+ past_key_value=past_key_value,
686
+ use_cache=use_cache,
687
+ query_length=query_length,
688
+ output_attentions=output_attentions,
689
+ )
690
+ layer_output = hidden_states + self.dropout(attention_output[0])
691
+ outputs = (layer_output,) + attention_output[
692
+ 1:
693
+ ] # add attentions if we output them
694
+ return outputs
695
+
696
+
697
+ class T5Block(nn.Module):
698
+ def __init__(self, config, has_relative_attention_bias=False):
699
+ super().__init__()
700
+ self.is_decoder = config.is_decoder
701
+ self.layer = nn.ModuleList()
702
+ self.layer.append(
703
+ T5LayerSelfAttention(
704
+ config, has_relative_attention_bias=has_relative_attention_bias
705
+ )
706
+ )
707
+ if self.is_decoder:
708
+ self.layer.append(T5LayerCrossAttention(config))
709
+
710
+ self.layer.append(T5LayerFF(config))
711
+
712
+ def forward(
713
+ self,
714
+ hidden_states,
715
+ attention_mask=None,
716
+ position_bias=None,
717
+ encoder_hidden_states=None,
718
+ encoder_attention_mask=None,
719
+ encoder_decoder_position_bias=None,
720
+ layer_head_mask=None,
721
+ cross_attn_layer_head_mask=None,
722
+ past_key_value=None,
723
+ use_cache=False,
724
+ output_attentions=False,
725
+ return_dict=True,
726
+ ):
727
+
728
+ if past_key_value is not None:
729
+ if not self.is_decoder:
730
+ logger.warning(
731
+ "`past_key_values` is passed to the encoder. Please make sure this is intended."
732
+ )
733
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
734
+
735
+ if len(past_key_value) != expected_num_past_key_values:
736
+ raise ValueError(
737
+ f"There should be {expected_num_past_key_values} past states. "
738
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
739
+ f"Got {len(past_key_value)} past key / value states"
740
+ )
741
+
742
+ self_attn_past_key_value = past_key_value[:2]
743
+ cross_attn_past_key_value = past_key_value[2:]
744
+ else:
745
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
746
+
747
+ self_attention_outputs = self.layer[0](
748
+ hidden_states,
749
+ attention_mask=attention_mask,
750
+ position_bias=position_bias,
751
+ layer_head_mask=layer_head_mask,
752
+ past_key_value=self_attn_past_key_value,
753
+ use_cache=use_cache,
754
+ output_attentions=output_attentions,
755
+ )
756
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
757
+ attention_outputs = self_attention_outputs[
758
+ 2:
759
+ ] # Keep self-attention outputs and relative position weights
760
+
761
+ # clamp inf values to enable fp16 training
762
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
763
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
764
+ hidden_states = torch.clamp(
765
+ hidden_states, min=-clamp_value, max=clamp_value
766
+ )
767
+
768
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
769
+ if do_cross_attention:
770
+ # the actual query length is unknown for cross attention
771
+ # if using past key value states. Need to inject it here
772
+ if present_key_value_state is not None:
773
+ query_length = present_key_value_state[0].shape[2]
774
+ else:
775
+ query_length = None
776
+
777
+ cross_attention_outputs = self.layer[1](
778
+ hidden_states,
779
+ key_value_states=encoder_hidden_states,
780
+ attention_mask=encoder_attention_mask,
781
+ position_bias=encoder_decoder_position_bias,
782
+ layer_head_mask=cross_attn_layer_head_mask,
783
+ past_key_value=cross_attn_past_key_value,
784
+ query_length=query_length,
785
+ use_cache=use_cache,
786
+ output_attentions=output_attentions,
787
+ )
788
+ hidden_states = cross_attention_outputs[0]
789
+
790
+ # clamp inf values to enable fp16 training
791
+ if (
792
+ hidden_states.dtype == torch.float16
793
+ and torch.isinf(hidden_states).any()
794
+ ):
795
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
796
+ hidden_states = torch.clamp(
797
+ hidden_states, min=-clamp_value, max=clamp_value
798
+ )
799
+
800
+ # Combine self attn and cross attn key value states
801
+ if present_key_value_state is not None:
802
+ present_key_value_state = (
803
+ present_key_value_state + cross_attention_outputs[1]
804
+ )
805
+
806
+ # Keep cross-attention outputs and relative position weights
807
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
808
+
809
+ # Apply Feed Forward layer
810
+ hidden_states = self.layer[-1](hidden_states)
811
+
812
+ # clamp inf values to enable fp16 training
813
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
814
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
815
+ hidden_states = torch.clamp(
816
+ hidden_states, min=-clamp_value, max=clamp_value
817
+ )
818
+
819
+ outputs = (hidden_states,)
820
+
821
+ if use_cache:
822
+ outputs = outputs + (present_key_value_state,) + attention_outputs
823
+ else:
824
+ outputs = outputs + attention_outputs
825
+
826
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
827
+
828
+
829
+ class T5PreTrainedModel(PreTrainedModel):
830
+ """
831
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
832
+ models.
833
+ """
834
+
835
+ config_class = T5Config
836
+ load_tf_weights = load_tf_weights_in_t5
837
+ base_model_prefix = "transformer"
838
+ is_parallelizable = True
839
+ supports_gradient_checkpointing = True
840
+ _no_split_modules = ["T5Block"]
841
+
842
+ @property
843
+ def dummy_inputs(self):
844
+ input_ids = torch.tensor(DUMMY_INPUTS)
845
+ input_mask = torch.tensor(DUMMY_MASK)
846
+ dummy_inputs = {
847
+ "decoder_input_ids": input_ids,
848
+ "input_ids": input_ids,
849
+ "decoder_attention_mask": input_mask,
850
+ }
851
+ return dummy_inputs
852
+
853
+ def _init_weights(self, module):
854
+ """Initialize the weights"""
855
+ factor = (
856
+ self.config.initializer_factor
857
+ ) # Used for testing weights initialization
858
+ if isinstance(module, T5LayerNorm):
859
+ module.weight.data.fill_(factor * 1.0)
860
+ elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
861
+ # Mesh TensorFlow embeddings initialization
862
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
863
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
864
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
865
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
866
+ elif isinstance(module, T5DenseActDense):
867
+ # Mesh TensorFlow FF initialization
868
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
869
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
870
+ module.wi.weight.data.normal_(
871
+ mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
872
+ )
873
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
874
+ module.wi.bias.data.zero_()
875
+ module.wo.weight.data.normal_(
876
+ mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
877
+ )
878
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
879
+ module.wo.bias.data.zero_()
880
+ elif isinstance(module, T5DenseGatedActDense):
881
+ module.wi_0.weight.data.normal_(
882
+ mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
883
+ )
884
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
885
+ module.wi_0.bias.data.zero_()
886
+ module.wi_1.weight.data.normal_(
887
+ mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
888
+ )
889
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
890
+ module.wi_1.bias.data.zero_()
891
+ module.wo.weight.data.normal_(
892
+ mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
893
+ )
894
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
895
+ module.wo.bias.data.zero_()
896
+ elif isinstance(module, T5Attention):
897
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
898
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
899
+ d_model = self.config.d_model
900
+ key_value_proj_dim = self.config.d_kv
901
+ n_heads = self.config.num_heads
902
+ module.q.weight.data.normal_(
903
+ mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)
904
+ )
905
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
906
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
907
+ module.o.weight.data.normal_(
908
+ mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)
909
+ )
910
+ if module.has_relative_attention_bias:
911
+ module.relative_attention_bias.weight.data.normal_(
912
+ mean=0.0, std=factor * ((d_model) ** -0.5)
913
+ )
914
+
915
+ def _set_gradient_checkpointing(self, module, value=False):
916
+ if isinstance(module, (T5Attention, T5Stack)):
917
+ module.gradient_checkpointing = value
918
+
919
+ def _shift_right(self, input_ids):
920
+ decoder_start_token_id = self.config.decoder_start_token_id
921
+ pad_token_id = self.config.pad_token_id
922
+
923
+ assert decoder_start_token_id is not None, (
924
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
925
+ " See T5 docs for more information"
926
+ )
927
+
928
+ # shift inputs to the right
929
+ if is_torch_fx_proxy(input_ids):
930
+ # Item assignment is not supported natively for proxies.
931
+ shifted_input_ids = torch.full(
932
+ input_ids.shape[:-1] + (1,), decoder_start_token_id
933
+ )
934
+ shifted_input_ids = torch.cat(
935
+ [shifted_input_ids, input_ids[..., :-1]], dim=-1
936
+ )
937
+ else:
938
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
939
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
940
+ shifted_input_ids[..., 0] = decoder_start_token_id
941
+
942
+ assert (
943
+ pad_token_id is not None
944
+ ), "self.model.config.pad_token_id has to be defined."
945
+ # replace possible -100 values in labels by `pad_token_id`
946
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
947
+
948
+ return shifted_input_ids
949
+
950
+
951
+ class T5Stack(T5PreTrainedModel):
952
+ def __init__(self, config, embed_tokens=None):
953
+ super().__init__(config)
954
+
955
+ self.embed_tokens = embed_tokens
956
+ self.is_decoder = config.is_decoder
957
+
958
+ self.block = nn.ModuleList(
959
+ [
960
+ T5Block(config, has_relative_attention_bias=bool(i == 0))
961
+ for i in range(config.num_layers)
962
+ ]
963
+ )
964
+ self.final_layer_norm = T5LayerNorm(
965
+ config.d_model, eps=config.layer_norm_epsilon
966
+ )
967
+ self.dropout = nn.Dropout(config.dropout_rate)
968
+
969
+ # Initialize weights and apply final processing
970
+ self.post_init()
971
+ # Model parallel
972
+ self.model_parallel = False
973
+ self.device_map = None
974
+ self.gradient_checkpointing = False
975
+
976
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
977
+ def parallelize(self, device_map=None):
978
+ # Check validity of device_map
979
+ self.device_map = (
980
+ get_device_map(len(self.block), range(torch.cuda.device_count()))
981
+ if device_map is None
982
+ else device_map
983
+ )
984
+ assert_device_map(self.device_map, len(self.block))
985
+ self.model_parallel = True
986
+ self.first_device = (
987
+ "cpu"
988
+ if "cpu" in self.device_map.keys()
989
+ else "cuda:" + str(min(self.device_map.keys()))
990
+ )
991
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
992
+ # Load onto devices
993
+ for k, v in self.device_map.items():
994
+ for layer in v:
995
+ cuda_device = "cuda:" + str(k)
996
+ self.block[layer] = self.block[layer].to(cuda_device)
997
+
998
+ # Set embed_tokens to first layer
999
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
1000
+ # Set final layer norm to last device
1001
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
1002
+
1003
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1004
+ def deparallelize(self):
1005
+ self.model_parallel = False
1006
+ self.device_map = None
1007
+ self.first_device = "cpu"
1008
+ self.last_device = "cpu"
1009
+ for i in range(len(self.block)):
1010
+ self.block[i] = self.block[i].to("cpu")
1011
+ self.embed_tokens = self.embed_tokens.to("cpu")
1012
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
1013
+ torch.cuda.empty_cache()
1014
+
1015
+ def get_input_embeddings(self):
1016
+ return self.embed_tokens
1017
+
1018
+ def set_input_embeddings(self, new_embeddings):
1019
+ self.embed_tokens = new_embeddings
1020
+
1021
+ def forward(
1022
+ self,
1023
+ input_ids=None,
1024
+ attention_mask=None,
1025
+ encoder_hidden_states=None,
1026
+ encoder_attention_mask=None,
1027
+ inputs_embeds=None,
1028
+ head_mask=None,
1029
+ cross_attn_head_mask=None,
1030
+ past_key_values=None,
1031
+ use_cache=None,
1032
+ output_attentions=None,
1033
+ output_hidden_states=None,
1034
+ return_dict=None,
1035
+ ):
1036
+ # Model parallel
1037
+ if self.model_parallel:
1038
+ torch.cuda.set_device(self.first_device)
1039
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
1040
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1041
+ output_attentions = (
1042
+ output_attentions
1043
+ if output_attentions is not None
1044
+ else self.config.output_attentions
1045
+ )
1046
+ output_hidden_states = (
1047
+ output_hidden_states
1048
+ if output_hidden_states is not None
1049
+ else self.config.output_hidden_states
1050
+ )
1051
+ return_dict = (
1052
+ return_dict if return_dict is not None else self.config.use_return_dict
1053
+ )
1054
+
1055
+ if input_ids is not None and inputs_embeds is not None:
1056
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
1057
+ raise ValueError(
1058
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
1059
+ )
1060
+ elif input_ids is not None:
1061
+ input_shape = input_ids.size()
1062
+ input_ids = input_ids.view(-1, input_shape[-1])
1063
+ elif inputs_embeds is not None:
1064
+ input_shape = inputs_embeds.size()[:-1]
1065
+ else:
1066
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
1067
+ raise ValueError(
1068
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
1069
+ )
1070
+
1071
+ if inputs_embeds is None:
1072
+ assert (
1073
+ self.embed_tokens is not None
1074
+ ), "You have to initialize the model with valid token embeddings"
1075
+ inputs_embeds = self.embed_tokens(input_ids)
1076
+
1077
+ batch_size, seq_length = input_shape
1078
+
1079
+ # required mask seq length can be calculated via length of past
1080
+ mask_seq_length = (
1081
+ past_key_values[0][0].shape[2] + seq_length
1082
+ if past_key_values is not None
1083
+ else seq_length
1084
+ )
1085
+
1086
+ if use_cache is True:
1087
+ assert (
1088
+ self.is_decoder
1089
+ ), f"`use_cache` can only be set to `True` if {self} is used as a decoder"
1090
+
1091
+ if attention_mask is None:
1092
+ attention_mask = torch.ones(
1093
+ batch_size, mask_seq_length, device=inputs_embeds.device
1094
+ )
1095
+ if (
1096
+ self.is_decoder
1097
+ and encoder_attention_mask is None
1098
+ and encoder_hidden_states is not None
1099
+ ):
1100
+ encoder_seq_length = encoder_hidden_states.shape[1]
1101
+ encoder_attention_mask = torch.ones(
1102
+ batch_size,
1103
+ encoder_seq_length,
1104
+ device=inputs_embeds.device,
1105
+ dtype=torch.long,
1106
+ )
1107
+
1108
+ # initialize past_key_values with `None` if past does not exist
1109
+ if past_key_values is None:
1110
+ past_key_values = [None] * len(self.block)
1111
+
1112
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1113
+ # ourselves in which case we just need to make it broadcastable to all heads.
1114
+ extended_attention_mask = self.get_extended_attention_mask(
1115
+ attention_mask, input_shape
1116
+ )
1117
+
1118
+ # If a 2D or 3D attention mask is provided for the cross-attention
1119
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1120
+ if self.is_decoder and encoder_hidden_states is not None:
1121
+ (
1122
+ encoder_batch_size,
1123
+ encoder_sequence_length,
1124
+ _,
1125
+ ) = encoder_hidden_states.size()
1126
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1127
+ if encoder_attention_mask is None:
1128
+ encoder_attention_mask = torch.ones(
1129
+ encoder_hidden_shape, device=inputs_embeds.device
1130
+ )
1131
+ encoder_extended_attention_mask = self.invert_attention_mask(
1132
+ encoder_attention_mask
1133
+ )
1134
+ else:
1135
+ encoder_extended_attention_mask = None
1136
+
1137
+ # Prepare head mask if needed
1138
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
1139
+ cross_attn_head_mask = self.get_head_mask(
1140
+ cross_attn_head_mask, self.config.num_layers
1141
+ )
1142
+ present_key_value_states = () if use_cache else None
1143
+ all_hidden_states = () if output_hidden_states else None
1144
+ all_attentions = () if output_attentions else None
1145
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
1146
+ position_bias = None
1147
+ encoder_decoder_position_bias = None
1148
+
1149
+ hidden_states = self.dropout(inputs_embeds)
1150
+
1151
+ for i, (layer_module, past_key_value) in enumerate(
1152
+ zip(self.block, past_key_values)
1153
+ ):
1154
+ layer_head_mask = head_mask[i]
1155
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
1156
+ # Model parallel
1157
+ if self.model_parallel:
1158
+ torch.cuda.set_device(hidden_states.device)
1159
+ # Ensure that attention_mask is always on the same device as hidden_states
1160
+ if attention_mask is not None:
1161
+ attention_mask = attention_mask.to(hidden_states.device)
1162
+ if position_bias is not None:
1163
+ position_bias = position_bias.to(hidden_states.device)
1164
+ if encoder_hidden_states is not None:
1165
+ encoder_hidden_states = encoder_hidden_states.to(
1166
+ hidden_states.device
1167
+ )
1168
+ if encoder_extended_attention_mask is not None:
1169
+ encoder_extended_attention_mask = (
1170
+ encoder_extended_attention_mask.to(hidden_states.device)
1171
+ )
1172
+ if encoder_decoder_position_bias is not None:
1173
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(
1174
+ hidden_states.device
1175
+ )
1176
+ if layer_head_mask is not None:
1177
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
1178
+ if cross_attn_layer_head_mask is not None:
1179
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
1180
+ hidden_states.device
1181
+ )
1182
+ if output_hidden_states:
1183
+ all_hidden_states = all_hidden_states + (hidden_states,)
1184
+
1185
+ if self.gradient_checkpointing and self.training:
1186
+ if use_cache:
1187
+ logger.warning(
1188
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1189
+ )
1190
+ use_cache = False
1191
+
1192
+ def create_custom_forward(module):
1193
+ def custom_forward(*inputs):
1194
+ return tuple(module(*inputs, use_cache, output_attentions))
1195
+
1196
+ return custom_forward
1197
+
1198
+ layer_outputs = checkpoint(
1199
+ create_custom_forward(layer_module),
1200
+ hidden_states,
1201
+ extended_attention_mask,
1202
+ position_bias,
1203
+ encoder_hidden_states,
1204
+ encoder_extended_attention_mask,
1205
+ encoder_decoder_position_bias,
1206
+ layer_head_mask,
1207
+ cross_attn_layer_head_mask,
1208
+ None, # past_key_value is always None with gradient checkpointing
1209
+ )
1210
+ else:
1211
+ layer_outputs = layer_module(
1212
+ hidden_states,
1213
+ attention_mask=extended_attention_mask,
1214
+ position_bias=position_bias,
1215
+ encoder_hidden_states=encoder_hidden_states,
1216
+ encoder_attention_mask=encoder_extended_attention_mask,
1217
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1218
+ layer_head_mask=layer_head_mask,
1219
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1220
+ past_key_value=past_key_value,
1221
+ use_cache=use_cache,
1222
+ output_attentions=output_attentions,
1223
+ )
1224
+
1225
+ # layer_outputs is a tuple with:
1226
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1227
+ if use_cache is False:
1228
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1229
+
1230
+ hidden_states, present_key_value_state = layer_outputs[:2]
1231
+
1232
+ # We share the position biases between the layers - the first layer store them
1233
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1234
+ # (cross-attention position bias), (cross-attention weights)
1235
+ position_bias = layer_outputs[2]
1236
+ if self.is_decoder and encoder_hidden_states is not None:
1237
+ encoder_decoder_position_bias = layer_outputs[
1238
+ 4 if output_attentions else 3
1239
+ ]
1240
+ # append next layer key value states
1241
+ if use_cache:
1242
+ present_key_value_states = present_key_value_states + (
1243
+ present_key_value_state,
1244
+ )
1245
+
1246
+ if output_attentions:
1247
+ all_attentions = all_attentions + (layer_outputs[3],)
1248
+ if self.is_decoder:
1249
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1250
+
1251
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1252
+ if self.model_parallel:
1253
+ for k, v in self.device_map.items():
1254
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1255
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1256
+
1257
+ hidden_states = self.final_layer_norm(hidden_states)
1258
+ hidden_states = self.dropout(hidden_states)
1259
+
1260
+ # Add last layer
1261
+ if output_hidden_states:
1262
+ all_hidden_states = all_hidden_states + (hidden_states,)
1263
+
1264
+ if not return_dict:
1265
+ return tuple(
1266
+ v
1267
+ for v in [
1268
+ hidden_states,
1269
+ present_key_value_states,
1270
+ all_hidden_states,
1271
+ all_attentions,
1272
+ all_cross_attentions,
1273
+ ]
1274
+ if v is not None
1275
+ )
1276
+ return BaseModelOutputWithPastAndCrossAttentions(
1277
+ last_hidden_state=hidden_states,
1278
+ past_key_values=present_key_value_states,
1279
+ hidden_states=all_hidden_states,
1280
+ attentions=all_attentions,
1281
+ cross_attentions=all_cross_attentions,
1282
+ )
1283
+
1284
+
1285
+ T5_START_DOCSTRING = r"""
1286
+
1287
+ The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
1288
+ Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
1289
+ Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
1290
+ text-to-text denoising generative setting.
1291
+
1292
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1293
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1294
+ etc.)
1295
+
1296
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1297
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1298
+ and behavior.
1299
+
1300
+ Parameters:
1301
+ config ([`T5Config`]): Model configuration class with all the parameters of the model.
1302
+ Initializing with a config file does not load the weights associated with the model, only the
1303
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1304
+ """
1305
+
1306
+ T5_INPUTS_DOCSTRING = r"""
1307
+ Args:
1308
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1309
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1310
+ should be able to pad the inputs on both the right and the left.
1311
+
1312
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1313
+ [`PreTrainedTokenizer.__call__`] for detail.
1314
+
1315
+ [What are input IDs?](../glossary#input-ids)
1316
+
1317
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1318
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1319
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1320
+
1321
+ - 1 for tokens that are **not masked**,
1322
+ - 0 for tokens that are **masked**.
1323
+
1324
+ [What are attention masks?](../glossary#attention-mask)
1325
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1326
+ Indices of decoder input sequence tokens in the vocabulary.
1327
+
1328
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1329
+ [`PreTrainedTokenizer.__call__`] for details.
1330
+
1331
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
1332
+
1333
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
1334
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
1335
+
1336
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
1337
+ Training](./t5#training).
1338
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1339
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
1340
+ be used by default.
1341
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1342
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
1343
+ 1]`:
1344
+
1345
+ - 1 indicates the head is **not masked**,
1346
+ - 0 indicates the head is **masked**.
1347
+
1348
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1349
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
1350
+ 1]`:
1351
+
1352
+ - 1 indicates the head is **not masked**,
1353
+ - 0 indicates the head is **masked**.
1354
+
1355
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1356
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1357
+ `[0, 1]`:
1358
+
1359
+ - 1 indicates the head is **not masked**,
1360
+ - 0 indicates the head is **masked**.
1361
+
1362
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1363
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
1364
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
1365
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1366
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1367
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1368
+
1369
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1370
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1371
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1372
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1373
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1374
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1375
+ model's internal embedding lookup matrix.
1376
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
1377
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1378
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
1379
+ input (see `past_key_values`). This is useful if you want more control over how to convert
1380
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1381
+
1382
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1383
+ of `inputs_embeds`.
1384
+
1385
+ use_cache (`bool`, *optional*):
1386
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1387
+ `past_key_values`).
1388
+
1389
+ output_attentions (`bool`, *optional*):
1390
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1391
+ tensors for more detail.
1392
+ output_hidden_states (`bool`, *optional*):
1393
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1394
+ more detail.
1395
+ return_dict (`bool`, *optional*):
1396
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1397
+ """
1398
+
1399
+ T5_ENCODER_INPUTS_DOCSTRING = r"""
1400
+ Args:
1401
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1402
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1403
+ should be able to pad the inputs on both the right and the left.
1404
+
1405
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1406
+ [`PreTrainedTokenizer.__call__`] for detail.
1407
+
1408
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1409
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1410
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1411
+
1412
+ - 1 for tokens that are **not masked**,
1413
+ - 0 for tokens that are **masked**.
1414
+
1415
+ [What are attention masks?](../glossary#attention-mask)
1416
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1417
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1418
+
1419
+ - 1 indicates the head is **not masked**,
1420
+ - 0 indicates the head is **masked**.
1421
+
1422
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1423
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1424
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1425
+ model's internal embedding lookup matrix.
1426
+ output_attentions (`bool`, *optional*):
1427
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1428
+ tensors for more detail.
1429
+ output_hidden_states (`bool`, *optional*):
1430
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1431
+ more detail.
1432
+ return_dict (`bool`, *optional*):
1433
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1434
+ """
1435
+
1436
+ # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1437
+ __HEAD_MASK_WARNING_MSG = """
1438
+ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
1439
+ `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
1440
+ If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
1441
+ num_heads)`.
1442
+ """
1443
+
1444
+
1445
+ @add_start_docstrings(
1446
+ "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1447
+ T5_START_DOCSTRING,
1448
+ )
1449
+ class T5Model(T5PreTrainedModel):
1450
+ _keys_to_ignore_on_load_missing = [
1451
+ r"encoder.embed_tokens.weight",
1452
+ r"decoder.embed_tokens.weight",
1453
+ ]
1454
+ _keys_to_ignore_on_load_unexpected = [
1455
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1456
+ ]
1457
+
1458
+ def __init__(self, config: T5Config):
1459
+ super().__init__(config)
1460
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1461
+
1462
+ encoder_config = copy.deepcopy(config)
1463
+ encoder_config.is_decoder = False
1464
+ encoder_config.use_cache = False
1465
+ encoder_config.is_encoder_decoder = False
1466
+ self.encoder = T5Stack(encoder_config, self.shared)
1467
+
1468
+ decoder_config = copy.deepcopy(config)
1469
+ decoder_config.is_decoder = True
1470
+ decoder_config.is_encoder_decoder = False
1471
+ decoder_config.num_layers = config.num_decoder_layers
1472
+ self.decoder = T5Stack(decoder_config, self.shared)
1473
+
1474
+ # Initialize weights and apply final processing
1475
+ self.post_init()
1476
+
1477
+ # Model parallel
1478
+ self.model_parallel = False
1479
+ self.device_map = None
1480
+
1481
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1482
+ def parallelize(self, device_map=None):
1483
+ self.device_map = (
1484
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1485
+ if device_map is None
1486
+ else device_map
1487
+ )
1488
+ assert_device_map(self.device_map, len(self.encoder.block))
1489
+ self.encoder.parallelize(self.device_map)
1490
+ self.decoder.parallelize(self.device_map)
1491
+ self.model_parallel = True
1492
+
1493
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1494
+ def deparallelize(self):
1495
+ self.encoder.deparallelize()
1496
+ self.decoder.deparallelize()
1497
+ self.encoder = self.encoder.to("cpu")
1498
+ self.decoder = self.decoder.to("cpu")
1499
+ self.model_parallel = False
1500
+ self.device_map = None
1501
+ torch.cuda.empty_cache()
1502
+
1503
+ def get_input_embeddings(self):
1504
+ return self.shared
1505
+
1506
+ def set_input_embeddings(self, new_embeddings):
1507
+ self.shared = new_embeddings
1508
+ self.encoder.set_input_embeddings(new_embeddings)
1509
+ self.decoder.set_input_embeddings(new_embeddings)
1510
+
1511
+ def get_encoder(self):
1512
+ return self.encoder
1513
+
1514
+ def get_decoder(self):
1515
+ return self.decoder
1516
+
1517
+ def _prune_heads(self, heads_to_prune):
1518
+ """
1519
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1520
+ class PreTrainedModel
1521
+ """
1522
+ for layer, heads in heads_to_prune.items():
1523
+ self.encoder.layer[layer].attention.prune_heads(heads)
1524
+
1525
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1526
+ @replace_return_docstrings(
1527
+ output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC
1528
+ )
1529
+ def forward(
1530
+ self,
1531
+ input_ids: Optional[torch.LongTensor] = None,
1532
+ attention_mask: Optional[torch.FloatTensor] = None,
1533
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1534
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1535
+ head_mask: Optional[torch.FloatTensor] = None,
1536
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1537
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1538
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1539
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1540
+ inputs_embeds: Optional[torch.Tensor] = None,
1541
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
1542
+ use_cache: Optional[bool] = None,
1543
+ output_attentions: Optional[bool] = None,
1544
+ output_hidden_states: Optional[bool] = None,
1545
+ return_dict: Optional[bool] = None,
1546
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1547
+ r"""
1548
+ Returns:
1549
+
1550
+ Example:
1551
+
1552
+ ```python
1553
+ >>> from transformers import T5Tokenizer, T5Model
1554
+
1555
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1556
+ >>> model = T5Model.from_pretrained("t5-small")
1557
+
1558
+ >>> input_ids = tokenizer(
1559
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1560
+ ... ).input_ids # Batch size 1
1561
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1562
+
1563
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
1564
+ >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
1565
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1566
+
1567
+ >>> # forward pass
1568
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1569
+ >>> last_hidden_states = outputs.last_hidden_state
1570
+ ```"""
1571
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1572
+ return_dict = (
1573
+ return_dict if return_dict is not None else self.config.use_return_dict
1574
+ )
1575
+
1576
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1577
+ if head_mask is not None and decoder_head_mask is None:
1578
+ if self.config.num_layers == self.config.num_decoder_layers:
1579
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
1580
+ decoder_head_mask = head_mask
1581
+
1582
+ # Encode if needed (training, first prediction pass)
1583
+ if encoder_outputs is None:
1584
+ encoder_outputs = self.encoder(
1585
+ input_ids=input_ids,
1586
+ attention_mask=attention_mask,
1587
+ inputs_embeds=inputs_embeds,
1588
+ head_mask=head_mask,
1589
+ output_attentions=output_attentions,
1590
+ output_hidden_states=output_hidden_states,
1591
+ return_dict=return_dict,
1592
+ )
1593
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1594
+ encoder_outputs = BaseModelOutput(
1595
+ last_hidden_state=encoder_outputs[0],
1596
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1597
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1598
+ )
1599
+
1600
+ hidden_states = encoder_outputs[0]
1601
+
1602
+ # Set device for model parallelism
1603
+ if self.model_parallel:
1604
+ torch.cuda.set_device(self.decoder.first_device)
1605
+ hidden_states = hidden_states.to(self.decoder.first_device)
1606
+ if decoder_input_ids is not None:
1607
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1608
+ if attention_mask is not None:
1609
+ attention_mask = attention_mask.to(self.decoder.first_device)
1610
+ if decoder_attention_mask is not None:
1611
+ decoder_attention_mask = decoder_attention_mask.to(
1612
+ self.decoder.first_device
1613
+ )
1614
+
1615
+ # Decode
1616
+ decoder_outputs = self.decoder(
1617
+ input_ids=decoder_input_ids,
1618
+ attention_mask=decoder_attention_mask,
1619
+ inputs_embeds=decoder_inputs_embeds,
1620
+ past_key_values=past_key_values,
1621
+ encoder_hidden_states=hidden_states,
1622
+ encoder_attention_mask=attention_mask,
1623
+ head_mask=decoder_head_mask,
1624
+ cross_attn_head_mask=cross_attn_head_mask,
1625
+ use_cache=use_cache,
1626
+ output_attentions=output_attentions,
1627
+ output_hidden_states=output_hidden_states,
1628
+ return_dict=return_dict,
1629
+ )
1630
+
1631
+ if not return_dict:
1632
+ return decoder_outputs + encoder_outputs
1633
+
1634
+ return Seq2SeqModelOutput(
1635
+ last_hidden_state=decoder_outputs.last_hidden_state,
1636
+ past_key_values=decoder_outputs.past_key_values,
1637
+ decoder_hidden_states=decoder_outputs.hidden_states,
1638
+ decoder_attentions=decoder_outputs.attentions,
1639
+ cross_attentions=decoder_outputs.cross_attentions,
1640
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1641
+ encoder_hidden_states=encoder_outputs.hidden_states,
1642
+ encoder_attentions=encoder_outputs.attentions,
1643
+ )
1644
+
1645
+
1646
+ @add_start_docstrings(
1647
+ """T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING
1648
+ )
1649
+ class T5ForConditionalGeneration(T5PreTrainedModel):
1650
+ _keys_to_ignore_on_load_missing = [
1651
+ r"encoder.embed_tokens.weight",
1652
+ r"decoder.embed_tokens.weight",
1653
+ r"lm_head.weight",
1654
+ ]
1655
+ _keys_to_ignore_on_load_unexpected = [
1656
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1657
+ ]
1658
+
1659
+ def __init__(self, config: T5Config):
1660
+ super().__init__(config)
1661
+ self.model_dim = config.d_model
1662
+
1663
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1664
+
1665
+ encoder_config = copy.deepcopy(config)
1666
+ encoder_config.is_decoder = False
1667
+ encoder_config.use_cache = False
1668
+ encoder_config.is_encoder_decoder = False
1669
+ self.encoder = T5Stack(encoder_config, self.shared)
1670
+
1671
+ decoder_config = copy.deepcopy(config)
1672
+ decoder_config.is_decoder = True
1673
+ decoder_config.is_encoder_decoder = False
1674
+ decoder_config.num_layers = config.num_decoder_layers
1675
+ self.decoder = T5Stack(decoder_config, self.shared)
1676
+
1677
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1678
+
1679
+ # Initialize weights and apply final processing
1680
+ self.post_init()
1681
+
1682
+ # Model parallel
1683
+ self.model_parallel = False
1684
+ self.device_map = None
1685
+
1686
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1687
+ def parallelize(self, device_map=None):
1688
+ self.device_map = (
1689
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1690
+ if device_map is None
1691
+ else device_map
1692
+ )
1693
+ assert_device_map(self.device_map, len(self.encoder.block))
1694
+ self.encoder.parallelize(self.device_map)
1695
+ self.decoder.parallelize(self.device_map)
1696
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1697
+ self.model_parallel = True
1698
+
1699
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1700
+ def deparallelize(self):
1701
+ self.encoder.deparallelize()
1702
+ self.decoder.deparallelize()
1703
+ self.encoder = self.encoder.to("cpu")
1704
+ self.decoder = self.decoder.to("cpu")
1705
+ self.lm_head = self.lm_head.to("cpu")
1706
+ self.model_parallel = False
1707
+ self.device_map = None
1708
+ torch.cuda.empty_cache()
1709
+
1710
+ def get_input_embeddings(self):
1711
+ return self.shared
1712
+
1713
+ def set_input_embeddings(self, new_embeddings):
1714
+ self.shared = new_embeddings
1715
+ self.encoder.set_input_embeddings(new_embeddings)
1716
+ self.decoder.set_input_embeddings(new_embeddings)
1717
+
1718
+ def set_output_embeddings(self, new_embeddings):
1719
+ self.lm_head = new_embeddings
1720
+
1721
+ def get_output_embeddings(self):
1722
+ return self.lm_head
1723
+
1724
+ def get_encoder(self):
1725
+ return self.encoder
1726
+
1727
+ def get_decoder(self):
1728
+ return self.decoder
1729
+
1730
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1731
+ @replace_return_docstrings(
1732
+ output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
1733
+ )
1734
+ def forward(
1735
+ self,
1736
+ input_ids: Optional[torch.LongTensor] = None,
1737
+ attention_mask: Optional[torch.FloatTensor] = None,
1738
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1739
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1740
+ head_mask: Optional[torch.FloatTensor] = None,
1741
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1742
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1743
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1744
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1745
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1746
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1747
+ labels: Optional[torch.LongTensor] = None,
1748
+ use_cache: Optional[bool] = None,
1749
+ output_attentions: Optional[bool] = None,
1750
+ output_hidden_states: Optional[bool] = None,
1751
+ return_dict: Optional[bool] = None,
1752
+ reduction: Optional[str] = "mean",
1753
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1754
+ r"""
1755
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1756
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1757
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1758
+ labels in `[0, ..., config.vocab_size]`
1759
+
1760
+ Returns:
1761
+
1762
+ Examples:
1763
+
1764
+ ```python
1765
+ >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
1766
+
1767
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1768
+ >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1769
+
1770
+ >>> # training
1771
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1772
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1773
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1774
+ >>> loss = outputs.loss
1775
+ >>> logits = outputs.logits
1776
+
1777
+ >>> # inference
1778
+ >>> input_ids = tokenizer(
1779
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1780
+ ... ).input_ids # Batch size 1
1781
+ >>> outputs = model.generate(input_ids)
1782
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1783
+ >>> # studies have shown that owning a dog is good for you.
1784
+ ```"""
1785
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1786
+ return_dict = (
1787
+ return_dict if return_dict is not None else self.config.use_return_dict
1788
+ )
1789
+
1790
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1791
+ if head_mask is not None and decoder_head_mask is None:
1792
+ if self.config.num_layers == self.config.num_decoder_layers:
1793
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
1794
+ decoder_head_mask = head_mask
1795
+
1796
+ # Encode if needed (training, first prediction pass)
1797
+ if encoder_outputs is None:
1798
+ # Convert encoder inputs in embeddings if needed
1799
+ encoder_outputs = self.encoder(
1800
+ input_ids=input_ids,
1801
+ attention_mask=attention_mask,
1802
+ inputs_embeds=inputs_embeds,
1803
+ head_mask=head_mask,
1804
+ output_attentions=output_attentions,
1805
+ output_hidden_states=output_hidden_states,
1806
+ return_dict=return_dict,
1807
+ )
1808
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1809
+ encoder_outputs = BaseModelOutput(
1810
+ last_hidden_state=encoder_outputs[0],
1811
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1812
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1813
+ )
1814
+
1815
+ hidden_states = encoder_outputs[0]
1816
+
1817
+ if self.model_parallel:
1818
+ torch.cuda.set_device(self.decoder.first_device)
1819
+
1820
+ if (
1821
+ labels is not None
1822
+ and decoder_input_ids is None
1823
+ and decoder_inputs_embeds is None
1824
+ ):
1825
+ # get decoder inputs from shifting lm labels to the right
1826
+ decoder_input_ids = self._shift_right(labels)
1827
+
1828
+ # Set device for model parallelism
1829
+ if self.model_parallel:
1830
+ torch.cuda.set_device(self.decoder.first_device)
1831
+ hidden_states = hidden_states.to(self.decoder.first_device)
1832
+ if decoder_input_ids is not None:
1833
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1834
+ if attention_mask is not None:
1835
+ attention_mask = attention_mask.to(self.decoder.first_device)
1836
+ if decoder_attention_mask is not None:
1837
+ decoder_attention_mask = decoder_attention_mask.to(
1838
+ self.decoder.first_device
1839
+ )
1840
+
1841
+ # Decode
1842
+ decoder_outputs = self.decoder(
1843
+ input_ids=decoder_input_ids,
1844
+ attention_mask=decoder_attention_mask,
1845
+ inputs_embeds=decoder_inputs_embeds,
1846
+ past_key_values=past_key_values,
1847
+ encoder_hidden_states=hidden_states,
1848
+ encoder_attention_mask=attention_mask,
1849
+ head_mask=decoder_head_mask,
1850
+ cross_attn_head_mask=cross_attn_head_mask,
1851
+ use_cache=use_cache,
1852
+ output_attentions=output_attentions,
1853
+ output_hidden_states=output_hidden_states,
1854
+ return_dict=return_dict,
1855
+ )
1856
+
1857
+ sequence_output = decoder_outputs[0]
1858
+
1859
+ # Set device for model parallelism
1860
+ if self.model_parallel:
1861
+ torch.cuda.set_device(self.encoder.first_device)
1862
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1863
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1864
+
1865
+ if self.config.tie_word_embeddings:
1866
+ # Rescale output before projecting on vocab
1867
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1868
+ sequence_output = sequence_output * (self.model_dim**-0.5)
1869
+
1870
+ lm_logits = self.lm_head(sequence_output)
1871
+
1872
+ loss = None
1873
+ if labels is not None:
1874
+ loss_fct = CrossEntropyLoss(ignore_index=-100, reduction=reduction)
1875
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1876
+ if reduction == "none":
1877
+ loss = loss.view(lm_logits.size(0), -1).sum(1)
1878
+
1879
+ if not return_dict:
1880
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1881
+ return ((loss,) + output) if loss is not None else output
1882
+
1883
+ return Seq2SeqLMOutput(
1884
+ loss=loss,
1885
+ logits=lm_logits,
1886
+ past_key_values=decoder_outputs.past_key_values,
1887
+ decoder_hidden_states=decoder_outputs.hidden_states,
1888
+ decoder_attentions=decoder_outputs.attentions,
1889
+ cross_attentions=decoder_outputs.cross_attentions,
1890
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1891
+ encoder_hidden_states=encoder_outputs.hidden_states,
1892
+ encoder_attentions=encoder_outputs.attentions,
1893
+ )
1894
+
1895
+ def prepare_inputs_for_generation(
1896
+ self,
1897
+ input_ids,
1898
+ past=None,
1899
+ attention_mask=None,
1900
+ head_mask=None,
1901
+ decoder_head_mask=None,
1902
+ cross_attn_head_mask=None,
1903
+ use_cache=None,
1904
+ encoder_outputs=None,
1905
+ **kwargs,
1906
+ ):
1907
+
1908
+ # cut decoder_input_ids if past is used
1909
+ if past is not None:
1910
+ input_ids = input_ids[:, -1:]
1911
+
1912
+ return {
1913
+ "decoder_input_ids": input_ids,
1914
+ "past_key_values": past,
1915
+ "encoder_outputs": encoder_outputs,
1916
+ "attention_mask": attention_mask,
1917
+ "head_mask": head_mask,
1918
+ "decoder_head_mask": decoder_head_mask,
1919
+ "cross_attn_head_mask": cross_attn_head_mask,
1920
+ "use_cache": use_cache,
1921
+ }
1922
+
1923
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1924
+ return self._shift_right(labels)
1925
+
1926
+ def _reorder_cache(self, past, beam_idx):
1927
+ # if decoder past is not included in output
1928
+ # speedy decoding is disabled and no need to reorder
1929
+ if past is None:
1930
+ logger.warning(
1931
+ "You might want to consider setting `use_cache=True` to speed up decoding"
1932
+ )
1933
+ return past
1934
+
1935
+ reordered_decoder_past = ()
1936
+ for layer_past_states in past:
1937
+ # get the correct batch idx from layer past batch dim
1938
+ # batch dim of `past` is at 2nd position
1939
+ reordered_layer_past_states = ()
1940
+ for layer_past_state in layer_past_states:
1941
+ # need to set correct `past` for each of the four key / value states
1942
+ reordered_layer_past_states = reordered_layer_past_states + (
1943
+ layer_past_state.index_select(
1944
+ 0, beam_idx.to(layer_past_state.device)
1945
+ ),
1946
+ )
1947
+
1948
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
1949
+ assert len(reordered_layer_past_states) == len(layer_past_states)
1950
+
1951
+ reordered_decoder_past = reordered_decoder_past + (
1952
+ reordered_layer_past_states,
1953
+ )
1954
+ return reordered_decoder_past
1955
+
1956
+
1957
+ @add_start_docstrings(
1958
+ "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1959
+ T5_START_DOCSTRING,
1960
+ )
1961
+ class T5EncoderModel(T5PreTrainedModel):
1962
+ authorized_missing_keys = [
1963
+ r"encoder.embed_tokens.weight",
1964
+ ]
1965
+
1966
+ def __init__(self, config: T5Config):
1967
+ super().__init__(config)
1968
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1969
+
1970
+ encoder_config = copy.deepcopy(config)
1971
+ encoder_config.use_cache = False
1972
+ encoder_config.is_encoder_decoder = False
1973
+ self.encoder = T5Stack(encoder_config, self.shared)
1974
+
1975
+ # Initialize weights and apply final processing
1976
+ self.post_init()
1977
+
1978
+ # Model parallel
1979
+ self.model_parallel = False
1980
+ self.device_map = None
1981
+
1982
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1983
+ def parallelize(self, device_map=None):
1984
+ self.device_map = (
1985
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1986
+ if device_map is None
1987
+ else device_map
1988
+ )
1989
+ assert_device_map(self.device_map, len(self.encoder.block))
1990
+ self.encoder.parallelize(self.device_map)
1991
+ self.model_parallel = True
1992
+
1993
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1994
+ def deparallelize(self):
1995
+ self.encoder.deparallelize()
1996
+ self.encoder = self.encoder.to("cpu")
1997
+ self.model_parallel = False
1998
+ self.device_map = None
1999
+ torch.cuda.empty_cache()
2000
+
2001
+ def get_input_embeddings(self):
2002
+ return self.shared
2003
+
2004
+ def set_input_embeddings(self, new_embeddings):
2005
+ self.shared = new_embeddings
2006
+ self.encoder.set_input_embeddings(new_embeddings)
2007
+
2008
+ def get_encoder(self):
2009
+ return self.encoder
2010
+
2011
+ def _prune_heads(self, heads_to_prune):
2012
+ """
2013
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
2014
+ class PreTrainedModel
2015
+ """
2016
+ for layer, heads in heads_to_prune.items():
2017
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
2018
+
2019
+ @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
2020
+ @replace_return_docstrings(
2021
+ output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC
2022
+ )
2023
+ def forward(
2024
+ self,
2025
+ input_ids: Optional[torch.LongTensor] = None,
2026
+ attention_mask: Optional[torch.FloatTensor] = None,
2027
+ head_mask: Optional[torch.FloatTensor] = None,
2028
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2029
+ output_attentions: Optional[bool] = None,
2030
+ output_hidden_states: Optional[bool] = None,
2031
+ return_dict: Optional[bool] = None,
2032
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
2033
+ r"""
2034
+ Returns:
2035
+
2036
+ Example:
2037
+
2038
+ ```python
2039
+ >>> from transformers import T5Tokenizer, T5EncoderModel
2040
+
2041
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
2042
+ >>> model = T5EncoderModel.from_pretrained("t5-small")
2043
+ >>> input_ids = tokenizer(
2044
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
2045
+ ... ).input_ids # Batch size 1
2046
+ >>> outputs = model(input_ids=input_ids)
2047
+ >>> last_hidden_states = outputs.last_hidden_state
2048
+ ```"""
2049
+ return_dict = (
2050
+ return_dict if return_dict is not None else self.config.use_return_dict
2051
+ )
2052
+
2053
+ encoder_outputs = self.encoder(
2054
+ input_ids=input_ids,
2055
+ attention_mask=attention_mask,
2056
+ inputs_embeds=inputs_embeds,
2057
+ head_mask=head_mask,
2058
+ output_attentions=output_attentions,
2059
+ output_hidden_states=output_hidden_states,
2060
+ return_dict=return_dict,
2061
+ )
2062
+
2063
+ return encoder_outputs
bliva/models/vit.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+
7
+ Based on timm code base
8
+ https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ """
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from functools import partial
16
+
17
+ from timm.models.vision_transformer import _cfg, PatchEmbed
18
+ from timm.models.registry import register_model
19
+ from timm.models.layers import trunc_normal_, DropPath
20
+ from timm.models.helpers import named_apply, adapt_input_conv
21
+
22
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
23
+ from bliva.models.base_model import BaseEncoder
24
+
25
+
26
+ class Mlp(nn.Module):
27
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
28
+
29
+ def __init__(
30
+ self,
31
+ in_features,
32
+ hidden_features=None,
33
+ out_features=None,
34
+ act_layer=nn.GELU,
35
+ drop=0.0,
36
+ ):
37
+ super().__init__()
38
+ out_features = out_features or in_features
39
+ hidden_features = hidden_features or in_features
40
+ self.fc1 = nn.Linear(in_features, hidden_features)
41
+ self.act = act_layer()
42
+ self.fc2 = nn.Linear(hidden_features, out_features)
43
+ self.drop = nn.Dropout(drop)
44
+
45
+ def forward(self, x):
46
+ x = self.fc1(x)
47
+ x = self.act(x)
48
+ x = self.drop(x)
49
+ x = self.fc2(x)
50
+ x = self.drop(x)
51
+ return x
52
+
53
+
54
+ class Attention(nn.Module):
55
+ def __init__(
56
+ self,
57
+ dim,
58
+ num_heads=8,
59
+ qkv_bias=False,
60
+ qk_scale=None,
61
+ attn_drop=0.0,
62
+ proj_drop=0.0,
63
+ ):
64
+ super().__init__()
65
+ self.num_heads = num_heads
66
+ head_dim = dim // num_heads
67
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
68
+ self.scale = qk_scale or head_dim**-0.5
69
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
70
+ self.attn_drop = nn.Dropout(attn_drop)
71
+ self.proj = nn.Linear(dim, dim)
72
+ self.proj_drop = nn.Dropout(proj_drop)
73
+ self.attn_gradients = None
74
+ self.attention_map = None
75
+
76
+ def save_attn_gradients(self, attn_gradients):
77
+ self.attn_gradients = attn_gradients
78
+
79
+ def get_attn_gradients(self):
80
+ return self.attn_gradients
81
+
82
+ def save_attention_map(self, attention_map):
83
+ self.attention_map = attention_map
84
+
85
+ def get_attention_map(self):
86
+ return self.attention_map
87
+
88
+ def forward(self, x, register_hook=False):
89
+ B, N, C = x.shape
90
+ qkv = (
91
+ self.qkv(x)
92
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
93
+ .permute(2, 0, 3, 1, 4)
94
+ )
95
+ q, k, v = (
96
+ qkv[0],
97
+ qkv[1],
98
+ qkv[2],
99
+ ) # make torchscript happy (cannot use tensor as tuple)
100
+
101
+ attn = (q @ k.transpose(-2, -1)) * self.scale
102
+ attn = attn.softmax(dim=-1)
103
+ attn = self.attn_drop(attn)
104
+
105
+ if register_hook:
106
+ self.save_attention_map(attn)
107
+ attn.register_hook(self.save_attn_gradients)
108
+
109
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
110
+ x = self.proj(x)
111
+ x = self.proj_drop(x)
112
+ return x
113
+
114
+
115
+ class Block(nn.Module):
116
+ def __init__(
117
+ self,
118
+ dim,
119
+ num_heads,
120
+ mlp_ratio=4.0,
121
+ qkv_bias=False,
122
+ qk_scale=None,
123
+ drop=0.0,
124
+ attn_drop=0.0,
125
+ drop_path=0.0,
126
+ act_layer=nn.GELU,
127
+ norm_layer=nn.LayerNorm,
128
+ use_grad_checkpointing=False,
129
+ ):
130
+ super().__init__()
131
+ self.norm1 = norm_layer(dim)
132
+ self.attn = Attention(
133
+ dim,
134
+ num_heads=num_heads,
135
+ qkv_bias=qkv_bias,
136
+ qk_scale=qk_scale,
137
+ attn_drop=attn_drop,
138
+ proj_drop=drop,
139
+ )
140
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
141
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
142
+ self.norm2 = norm_layer(dim)
143
+ mlp_hidden_dim = int(dim * mlp_ratio)
144
+ self.mlp = Mlp(
145
+ in_features=dim,
146
+ hidden_features=mlp_hidden_dim,
147
+ act_layer=act_layer,
148
+ drop=drop,
149
+ )
150
+
151
+ if use_grad_checkpointing:
152
+ self.attn = checkpoint_wrapper(self.attn)
153
+ self.mlp = checkpoint_wrapper(self.mlp)
154
+
155
+ def forward(self, x, register_hook=False):
156
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
157
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
158
+ return x
159
+
160
+
161
+ class VisionTransformer(nn.Module):
162
+ """Vision Transformer
163
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
164
+ https://arxiv.org/abs/2010.11929
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ img_size=224,
170
+ patch_size=16,
171
+ in_chans=3,
172
+ num_classes=1000,
173
+ embed_dim=768,
174
+ depth=12,
175
+ num_heads=12,
176
+ mlp_ratio=4.0,
177
+ qkv_bias=True,
178
+ qk_scale=None,
179
+ representation_size=None,
180
+ drop_rate=0.0,
181
+ attn_drop_rate=0.0,
182
+ drop_path_rate=0.0,
183
+ norm_layer=None,
184
+ use_grad_checkpointing=False,
185
+ ckpt_layer=0,
186
+ ):
187
+ """
188
+ Args:
189
+ img_size (int, tuple): input image size
190
+ patch_size (int, tuple): patch size
191
+ in_chans (int): number of input channels
192
+ num_classes (int): number of classes for classification head
193
+ embed_dim (int): embedding dimension
194
+ depth (int): depth of transformer
195
+ num_heads (int): number of attention heads
196
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
197
+ qkv_bias (bool): enable bias for qkv if True
198
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
199
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
200
+ drop_rate (float): dropout rate
201
+ attn_drop_rate (float): attention dropout rate
202
+ drop_path_rate (float): stochastic depth rate
203
+ norm_layer: (nn.Module): normalization layer
204
+ """
205
+ super().__init__()
206
+ self.num_features = (
207
+ self.embed_dim
208
+ ) = embed_dim # num_features for consistency with other models
209
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
210
+
211
+ self.patch_embed = PatchEmbed(
212
+ img_size=img_size,
213
+ patch_size=patch_size,
214
+ in_chans=in_chans,
215
+ embed_dim=embed_dim,
216
+ )
217
+
218
+ num_patches = self.patch_embed.num_patches
219
+
220
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
221
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
222
+ self.pos_drop = nn.Dropout(p=drop_rate)
223
+
224
+ dpr = [
225
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
226
+ ] # stochastic depth decay rule
227
+ self.blocks = nn.ModuleList(
228
+ [
229
+ Block(
230
+ dim=embed_dim,
231
+ num_heads=num_heads,
232
+ mlp_ratio=mlp_ratio,
233
+ qkv_bias=qkv_bias,
234
+ qk_scale=qk_scale,
235
+ drop=drop_rate,
236
+ attn_drop=attn_drop_rate,
237
+ drop_path=dpr[i],
238
+ norm_layer=norm_layer,
239
+ use_grad_checkpointing=(
240
+ use_grad_checkpointing and i >= depth - ckpt_layer
241
+ ),
242
+ )
243
+ for i in range(depth)
244
+ ]
245
+ )
246
+ self.norm = norm_layer(embed_dim)
247
+
248
+ trunc_normal_(self.pos_embed, std=0.02)
249
+ trunc_normal_(self.cls_token, std=0.02)
250
+ self.apply(self._init_weights)
251
+
252
+ def _init_weights(self, m):
253
+ if isinstance(m, nn.Linear):
254
+ trunc_normal_(m.weight, std=0.02)
255
+ if isinstance(m, nn.Linear) and m.bias is not None:
256
+ nn.init.constant_(m.bias, 0)
257
+ elif isinstance(m, nn.LayerNorm):
258
+ nn.init.constant_(m.bias, 0)
259
+ nn.init.constant_(m.weight, 1.0)
260
+
261
+ @torch.jit.ignore
262
+ def no_weight_decay(self):
263
+ return {"pos_embed", "cls_token"}
264
+
265
+ def forward(self, x, register_blk=-1):
266
+ B = x.shape[0]
267
+ x = self.patch_embed(x)
268
+
269
+ cls_tokens = self.cls_token.expand(
270
+ B, -1, -1
271
+ ) # stole cls_tokens impl from Phil Wang, thanks
272
+ x = torch.cat((cls_tokens, x), dim=1)
273
+
274
+ x = x + self.pos_embed[:, : x.size(1), :]
275
+ x = self.pos_drop(x)
276
+
277
+ for i, blk in enumerate(self.blocks):
278
+ x = blk(x, register_blk == i)
279
+ x = self.norm(x)
280
+
281
+ return x
282
+
283
+ @torch.jit.ignore()
284
+ def load_pretrained(self, checkpoint_path, prefix=""):
285
+ _load_weights(self, checkpoint_path, prefix)
286
+
287
+
288
+ @torch.no_grad()
289
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
290
+ """Load weights from .npz checkpoints for official Google Brain Flax implementation"""
291
+ import numpy as np
292
+
293
+ def _n2p(w, t=True):
294
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
295
+ w = w.flatten()
296
+ if t:
297
+ if w.ndim == 4:
298
+ w = w.transpose([3, 2, 0, 1])
299
+ elif w.ndim == 3:
300
+ w = w.transpose([2, 0, 1])
301
+ elif w.ndim == 2:
302
+ w = w.transpose([1, 0])
303
+ return torch.from_numpy(w)
304
+
305
+ w = np.load(checkpoint_path)
306
+ if not prefix and "opt/target/embedding/kernel" in w:
307
+ prefix = "opt/target/"
308
+
309
+ if hasattr(model.patch_embed, "backbone"):
310
+ # hybrid
311
+ backbone = model.patch_embed.backbone
312
+ stem_only = not hasattr(backbone, "stem")
313
+ stem = backbone if stem_only else backbone.stem
314
+ stem.conv.weight.copy_(
315
+ adapt_input_conv(
316
+ stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])
317
+ )
318
+ )
319
+ stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
320
+ stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
321
+ if not stem_only:
322
+ for i, stage in enumerate(backbone.stages):
323
+ for j, block in enumerate(stage.blocks):
324
+ bp = f"{prefix}block{i + 1}/unit{j + 1}/"
325
+ for r in range(3):
326
+ getattr(block, f"conv{r + 1}").weight.copy_(
327
+ _n2p(w[f"{bp}conv{r + 1}/kernel"])
328
+ )
329
+ getattr(block, f"norm{r + 1}").weight.copy_(
330
+ _n2p(w[f"{bp}gn{r + 1}/scale"])
331
+ )
332
+ getattr(block, f"norm{r + 1}").bias.copy_(
333
+ _n2p(w[f"{bp}gn{r + 1}/bias"])
334
+ )
335
+ if block.downsample is not None:
336
+ block.downsample.conv.weight.copy_(
337
+ _n2p(w[f"{bp}conv_proj/kernel"])
338
+ )
339
+ block.downsample.norm.weight.copy_(
340
+ _n2p(w[f"{bp}gn_proj/scale"])
341
+ )
342
+ block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
343
+ embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
344
+ else:
345
+ embed_conv_w = adapt_input_conv(
346
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"])
347
+ )
348
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
349
+ model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
350
+ model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
351
+ pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
352
+ if pos_embed_w.shape != model.pos_embed.shape:
353
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
354
+ pos_embed_w,
355
+ model.pos_embed,
356
+ getattr(model, "num_tokens", 1),
357
+ model.patch_embed.grid_size,
358
+ )
359
+ model.pos_embed.copy_(pos_embed_w)
360
+ model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
361
+ model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
362
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
363
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
364
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
365
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
366
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
367
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
368
+ for i, block in enumerate(model.blocks.children()):
369
+ block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
370
+ mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/"
371
+ block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"]))
372
+ block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"]))
373
+ block.attn.qkv.weight.copy_(
374
+ torch.cat(
375
+ [
376
+ _n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T
377
+ for n in ("query", "key", "value")
378
+ ]
379
+ )
380
+ )
381
+ block.attn.qkv.bias.copy_(
382
+ torch.cat(
383
+ [
384
+ _n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1)
385
+ for n in ("query", "key", "value")
386
+ ]
387
+ )
388
+ )
389
+ block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1))
390
+ block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"]))
391
+ for r in range(2):
392
+ getattr(block.mlp, f"fc{r + 1}").weight.copy_(
393
+ _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"])
394
+ )
395
+ getattr(block.mlp, f"fc{r + 1}").bias.copy_(
396
+ _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"])
397
+ )
398
+ block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"]))
399
+ block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"]))
400
+
401
+
402
+ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
403
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
404
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
405
+ print("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
406
+ ntok_new = posemb_new.shape[1]
407
+ if num_tokens:
408
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
409
+ ntok_new -= num_tokens
410
+ else:
411
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
412
+ gs_old = int(math.sqrt(len(posemb_grid)))
413
+ if not len(gs_new): # backwards compatibility
414
+ gs_new = [int(math.sqrt(ntok_new))] * 2
415
+ assert len(gs_new) >= 2
416
+ print("Position embedding grid-size from %s to %s", [gs_old, gs_old], gs_new)
417
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
418
+ posemb_grid = F.interpolate(
419
+ posemb_grid, size=gs_new, mode="bicubic", align_corners=False
420
+ )
421
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
422
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
423
+ return
424
+
425
+
426
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
427
+ # interpolate position embedding
428
+ embedding_size = pos_embed_checkpoint.shape[-1]
429
+ num_patches = visual_encoder.patch_embed.num_patches
430
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
431
+ # height (== width) for the checkpoint position embedding
432
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
433
+ # height (== width) for the new position embedding
434
+ new_size = int(num_patches**0.5)
435
+
436
+ if orig_size != new_size:
437
+ # class_token and dist_token are kept unchanged
438
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
439
+ # only the position tokens are interpolated
440
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
441
+ pos_tokens = pos_tokens.reshape(
442
+ -1, orig_size, orig_size, embedding_size
443
+ ).permute(0, 3, 1, 2)
444
+ pos_tokens = torch.nn.functional.interpolate(
445
+ pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
446
+ )
447
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
448
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
449
+ print(
450
+ "reshape position embedding from %d to %d" % (orig_size**2, new_size**2)
451
+ )
452
+
453
+ return new_pos_embed
454
+ else:
455
+ return pos_embed_checkpoint
456
+
457
+
458
+ class VisionTransformerEncoder(VisionTransformer, BaseEncoder):
459
+ @classmethod
460
+ def from_config(cls, cfg, from_pretrained=False):
461
+
462
+ vit_type = cfg.get("vit_type", "base")
463
+ image_size = cfg.get("image_size", 384)
464
+ ckpt_layer = cfg.get("vit_ckpt_layer", 0)
465
+ drop_path_rate = cfg.get("vit_drop_path_rate", 0)
466
+ norm_layer_eps = cfg.get("vit_layer_norm_epsilon", -1)
467
+ use_grad_checkpointing = cfg.get("vit_grad_ckpt", False)
468
+
469
+ if norm_layer_eps == -1:
470
+ norm_layer = None
471
+ else:
472
+ norm_layer = partial(nn.LayerNorm, eps=norm_layer_eps)
473
+
474
+ # norm_layer=partial(nn.LayerNorm, eps=1e-6),
475
+ assert vit_type in ["base", "large"], "vit parameter must be base or large"
476
+ if vit_type == "base":
477
+ vision_width = 768
478
+ visual_encoder = cls(
479
+ img_size=image_size,
480
+ patch_size=16,
481
+ embed_dim=vision_width,
482
+ depth=12,
483
+ num_heads=12,
484
+ use_grad_checkpointing=use_grad_checkpointing,
485
+ ckpt_layer=ckpt_layer,
486
+ drop_path_rate=0 or drop_path_rate,
487
+ norm_layer=norm_layer,
488
+ )
489
+
490
+ if from_pretrained:
491
+ checkpoint = torch.hub.load_state_dict_from_url(
492
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
493
+ map_location="cpu",
494
+ check_hash=True,
495
+ )
496
+ state_dict = checkpoint["model"]
497
+ state_dict["pos_embed"] = interpolate_pos_embed(
498
+ state_dict["pos_embed"], visual_encoder
499
+ )
500
+ msg = visual_encoder.load_state_dict(state_dict, strict=False)
501
+
502
+ elif vit_type == "large":
503
+ vision_width = 1024
504
+ visual_encoder = cls(
505
+ img_size=image_size,
506
+ patch_size=16,
507
+ embed_dim=vision_width,
508
+ depth=24,
509
+ num_heads=16,
510
+ use_grad_checkpointing=use_grad_checkpointing,
511
+ ckpt_layer=ckpt_layer,
512
+ drop_path_rate=0.1 or drop_path_rate,
513
+ norm_layer=norm_layer,
514
+ )
515
+ if from_pretrained:
516
+ from timm.models.helpers import load_custom_pretrained
517
+ from timm.models.vision_transformer import default_cfgs
518
+
519
+ load_custom_pretrained(
520
+ visual_encoder, default_cfgs["vit_large_patch16_224_in21k"]
521
+ )
522
+
523
+ visual_encoder.vision_width = vision_width
524
+ return visual_encoder
525
+
526
+ def forward_features(self, x, register_blk=-1):
527
+ return super().forward(x, register_blk)
bliva/processors/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from bliva.processors.base_processor import BaseProcessor
9
+
10
+ from bliva.processors.blip_processors import (
11
+ BlipImageTrainProcessor,
12
+ Blip2ImageTrainProcessor,
13
+ BlipImageEvalProcessor,
14
+ BlipCaptionProcessor,
15
+ )
16
+ from bliva.processors.clip_processors import ClipImageTrainProcessor
17
+
18
+ from bliva.common.registry import registry
19
+
20
+ __all__ = [
21
+ "BaseProcessor",
22
+ "BlipImageTrainProcessor",
23
+ "Blip2ImageTrainProcessor",
24
+ "BlipImageEvalProcessor",
25
+ "BlipCaptionProcessor",
26
+ "ClipImageTrainProcessor",
27
+ ]
28
+
29
+
30
+ def load_processor(name, cfg=None):
31
+ """
32
+ Example
33
+
34
+ >>> processor = load_processor("alpro_video_train", cfg=None)
35
+ """
36
+ processor = registry.get_processor_class(name).from_config(cfg)
37
+
38
+ return processor
bliva/processors/base_processor.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ class BaseProcessor:
12
+ def __init__(self):
13
+ self.transform = lambda x: x
14
+ return
15
+
16
+ def __call__(self, item):
17
+ return self.transform(item)
18
+
19
+ @classmethod
20
+ def from_config(cls, cfg=None):
21
+ return cls()
22
+
23
+ def build(self, **kwargs):
24
+ cfg = OmegaConf.create(kwargs)
25
+
26
+ return self.from_config(cfg)
bliva/processors/blip_processors.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import re
9
+
10
+ from bliva.common.registry import registry
11
+ from bliva.processors.base_processor import BaseProcessor
12
+ from bliva.processors.randaugment import RandomAugment
13
+ from omegaconf import OmegaConf
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+
17
+
18
+ class BlipImageBaseProcessor(BaseProcessor):
19
+ def __init__(self, mean=None, std=None):
20
+ if mean is None:
21
+ mean = (0.48145466, 0.4578275, 0.40821073)
22
+ if std is None:
23
+ std = (0.26862954, 0.26130258, 0.27577711)
24
+
25
+ self.normalize = transforms.Normalize(mean, std)
26
+
27
+
28
+ @registry.register_processor("blip_caption")
29
+ class BlipCaptionProcessor(BaseProcessor):
30
+ def __init__(self, prompt="", max_words=50):
31
+ self.prompt = prompt
32
+ self.max_words = max_words
33
+
34
+ def __call__(self, caption):
35
+ caption = self.prompt + self.pre_caption(caption)
36
+
37
+ return caption
38
+
39
+ @classmethod
40
+ def from_config(cls, cfg=None):
41
+ if cfg is None:
42
+ cfg = OmegaConf.create()
43
+
44
+ prompt = cfg.get("prompt", "")
45
+ max_words = cfg.get("max_words", 50)
46
+
47
+ return cls(prompt=prompt, max_words=max_words)
48
+
49
+ def pre_caption(self, caption):
50
+ caption = re.sub(
51
+ r"([.!\"()*#:;~])",
52
+ " ",
53
+ caption.lower(),
54
+ )
55
+ caption = re.sub(
56
+ r"\s{2,}",
57
+ " ",
58
+ caption,
59
+ )
60
+ caption = caption.rstrip("\n")
61
+ caption = caption.strip(" ")
62
+
63
+ # truncate caption
64
+ caption_words = caption.split(" ")
65
+ if len(caption_words) > self.max_words:
66
+ caption = " ".join(caption_words[: self.max_words])
67
+
68
+ return caption
69
+
70
+
71
+ @registry.register_processor("blip_question")
72
+ class BlipQuestionProcessor(BaseProcessor):
73
+ def __init__(self, max_words=50):
74
+ self.max_words = max_words
75
+
76
+ def __call__(self, question):
77
+ return self.pre_question(question)
78
+
79
+ @classmethod
80
+ def from_config(cls, cfg=None):
81
+ if cfg is None:
82
+ cfg = OmegaConf.create()
83
+
84
+ max_words = cfg.get("max_words", 50)
85
+
86
+ return cls(max_words=max_words)
87
+
88
+ def pre_question(self, question):
89
+ question = re.sub(
90
+ r"([.!\"()*#:;~])",
91
+ "",
92
+ question.lower(),
93
+ )
94
+ question = question.rstrip(" ")
95
+
96
+ # truncate question
97
+ question_words = question.split(" ")
98
+ if len(question_words) > self.max_words:
99
+ question = " ".join(question_words[: self.max_words])
100
+
101
+ return question
102
+
103
+
104
+ @registry.register_processor("blip_image_train")
105
+ class BlipImageTrainProcessor(BlipImageBaseProcessor):
106
+ def __init__(
107
+ self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
108
+ ):
109
+ super().__init__(mean=mean, std=std)
110
+
111
+ self.transform = transforms.Compose(
112
+ [
113
+ transforms.RandomResizedCrop(
114
+ image_size,
115
+ scale=(min_scale, max_scale),
116
+ interpolation=InterpolationMode.BICUBIC,
117
+ ),
118
+ transforms.RandomHorizontalFlip(),
119
+ # RandomAugment(
120
+ # 2,
121
+ # 5,
122
+ # isPIL=True,
123
+ # augs=[
124
+ # "Identity",
125
+ # "AutoContrast",
126
+ # "Brightness",
127
+ # "Sharpness",
128
+ # "Equalize",
129
+ # "ShearX",
130
+ # "ShearY",
131
+ # "TranslateX",
132
+ # "TranslateY",
133
+ # "Rotate",
134
+ # ],
135
+ # ),
136
+ transforms.ToTensor(),
137
+ self.normalize,
138
+ ]
139
+ )
140
+
141
+ def __call__(self, item):
142
+ return self.transform(item)
143
+
144
+ @classmethod
145
+ def from_config(cls, cfg=None):
146
+ if cfg is None:
147
+ cfg = OmegaConf.create()
148
+
149
+ image_size = cfg.get("image_size", 384)
150
+
151
+ mean = cfg.get("mean", None)
152
+ std = cfg.get("std", None)
153
+
154
+ min_scale = cfg.get("min_scale", 0.5)
155
+ max_scale = cfg.get("max_scale", 1.0)
156
+
157
+ return cls(
158
+ image_size=image_size,
159
+ mean=mean,
160
+ std=std,
161
+ min_scale=min_scale,
162
+ max_scale=max_scale,
163
+ )
164
+
165
+
166
+ @registry.register_processor("blip_image_eval")
167
+ class BlipImageEvalProcessor(BlipImageBaseProcessor):
168
+ def __init__(self, image_size=384, mean=None, std=None):
169
+ super().__init__(mean=mean, std=std)
170
+
171
+ self.transform = transforms.Compose(
172
+ [
173
+ transforms.Resize(
174
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
175
+ ),
176
+ transforms.ToTensor(),
177
+ self.normalize,
178
+ ]
179
+ )
180
+
181
+ def __call__(self, item):
182
+ return self.transform(item)
183
+
184
+ @classmethod
185
+ def from_config(cls, cfg=None):
186
+ if cfg is None:
187
+ cfg = OmegaConf.create()
188
+
189
+ image_size = cfg.get("image_size", 384)
190
+
191
+ mean = cfg.get("mean", None)
192
+ std = cfg.get("std", None)
193
+
194
+ return cls(image_size=image_size, mean=mean, std=std)
195
+
196
+
197
+ @registry.register_processor("blip2_image_train")
198
+ class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
199
+ def __init__(
200
+ self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0
201
+ ):
202
+ super().__init__(mean=mean, std=std)
203
+
204
+ self.transform = transforms.Compose(
205
+ [
206
+ transforms.RandomResizedCrop(
207
+ image_size,
208
+ scale=(min_scale, max_scale),
209
+ interpolation=InterpolationMode.BICUBIC,
210
+ ),
211
+ transforms.RandomHorizontalFlip(),
212
+ transforms.ToTensor(),
213
+ self.normalize,
214
+ ]
215
+ )
216
+
217
+ def __call__(self, item):
218
+ return self.transform(item)
219
+
220
+ @classmethod
221
+ def from_config(cls, cfg=None):
222
+ if cfg is None:
223
+ cfg = OmegaConf.create()
224
+
225
+ image_size = cfg.get("image_size", 364)
226
+
227
+ mean = cfg.get("mean", None)
228
+ std = cfg.get("std", None)
229
+
230
+ min_scale = cfg.get("min_scale", 0.5)
231
+ max_scale = cfg.get("max_scale", 1.0)
232
+
233
+ return cls(
234
+ image_size=image_size,
235
+ mean=mean,
236
+ std=std,
237
+ min_scale=min_scale,
238
+ max_scale=max_scale,
239
+ )
bliva/processors/clip_processors.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from bliva.common.registry import registry
9
+ from bliva.processors.blip_processors import BlipImageBaseProcessor
10
+ from omegaconf import OmegaConf
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode
13
+
14
+
15
+ def _convert_to_rgb(image):
16
+ return image.convert("RGB")
17
+
18
+
19
+ @registry.register_processor("clip_image_train")
20
+ class ClipImageTrainProcessor(BlipImageBaseProcessor):
21
+ def __init__(
22
+ self, image_size=224, mean=None, std=None, min_scale=0.9, max_scale=1.0
23
+ ):
24
+
25
+ super().__init__(mean=mean, std=std)
26
+
27
+ self.transform = transforms.Compose(
28
+ [
29
+ transforms.RandomResizedCrop(
30
+ image_size,
31
+ scale=(min_scale, max_scale),
32
+ interpolation=InterpolationMode.BICUBIC,
33
+ ),
34
+ _convert_to_rgb,
35
+ transforms.ToTensor(),
36
+ self.normalize,
37
+ ]
38
+ )
39
+
40
+ @classmethod
41
+ def from_config(cls, cfg=None):
42
+ if cfg is None:
43
+ cfg = OmegaConf.create()
44
+
45
+ image_size = cfg.get("image_size", 224)
46
+
47
+ mean = cfg.get("mean", None)
48
+ std = cfg.get("std", None)
49
+
50
+ min_scale = cfg.get("min_scale", 0.9)
51
+ max_scale = cfg.get("max_scale", 1.0)
52
+
53
+ return cls(
54
+ image_size=image_size,
55
+ mean=mean,
56
+ std=std,
57
+ min_scale=min_scale,
58
+ max_scale=max_scale,
59
+ )
60
+
61
+
62
+ @registry.register_processor("clip_image_eval")
63
+ class ClipImageEvalProcessor(BlipImageBaseProcessor):
64
+ def __init__(self, image_size=224, mean=None, std=None):
65
+
66
+ super().__init__(mean=mean, std=std)
67
+
68
+ self.transform = transforms.Compose(
69
+ [
70
+ transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
71
+ transforms.CenterCrop(image_size),
72
+ _convert_to_rgb,
73
+ transforms.ToTensor(),
74
+ self.normalize,
75
+ ]
76
+ )
77
+
78
+ @classmethod
79
+ def from_config(cls, cfg=None):
80
+ if cfg is None:
81
+ cfg = OmegaConf.create()
82
+
83
+ image_size = cfg.get("image_size", 224)
84
+
85
+ mean = cfg.get("mean", None)
86
+ std = cfg.get("std", None)
87
+
88
+ return cls(
89
+ image_size=image_size,
90
+ mean=mean,
91
+ std=std,
92
+ )
bliva/processors/randaugment.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch
12
+
13
+
14
+ ## aug functions
15
+ def identity_func(img):
16
+ return img
17
+
18
+
19
+ def autocontrast_func(img, cutoff=0):
20
+ """
21
+ same output as PIL.ImageOps.autocontrast
22
+ """
23
+ n_bins = 256
24
+
25
+ def tune_channel(ch):
26
+ n = ch.size
27
+ cut = cutoff * n // 100
28
+ if cut == 0:
29
+ high, low = ch.max(), ch.min()
30
+ else:
31
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
32
+ low = np.argwhere(np.cumsum(hist) > cut)
33
+ low = 0 if low.shape[0] == 0 else low[0]
34
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
35
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
36
+ if high <= low:
37
+ table = np.arange(n_bins)
38
+ else:
39
+ scale = (n_bins - 1) / (high - low)
40
+ offset = -low * scale
41
+ table = np.arange(n_bins) * scale + offset
42
+ table[table < 0] = 0
43
+ table[table > n_bins - 1] = n_bins - 1
44
+ table = table.clip(0, 255).astype(np.uint8)
45
+ return table[ch]
46
+
47
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
48
+ out = cv2.merge(channels)
49
+ return out
50
+
51
+
52
+ def equalize_func(img):
53
+ """
54
+ same output as PIL.ImageOps.equalize
55
+ PIL's implementation is different from cv2.equalize
56
+ """
57
+ n_bins = 256
58
+
59
+ def tune_channel(ch):
60
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
61
+ non_zero_hist = hist[hist != 0].reshape(-1)
62
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
63
+ if step == 0:
64
+ return ch
65
+ n = np.empty_like(hist)
66
+ n[0] = step // 2
67
+ n[1:] = hist[:-1]
68
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
69
+ return table[ch]
70
+
71
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
72
+ out = cv2.merge(channels)
73
+ return out
74
+
75
+
76
+ def rotate_func(img, degree, fill=(0, 0, 0)):
77
+ """
78
+ like PIL, rotate by degree, not radians
79
+ """
80
+ H, W = img.shape[0], img.shape[1]
81
+ center = W / 2, H / 2
82
+ M = cv2.getRotationMatrix2D(center, degree, 1)
83
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
84
+ return out
85
+
86
+
87
+ def solarize_func(img, thresh=128):
88
+ """
89
+ same output as PIL.ImageOps.posterize
90
+ """
91
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
92
+ table = table.clip(0, 255).astype(np.uint8)
93
+ out = table[img]
94
+ return out
95
+
96
+
97
+ def color_func(img, factor):
98
+ """
99
+ same output as PIL.ImageEnhance.Color
100
+ """
101
+ ## implementation according to PIL definition, quite slow
102
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
103
+ # out = blend(degenerate, img, factor)
104
+ # M = (
105
+ # np.eye(3) * factor
106
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
107
+ # )[np.newaxis, np.newaxis, :]
108
+ M = np.float32(
109
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
110
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
111
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
112
+ return out
113
+
114
+
115
+ def contrast_func(img, factor):
116
+ """
117
+ same output as PIL.ImageEnhance.Contrast
118
+ """
119
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
120
+ table = (
121
+ np.array([(el - mean) * factor + mean for el in range(256)])
122
+ .clip(0, 255)
123
+ .astype(np.uint8)
124
+ )
125
+ out = table[img]
126
+ return out
127
+
128
+
129
+ def brightness_func(img, factor):
130
+ """
131
+ same output as PIL.ImageEnhance.Contrast
132
+ """
133
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
134
+ out = table[img]
135
+ return out
136
+
137
+
138
+ def sharpness_func(img, factor):
139
+ """
140
+ The differences the this result and PIL are all on the 4 boundaries, the center
141
+ areas are same
142
+ """
143
+ kernel = np.ones((3, 3), dtype=np.float32)
144
+ kernel[1][1] = 5
145
+ kernel /= 13
146
+ degenerate = cv2.filter2D(img, -1, kernel)
147
+ if factor == 0.0:
148
+ out = degenerate
149
+ elif factor == 1.0:
150
+ out = img
151
+ else:
152
+ out = img.astype(np.float32)
153
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
154
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
155
+ out = out.astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
160
+ H, W = img.shape[0], img.shape[1]
161
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
162
+ out = cv2.warpAffine(
163
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
164
+ ).astype(np.uint8)
165
+ return out
166
+
167
+
168
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
169
+ """
170
+ same output as PIL.Image.transform
171
+ """
172
+ H, W = img.shape[0], img.shape[1]
173
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
174
+ out = cv2.warpAffine(
175
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
176
+ ).astype(np.uint8)
177
+ return out
178
+
179
+
180
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
181
+ """
182
+ same output as PIL.Image.transform
183
+ """
184
+ H, W = img.shape[0], img.shape[1]
185
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
186
+ out = cv2.warpAffine(
187
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
188
+ ).astype(np.uint8)
189
+ return out
190
+
191
+
192
+ def posterize_func(img, bits):
193
+ """
194
+ same output as PIL.ImageOps.posterize
195
+ """
196
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
197
+ return out
198
+
199
+
200
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
201
+ H, W = img.shape[0], img.shape[1]
202
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
203
+ out = cv2.warpAffine(
204
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
205
+ ).astype(np.uint8)
206
+ return out
207
+
208
+
209
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
210
+ replace = np.array(replace, dtype=np.uint8)
211
+ H, W = img.shape[0], img.shape[1]
212
+ rh, rw = np.random.random(2)
213
+ pad_size = pad_size // 2
214
+ ch, cw = int(rh * H), int(rw * W)
215
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
216
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
217
+ out = img.copy()
218
+ out[x1:x2, y1:y2, :] = replace
219
+ return out
220
+
221
+
222
+ ### level to args
223
+ def enhance_level_to_args(MAX_LEVEL):
224
+ def level_to_args(level):
225
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
226
+
227
+ return level_to_args
228
+
229
+
230
+ def shear_level_to_args(MAX_LEVEL, replace_value):
231
+ def level_to_args(level):
232
+ level = (level / MAX_LEVEL) * 0.3
233
+ if np.random.random() > 0.5:
234
+ level = -level
235
+ return (level, replace_value)
236
+
237
+ return level_to_args
238
+
239
+
240
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
241
+ def level_to_args(level):
242
+ level = (level / MAX_LEVEL) * float(translate_const)
243
+ if np.random.random() > 0.5:
244
+ level = -level
245
+ return (level, replace_value)
246
+
247
+ return level_to_args
248
+
249
+
250
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
251
+ def level_to_args(level):
252
+ level = int((level / MAX_LEVEL) * cutout_const)
253
+ return (level, replace_value)
254
+
255
+ return level_to_args
256
+
257
+
258
+ def solarize_level_to_args(MAX_LEVEL):
259
+ def level_to_args(level):
260
+ level = int((level / MAX_LEVEL) * 256)
261
+ return (level,)
262
+
263
+ return level_to_args
264
+
265
+
266
+ def none_level_to_args(level):
267
+ return ()
268
+
269
+
270
+ def posterize_level_to_args(MAX_LEVEL):
271
+ def level_to_args(level):
272
+ level = int((level / MAX_LEVEL) * 4)
273
+ return (level,)
274
+
275
+ return level_to_args
276
+
277
+
278
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
279
+ def level_to_args(level):
280
+ level = (level / MAX_LEVEL) * 30
281
+ if np.random.random() < 0.5:
282
+ level = -level
283
+ return (level, replace_value)
284
+
285
+ return level_to_args
286
+
287
+
288
+ func_dict = {
289
+ "Identity": identity_func,
290
+ "AutoContrast": autocontrast_func,
291
+ "Equalize": equalize_func,
292
+ "Rotate": rotate_func,
293
+ "Solarize": solarize_func,
294
+ "Color": color_func,
295
+ "Contrast": contrast_func,
296
+ "Brightness": brightness_func,
297
+ "Sharpness": sharpness_func,
298
+ "ShearX": shear_x_func,
299
+ "TranslateX": translate_x_func,
300
+ "TranslateY": translate_y_func,
301
+ "Posterize": posterize_func,
302
+ "ShearY": shear_y_func,
303
+ }
304
+
305
+ translate_const = 10
306
+ MAX_LEVEL = 10
307
+ replace_value = (128, 128, 128)
308
+ arg_dict = {
309
+ "Identity": none_level_to_args,
310
+ "AutoContrast": none_level_to_args,
311
+ "Equalize": none_level_to_args,
312
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
313
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
314
+ "Color": enhance_level_to_args(MAX_LEVEL),
315
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
316
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
317
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
318
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
319
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
320
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
321
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
322
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
323
+ }
324
+
325
+
326
+ class RandomAugment(object):
327
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
328
+ self.N = N
329
+ self.M = M
330
+ self.isPIL = isPIL
331
+ if augs:
332
+ self.augs = augs
333
+ else:
334
+ self.augs = list(arg_dict.keys())
335
+
336
+ def get_random_ops(self):
337
+ sampled_ops = np.random.choice(self.augs, self.N)
338
+ return [(op, 0.5, self.M) for op in sampled_ops]
339
+
340
+ def __call__(self, img):
341
+ if self.isPIL:
342
+ img = np.array(img)
343
+ ops = self.get_random_ops()
344
+ for name, prob, level in ops:
345
+ if np.random.random() > prob:
346
+ continue
347
+ args = arg_dict[name](level)
348
+ img = func_dict[name](img, *args)
349
+ return img
350
+
351
+
352
+ class VideoRandomAugment(object):
353
+ def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
354
+ self.N = N
355
+ self.M = M
356
+ self.p = p
357
+ self.tensor_in_tensor_out = tensor_in_tensor_out
358
+ if augs:
359
+ self.augs = augs
360
+ else:
361
+ self.augs = list(arg_dict.keys())
362
+
363
+ def get_random_ops(self):
364
+ sampled_ops = np.random.choice(self.augs, self.N, replace=False)
365
+ return [(op, self.M) for op in sampled_ops]
366
+
367
+ def __call__(self, frames):
368
+ assert (
369
+ frames.shape[-1] == 3
370
+ ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
371
+
372
+ if self.tensor_in_tensor_out:
373
+ frames = frames.numpy().astype(np.uint8)
374
+
375
+ num_frames = frames.shape[0]
376
+
377
+ ops = num_frames * [self.get_random_ops()]
378
+ apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
379
+
380
+ frames = torch.stack(
381
+ list(map(self._aug, frames, ops, apply_or_not)), dim=0
382
+ ).float()
383
+
384
+ return frames
385
+
386
+ def _aug(self, img, ops, apply_or_not):
387
+ for i, (name, level) in enumerate(ops):
388
+ if not apply_or_not[i]:
389
+ continue
390
+ args = arg_dict[name](level)
391
+ img = func_dict[name](img, *args)
392
+ return torch.from_numpy(img)
393
+
394
+
395
+ if __name__ == "__main__":
396
+ a = RandomAugment()
397
+ img = np.random.randn(32, 32, 3)
398
+ a(img)
bliva_vicuna7b.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:512b0b5d6d98391e570ff6ee7778d7f39b267650e43098f8e6cc60539d56d037
3
+ size 15812679883
evaluate.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import argparse
4
+ from PIL import Image
5
+ from bliva.models import load_model_and_preprocess
6
+
7
+ def disable_torch_init():
8
+ """
9
+ Disable the redundant torch default initialization to accelerate model creation.
10
+ """
11
+ import torch
12
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
13
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
14
+
15
+ def parse_args():
16
+ """
17
+ Parse arguments from command line.
18
+ """
19
+ parser = argparse.ArgumentParser(description="Arguments for Evaluation")
20
+ parser.add_argument(
21
+ "--answer_mc",
22
+ action="store_true",
23
+ default=False,
24
+ help="Whether to evaluate multiple choice question with candidates."
25
+ )
26
+ parser.add_argument(
27
+ "--answer_qs",
28
+ action="store_true",
29
+ default=False,
30
+ help="Whether to evaluate only one question image."
31
+ )
32
+
33
+ parser.add_argument("--model_name", type=str, default="bliva_vicuna")
34
+ parser.add_argument("--device", type=str, default="cuda:0", help="Specify which gpu device to use.")
35
+ parser.add_argument("--img_path", type=str, required=True, help="the path to the image")
36
+ parser.add_argument("--question", type=str, required=True, help="the question to ask")
37
+ parser.add_argument("--candidates", type=str, help="list of choices for mulitple choice question")
38
+
39
+ args = parser.parse_args()
40
+ return args
41
+
42
+ def eval_one(image, question, model):
43
+ """
44
+ Evaluate one question
45
+ """
46
+ outputs = model.generate({"image": image, "prompt": question})
47
+ print("=====================================")
48
+ print("Question:", question[0])
49
+ print("-------------------------------------")
50
+ print("Outputs: ", outputs[0])
51
+
52
+
53
+ def eval_candidates(image, question, candidates, model):
54
+ """
55
+ Evaluate with candidates
56
+ """
57
+ outputs = model.predict_class({"image": image, "prompt": question}, candidates)
58
+ print("=====================================")
59
+ print("Question:", question[0])
60
+ print("-------------------------------------")
61
+ print("Candidates:", candidates)
62
+ print("-------------------------------------")
63
+ print("Outputs: ", candidates[outputs[0][0]])
64
+
65
+
66
+
67
+ def main(args):
68
+ np.random.seed(0)
69
+
70
+ disable_torch_init()
71
+
72
+ if args.model_name == "bliva_vicuna":
73
+ model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="vicuna7b", is_eval=True, device=args.device)
74
+ if args.model_name == "bliva_flant5":
75
+ model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="flant5xxl", is_eval=True, device=args.device)
76
+ vis_processor = vis_processors["eval"]
77
+
78
+ image = Image.open(args.img_path).convert('RGB')
79
+
80
+ question = [args.question]
81
+
82
+ image = vis_processor(image).unsqueeze(0).to(args.device)
83
+
84
+ if args.answer_qs:
85
+ eval_one(image, question, model)
86
+ elif args.answer_mc:
87
+ candidates = [candidate.strip() for candidate in args.candidates.split(",")]
88
+ eval_candidates(image, question, candidates, model)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ args = parse_args()
93
+ main(args)
hf_vicuna_7b/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "../data/hf_llama_7b/",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "bos_token_id": 1,
7
+ "eos_token_id": 2,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 11008,
12
+ "max_position_embeddings": 2048,
13
+ "model_type": "llama",
14
+ "num_attention_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "pad_token_id": 0,
17
+ "rms_norm_eps": 1e-06,
18
+ "tie_word_embeddings": false,
19
+ "torch_dtype": "float16",
20
+ "transformers_version": "4.28.1",
21
+ "use_cache": true,
22
+ "vocab_size": 32000
23
+ }
hf_vicuna_7b/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.28.1"
7
+ }
hf_vicuna_7b/pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ed572be140240b212049a3e791271eed8a04c40bc732c91a4da4b1469db23b1
3
+ size 9976634558
hf_vicuna_7b/pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9382da358f0ec38c4fca3bcf1e3e65274ae4c78090e0775b4bb4dea6a518e08
3
+ size 3500315539
hf_vicuna_7b/pytorch_model.bin.index.json ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 13476839424
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00002-of-00002.bin",
7
+ "model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
8
+ "model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
9
+ "model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
10
+ "model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
11
+ "model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
12
+ "model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
13
+ "model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
14
+ "model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
15
+ "model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
16
+ "model.layers.0.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
17
+ "model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
18
+ "model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
19
+ "model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
20
+ "model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
21
+ "model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
22
+ "model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
23
+ "model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
24
+ "model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
25
+ "model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
26
+ "model.layers.1.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
27
+ "model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
28
+ "model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
29
+ "model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
30
+ "model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
31
+ "model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
32
+ "model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
33
+ "model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
34
+ "model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
35
+ "model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
36
+ "model.layers.10.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
37
+ "model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
38
+ "model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
39
+ "model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
40
+ "model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
41
+ "model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
42
+ "model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
43
+ "model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
44
+ "model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
45
+ "model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
46
+ "model.layers.11.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
47
+ "model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
48
+ "model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
49
+ "model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
50
+ "model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
51
+ "model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
52
+ "model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
53
+ "model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
54
+ "model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
55
+ "model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
56
+ "model.layers.12.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
57
+ "model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
58
+ "model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
59
+ "model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
60
+ "model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
61
+ "model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
62
+ "model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
63
+ "model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
64
+ "model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
65
+ "model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
66
+ "model.layers.13.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
67
+ "model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
68
+ "model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
69
+ "model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
70
+ "model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
71
+ "model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
72
+ "model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
73
+ "model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
74
+ "model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
75
+ "model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
76
+ "model.layers.14.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
77
+ "model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
78
+ "model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
79
+ "model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
80
+ "model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
81
+ "model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
82
+ "model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
83
+ "model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
84
+ "model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
85
+ "model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
86
+ "model.layers.15.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
87
+ "model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
88
+ "model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
89
+ "model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
90
+ "model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
91
+ "model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
92
+ "model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
93
+ "model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
94
+ "model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
95
+ "model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
96
+ "model.layers.16.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
97
+ "model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
98
+ "model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
99
+ "model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
100
+ "model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
101
+ "model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
102
+ "model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
103
+ "model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
104
+ "model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
105
+ "model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
106
+ "model.layers.17.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
107
+ "model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
108
+ "model.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
109
+ "model.layers.18.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
110
+ "model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
111
+ "model.layers.18.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
112
+ "model.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
113
+ "model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
114
+ "model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
115
+ "model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
116
+ "model.layers.18.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
117
+ "model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
118
+ "model.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
119
+ "model.layers.19.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
120
+ "model.layers.19.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
121
+ "model.layers.19.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
122
+ "model.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
123
+ "model.layers.19.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
124
+ "model.layers.19.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
125
+ "model.layers.19.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
126
+ "model.layers.19.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
127
+ "model.layers.19.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
128
+ "model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
129
+ "model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
130
+ "model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
131
+ "model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
132
+ "model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
133
+ "model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
134
+ "model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
135
+ "model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
136
+ "model.layers.2.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
137
+ "model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
138
+ "model.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
139
+ "model.layers.20.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
140
+ "model.layers.20.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
141
+ "model.layers.20.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
142
+ "model.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
143
+ "model.layers.20.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
144
+ "model.layers.20.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
145
+ "model.layers.20.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
146
+ "model.layers.20.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
147
+ "model.layers.20.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
148
+ "model.layers.21.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
149
+ "model.layers.21.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
150
+ "model.layers.21.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
151
+ "model.layers.21.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
152
+ "model.layers.21.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
153
+ "model.layers.21.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
154
+ "model.layers.21.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
155
+ "model.layers.21.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
156
+ "model.layers.21.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
157
+ "model.layers.21.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
158
+ "model.layers.22.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
159
+ "model.layers.22.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
160
+ "model.layers.22.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
161
+ "model.layers.22.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
162
+ "model.layers.22.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
163
+ "model.layers.22.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
164
+ "model.layers.22.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
165
+ "model.layers.22.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
166
+ "model.layers.22.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
167
+ "model.layers.22.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
168
+ "model.layers.23.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
169
+ "model.layers.23.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
170
+ "model.layers.23.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
171
+ "model.layers.23.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
172
+ "model.layers.23.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
173
+ "model.layers.23.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
174
+ "model.layers.23.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
175
+ "model.layers.23.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
176
+ "model.layers.23.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
177
+ "model.layers.23.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
178
+ "model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
179
+ "model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
180
+ "model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
181
+ "model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
182
+ "model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
183
+ "model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
184
+ "model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
185
+ "model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
186
+ "model.layers.24.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
187
+ "model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
188
+ "model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
189
+ "model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
190
+ "model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
191
+ "model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
192
+ "model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
193
+ "model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
194
+ "model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
195
+ "model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
196
+ "model.layers.25.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
197
+ "model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
198
+ "model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
199
+ "model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
200
+ "model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
201
+ "model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
202
+ "model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
203
+ "model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
204
+ "model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
205
+ "model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
206
+ "model.layers.26.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
207
+ "model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
208
+ "model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
209
+ "model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
210
+ "model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
211
+ "model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
212
+ "model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
213
+ "model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
214
+ "model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
215
+ "model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
216
+ "model.layers.27.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
217
+ "model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
218
+ "model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
219
+ "model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
220
+ "model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
221
+ "model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
222
+ "model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
223
+ "model.layers.28.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
224
+ "model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
225
+ "model.layers.28.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
226
+ "model.layers.28.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
227
+ "model.layers.28.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
228
+ "model.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
229
+ "model.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
230
+ "model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
231
+ "model.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
232
+ "model.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
233
+ "model.layers.29.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
234
+ "model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
235
+ "model.layers.29.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
236
+ "model.layers.29.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
237
+ "model.layers.29.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
238
+ "model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
239
+ "model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
240
+ "model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
241
+ "model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
242
+ "model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
243
+ "model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
244
+ "model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
245
+ "model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
246
+ "model.layers.3.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
247
+ "model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
248
+ "model.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
249
+ "model.layers.30.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
250
+ "model.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
251
+ "model.layers.30.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
252
+ "model.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
253
+ "model.layers.30.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
254
+ "model.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
255
+ "model.layers.30.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
256
+ "model.layers.30.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
257
+ "model.layers.30.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
258
+ "model.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
259
+ "model.layers.31.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
260
+ "model.layers.31.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
261
+ "model.layers.31.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
262
+ "model.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
263
+ "model.layers.31.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
264
+ "model.layers.31.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
265
+ "model.layers.31.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
266
+ "model.layers.31.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
267
+ "model.layers.31.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
268
+ "model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
269
+ "model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
270
+ "model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
271
+ "model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
272
+ "model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
273
+ "model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
274
+ "model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
275
+ "model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
276
+ "model.layers.4.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
277
+ "model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
278
+ "model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
279
+ "model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
280
+ "model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
281
+ "model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
282
+ "model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
283
+ "model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
284
+ "model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
285
+ "model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
286
+ "model.layers.5.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
287
+ "model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
288
+ "model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
289
+ "model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
290
+ "model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
291
+ "model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
292
+ "model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
293
+ "model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
294
+ "model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
295
+ "model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
296
+ "model.layers.6.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
297
+ "model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
298
+ "model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
299
+ "model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
300
+ "model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
301
+ "model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
302
+ "model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
303
+ "model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
304
+ "model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
305
+ "model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
306
+ "model.layers.7.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
307
+ "model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
308
+ "model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
309
+ "model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
310
+ "model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
311
+ "model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
312
+ "model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
313
+ "model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
314
+ "model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
315
+ "model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
316
+ "model.layers.8.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
317
+ "model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
318
+ "model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
319
+ "model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
320
+ "model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
321
+ "model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
322
+ "model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
323
+ "model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
324
+ "model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
325
+ "model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
326
+ "model.layers.9.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
327
+ "model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
328
+ "model.norm.weight": "pytorch_model-00002-of-00002.bin"
329
+ }
330
+ }
hf_vicuna_7b/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
hf_vicuna_7b/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
hf_vicuna_7b/tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "model_max_length": 1000000000000000019884624838656,
22
+ "pad_token": null,
23
+ "sp_model_kwargs": {},
24
+ "tokenizer_class": "LlamaTokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
images/example.jpg ADDED
images/img1.jpg ADDED
images/img2.jpg ADDED
images/img3.jpg ADDED
images/img4.jpg ADDED
images/img5.jpg ADDED
images/img6.jpg ADDED