leadmaister Vision-CAIR commited on
Commit
7a1fc84
0 Parent(s):

Duplicate from Vision-CAIR/minigpt4

Browse files

Co-authored-by: visioncairgroup <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. CODEOWNERS +2 -0
  3. LICENSE.txt +14 -0
  4. MANIFEST.in +7 -0
  5. README.md +14 -0
  6. app.py +154 -0
  7. blip2_pretrained_flant5xxl.pth +3 -0
  8. checkpoint.pth +3 -0
  9. eval_configs/minigpt4.yaml +30 -0
  10. minigpt4/__init__.py +31 -0
  11. minigpt4/common/__init__.py +0 -0
  12. minigpt4/common/config.py +468 -0
  13. minigpt4/common/dist_utils.py +137 -0
  14. minigpt4/common/gradcam.py +24 -0
  15. minigpt4/common/logger.py +195 -0
  16. minigpt4/common/optims.py +119 -0
  17. minigpt4/common/registry.py +329 -0
  18. minigpt4/common/utils.py +424 -0
  19. minigpt4/configs/datasets/cc_combine/align.yaml +16 -0
  20. minigpt4/configs/datasets/cc_combine/defaults.yaml +11 -0
  21. minigpt4/configs/datasets/laion/defaults.yaml +13 -0
  22. minigpt4/configs/default.yaml +10 -0
  23. minigpt4/configs/models/minigpt4.yaml +39 -0
  24. minigpt4/conversation/__init__.py +0 -0
  25. minigpt4/conversation/conversation.py +199 -0
  26. minigpt4/datasets/__init__.py +0 -0
  27. minigpt4/datasets/builders/__init__.py +72 -0
  28. minigpt4/datasets/builders/base_dataset_builder.py +235 -0
  29. minigpt4/datasets/builders/image_text_pair_builder.py +86 -0
  30. minigpt4/datasets/data_utils.py +196 -0
  31. minigpt4/datasets/datasets/__init__.py +0 -0
  32. minigpt4/datasets/datasets/base_dataset.py +68 -0
  33. minigpt4/datasets/datasets/caption_datasets.py +85 -0
  34. minigpt4/datasets/datasets/cc_combine_dataset.py +53 -0
  35. minigpt4/datasets/datasets/dataloader_utils.py +162 -0
  36. minigpt4/datasets/datasets/laion_dataset.py +31 -0
  37. minigpt4/models/Qformer.py +1216 -0
  38. minigpt4/models/__init__.py +200 -0
  39. minigpt4/models/base_model.py +247 -0
  40. minigpt4/models/blip2.py +221 -0
  41. minigpt4/models/blip2_outputs.py +110 -0
  42. minigpt4/models/eva_vit.py +442 -0
  43. minigpt4/models/mini_gpt4.py +263 -0
  44. minigpt4/models/modeling_llama.py +772 -0
  45. minigpt4/processors/__init__.py +33 -0
  46. minigpt4/processors/base_processor.py +26 -0
  47. minigpt4/processors/blip_processors.py +141 -0
  48. minigpt4/processors/randaugment.py +398 -0
  49. minigpt4/runners/__init__.py +10 -0
  50. minigpt4/runners/runner_base.py +658 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ MiniGPT_4.pdf filter=lfs diff=lfs merge=lfs -text
CODEOWNERS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2
+ #ECCN:Open Source
LICENSE.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022 Salesforce, Inc.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
MANIFEST.in ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ recursive-include minigpt4/configs *.yaml *.json
2
+ recursive-include minigpt4/projects *.yaml *.json
3
+
4
+ recursive-exclude minigpt4/datasets/download_scripts *
5
+ recursive-exclude minigpt4/output *
6
+
7
+ include requirements.txt
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MiniGPT-4
3
+ emoji: 🚀
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.27
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ duplicated_from: Vision-CAIR/minigpt4
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ import gradio as gr
9
+
10
+ from minigpt4.common.config import Config
11
+ from minigpt4.common.dist_utils import get_rank
12
+ from minigpt4.common.registry import registry
13
+ from minigpt4.conversation.conversation import Chat, CONV_VISION
14
+
15
+ # imports modules for registration
16
+ from minigpt4.datasets.builders import *
17
+ from minigpt4.models import *
18
+ from minigpt4.processors import *
19
+ from minigpt4.runners import *
20
+ from minigpt4.tasks import *
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(description="Demo")
24
+ parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
25
+ parser.add_argument(
26
+ "--options",
27
+ nargs="+",
28
+ help="override some settings in the used config, the key-value pair "
29
+ "in xxx=yyy format will be merged into config file (deprecate), "
30
+ "change to --cfg-options instead.",
31
+ )
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+
36
+ def setup_seeds(config):
37
+ seed = config.run_cfg.seed + get_rank()
38
+
39
+ random.seed(seed)
40
+ np.random.seed(seed)
41
+ torch.manual_seed(seed)
42
+
43
+ cudnn.benchmark = False
44
+ cudnn.deterministic = True
45
+
46
+ # ========================================
47
+ # Model Initialization
48
+ # ========================================
49
+
50
+ SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
51
+
52
+ You can duplicate and use it with a paid private GPU.
53
+
54
+ <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
55
+
56
+ Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
57
+ '''
58
+
59
+ print('Initializing Chat')
60
+ cfg = Config(parse_args())
61
+
62
+ model_config = cfg.model_cfg
63
+ model_cls = registry.get_model_class(model_config.arch)
64
+ model = model_cls.from_config(model_config).to('cuda:0')
65
+
66
+ vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
67
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
68
+ chat = Chat(model, vis_processor)
69
+ print('Initialization Finished')
70
+
71
+ # ========================================
72
+ # Gradio Setting
73
+ # ========================================
74
+
75
+ def gradio_reset(chat_state, img_list):
76
+ if chat_state is not None:
77
+ chat_state.messages = []
78
+ if img_list is not None:
79
+ img_list = []
80
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
81
+
82
+ def upload_img(gr_img, text_input, chat_state):
83
+ if gr_img is None:
84
+ return None, None, gr.update(interactive=True), chat_state, None
85
+ chat_state = CONV_VISION.copy()
86
+ img_list = []
87
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
88
+ return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
89
+
90
+ def gradio_ask(user_message, chatbot, chat_state):
91
+ if len(user_message) == 0:
92
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
93
+ chat.ask(user_message, chat_state)
94
+ chatbot = chatbot + [[user_message, None]]
95
+ return '', chatbot, chat_state
96
+
97
+
98
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
99
+ llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
100
+ chatbot[-1][1] = llm_message
101
+ return chatbot, chat_state, img_list
102
+
103
+ title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
104
+ description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
105
+ article = """<div style='display:flex; gap: 0.25rem; '><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
106
+ """
107
+
108
+ #TODO show examples below
109
+
110
+ with gr.Blocks() as demo:
111
+ gr.Markdown(title)
112
+ gr.Markdown(SHARED_UI_WARNING)
113
+ gr.Markdown(description)
114
+ gr.Markdown(article)
115
+
116
+ with gr.Row():
117
+ with gr.Column(scale=0.5):
118
+ image = gr.Image(type="pil")
119
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
120
+ clear = gr.Button("Restart")
121
+
122
+ num_beams = gr.Slider(
123
+ minimum=1,
124
+ maximum=5,
125
+ value=1,
126
+ step=1,
127
+ interactive=True,
128
+ label="beam search numbers)",
129
+ )
130
+
131
+ temperature = gr.Slider(
132
+ minimum=0.1,
133
+ maximum=2.0,
134
+ value=1.0,
135
+ step=0.1,
136
+ interactive=True,
137
+ label="Temperature",
138
+ )
139
+
140
+
141
+ with gr.Column():
142
+ chat_state = gr.State()
143
+ img_list = gr.State()
144
+ chatbot = gr.Chatbot(label='MiniGPT-4')
145
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
146
+
147
+ upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
148
+
149
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
150
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
151
+ )
152
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
153
+
154
+ demo.launch(enable_queue=True)
blip2_pretrained_flant5xxl.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b3839ea6c617f315ead9bf4036bbb0f0cf6bf62695ecfc14968ea626af03a29
3
+ size 433481467
checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e7b8a3c21f146654c21a1a29a577dab2c3bd1aa3b1bc902f39e86954357a811
3
+ size 47369169
eval_configs/minigpt4.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: mini_gpt4
8
+ model_type: pretrain_vicuna
9
+ freeze_vit: True
10
+ freeze_qformer: True
11
+ max_txt_len: 160
12
+ end_sym: "###"
13
+ prompt_path: "prompts/alignment.txt"
14
+ prompt_template: '###Human: {} ###Assistant: '
15
+ ckpt: 'checkpoint.pth'
16
+
17
+
18
+ datasets:
19
+ cc_align:
20
+ vis_processor:
21
+ train:
22
+ name: "blip2_image_eval"
23
+ image_size: 224
24
+ text_processor:
25
+ train:
26
+ name: "blip_caption"
27
+
28
+ run:
29
+ task: image_text_pretrain
30
+
minigpt4/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from minigpt4.common.registry import registry
14
+
15
+ from minigpt4.datasets.builders import *
16
+ from minigpt4.models import *
17
+ from minigpt4.processors import *
18
+ from minigpt4.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
minigpt4/common/__init__.py ADDED
File without changes
minigpt4/common/config.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from minigpt4.common.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ model_config_path = model_cls.default_config_path(model_type=model_type)
72
+
73
+ model_config = OmegaConf.create()
74
+ # hiararchy override, customized config > default config
75
+ model_config = OmegaConf.merge(
76
+ model_config,
77
+ OmegaConf.load(model_config_path),
78
+ {"model": config["model"]},
79
+ )
80
+
81
+ return model_config
82
+
83
+ @staticmethod
84
+ def build_runner_config(config):
85
+ return {"run": config.run}
86
+
87
+ @staticmethod
88
+ def build_dataset_config(config):
89
+ datasets = config.get("datasets", None)
90
+ if datasets is None:
91
+ raise KeyError(
92
+ "Expecting 'datasets' as the root key for dataset configuration."
93
+ )
94
+
95
+ dataset_config = OmegaConf.create()
96
+
97
+ for dataset_name in datasets:
98
+ builder_cls = registry.get_builder_class(dataset_name)
99
+
100
+ dataset_config_type = datasets[dataset_name].get("type", "default")
101
+ dataset_config_path = builder_cls.default_config_path(
102
+ type=dataset_config_type
103
+ )
104
+
105
+ # hiararchy override, customized config > default config
106
+ dataset_config = OmegaConf.merge(
107
+ dataset_config,
108
+ OmegaConf.load(dataset_config_path),
109
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
110
+ )
111
+
112
+ return dataset_config
113
+
114
+ def _convert_to_dot_list(self, opts):
115
+ if opts is None:
116
+ opts = []
117
+
118
+ if len(opts) == 0:
119
+ return opts
120
+
121
+ has_equal = opts[0].find("=") != -1
122
+
123
+ if has_equal:
124
+ return opts
125
+
126
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
127
+
128
+ def get_config(self):
129
+ return self.config
130
+
131
+ @property
132
+ def run_cfg(self):
133
+ return self.config.run
134
+
135
+ @property
136
+ def datasets_cfg(self):
137
+ return self.config.datasets
138
+
139
+ @property
140
+ def model_cfg(self):
141
+ return self.config.model
142
+
143
+ def pretty_print(self):
144
+ logging.info("\n===== Running Parameters =====")
145
+ logging.info(self._convert_node_to_json(self.config.run))
146
+
147
+ logging.info("\n====== Dataset Attributes ======")
148
+ datasets = self.config.datasets
149
+
150
+ for dataset in datasets:
151
+ if dataset in self.config.datasets:
152
+ logging.info(f"\n======== {dataset} =======")
153
+ dataset_config = self.config.datasets[dataset]
154
+ logging.info(self._convert_node_to_json(dataset_config))
155
+ else:
156
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
157
+
158
+ logging.info(f"\n====== Model Attributes ======")
159
+ logging.info(self._convert_node_to_json(self.config.model))
160
+
161
+ def _convert_node_to_json(self, node):
162
+ container = OmegaConf.to_container(node, resolve=True)
163
+ return json.dumps(container, indent=4, sort_keys=True)
164
+
165
+ def to_dict(self):
166
+ return OmegaConf.to_container(self.config)
167
+
168
+
169
+ def node_to_dict(node):
170
+ return OmegaConf.to_container(node)
171
+
172
+
173
+ class ConfigValidator:
174
+ """
175
+ This is a preliminary implementation to centralize and validate the configuration.
176
+ May be altered in the future.
177
+
178
+ A helper class to validate configurations from yaml file.
179
+
180
+ This serves the following purposes:
181
+ 1. Ensure all the options in the yaml are defined, raise error if not.
182
+ 2. when type mismatches are found, the validator will raise an error.
183
+ 3. a central place to store and display helpful messages for supported configurations.
184
+
185
+ """
186
+
187
+ class _Argument:
188
+ def __init__(self, name, choices=None, type=None, help=None):
189
+ self.name = name
190
+ self.val = None
191
+ self.choices = choices
192
+ self.type = type
193
+ self.help = help
194
+
195
+ def __str__(self):
196
+ s = f"{self.name}={self.val}"
197
+ if self.type is not None:
198
+ s += f", ({self.type})"
199
+ if self.choices is not None:
200
+ s += f", choices: {self.choices}"
201
+ if self.help is not None:
202
+ s += f", ({self.help})"
203
+ return s
204
+
205
+ def __init__(self, description):
206
+ self.description = description
207
+
208
+ self.arguments = dict()
209
+
210
+ self.parsed_args = None
211
+
212
+ def __getitem__(self, key):
213
+ assert self.parsed_args is not None, "No arguments parsed yet."
214
+
215
+ return self.parsed_args[key]
216
+
217
+ def __str__(self) -> str:
218
+ return self.format_help()
219
+
220
+ def add_argument(self, *args, **kwargs):
221
+ """
222
+ Assume the first argument is the name of the argument.
223
+ """
224
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
225
+
226
+ def validate(self, config=None):
227
+ """
228
+ Convert yaml config (dict-like) to list, required by argparse.
229
+ """
230
+ for k, v in config.items():
231
+ assert (
232
+ k in self.arguments
233
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
234
+
235
+ if self.arguments[k].type is not None:
236
+ try:
237
+ self.arguments[k].val = self.arguments[k].type(v)
238
+ except ValueError:
239
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
240
+
241
+ if self.arguments[k].choices is not None:
242
+ assert (
243
+ v in self.arguments[k].choices
244
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
245
+
246
+ return config
247
+
248
+ def format_arguments(self):
249
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
250
+
251
+ def format_help(self):
252
+ # description + key-value pair string for each argument
253
+ help_msg = str(self.description)
254
+ return help_msg + ", available arguments: " + self.format_arguments()
255
+
256
+ def print_help(self):
257
+ # display help message
258
+ print(self.format_help())
259
+
260
+
261
+ def create_runner_config_validator():
262
+ validator = ConfigValidator(description="Runner configurations")
263
+
264
+ validator.add_argument(
265
+ "runner",
266
+ type=str,
267
+ choices=["runner_base", "runner_iter"],
268
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
269
+ runner runs based on iters. Default: runner_base""",
270
+ )
271
+ # add argumetns for training dataset ratios
272
+ validator.add_argument(
273
+ "train_dataset_ratios",
274
+ type=Dict[str, float],
275
+ help="""Ratios of training dataset. This is used in iteration-based runner.
276
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
277
+ Default: None""",
278
+ )
279
+ validator.add_argument(
280
+ "max_iters",
281
+ type=float,
282
+ help="Maximum number of iterations to run.",
283
+ )
284
+ validator.add_argument(
285
+ "max_epoch",
286
+ type=int,
287
+ help="Maximum number of epochs to run.",
288
+ )
289
+ # add arguments for iters_per_inner_epoch
290
+ validator.add_argument(
291
+ "iters_per_inner_epoch",
292
+ type=float,
293
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
294
+ )
295
+ lr_scheds_choices = registry.list_lr_schedulers()
296
+ validator.add_argument(
297
+ "lr_sched",
298
+ type=str,
299
+ choices=lr_scheds_choices,
300
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
301
+ )
302
+ task_choices = registry.list_tasks()
303
+ validator.add_argument(
304
+ "task",
305
+ type=str,
306
+ choices=task_choices,
307
+ help="Task to use, from {}".format(task_choices),
308
+ )
309
+ # add arguments for init_lr
310
+ validator.add_argument(
311
+ "init_lr",
312
+ type=float,
313
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
314
+ )
315
+ # add arguments for min_lr
316
+ validator.add_argument(
317
+ "min_lr",
318
+ type=float,
319
+ help="Minimum learning rate (after decay).",
320
+ )
321
+ # add arguments for warmup_lr
322
+ validator.add_argument(
323
+ "warmup_lr",
324
+ type=float,
325
+ help="Starting learning rate for warmup.",
326
+ )
327
+ # add arguments for learning rate decay rate
328
+ validator.add_argument(
329
+ "lr_decay_rate",
330
+ type=float,
331
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
332
+ )
333
+ # add arguments for weight decay
334
+ validator.add_argument(
335
+ "weight_decay",
336
+ type=float,
337
+ help="Weight decay rate.",
338
+ )
339
+ # add arguments for training batch size
340
+ validator.add_argument(
341
+ "batch_size_train",
342
+ type=int,
343
+ help="Training batch size.",
344
+ )
345
+ # add arguments for evaluation batch size
346
+ validator.add_argument(
347
+ "batch_size_eval",
348
+ type=int,
349
+ help="Evaluation batch size, including validation and testing.",
350
+ )
351
+ # add arguments for number of workers for data loading
352
+ validator.add_argument(
353
+ "num_workers",
354
+ help="Number of workers for data loading.",
355
+ )
356
+ # add arguments for warm up steps
357
+ validator.add_argument(
358
+ "warmup_steps",
359
+ type=int,
360
+ help="Number of warmup steps. Required if a warmup schedule is used.",
361
+ )
362
+ # add arguments for random seed
363
+ validator.add_argument(
364
+ "seed",
365
+ type=int,
366
+ help="Random seed.",
367
+ )
368
+ # add arguments for output directory
369
+ validator.add_argument(
370
+ "output_dir",
371
+ type=str,
372
+ help="Output directory to save checkpoints and logs.",
373
+ )
374
+ # add arguments for whether only use evaluation
375
+ validator.add_argument(
376
+ "evaluate",
377
+ help="Whether to only evaluate the model. If true, training will not be performed.",
378
+ )
379
+ # add arguments for splits used for training, e.g. ["train", "val"]
380
+ validator.add_argument(
381
+ "train_splits",
382
+ type=list,
383
+ help="Splits to use for training.",
384
+ )
385
+ # add arguments for splits used for validation, e.g. ["val"]
386
+ validator.add_argument(
387
+ "valid_splits",
388
+ type=list,
389
+ help="Splits to use for validation. If not provided, will skip the validation.",
390
+ )
391
+ # add arguments for splits used for testing, e.g. ["test"]
392
+ validator.add_argument(
393
+ "test_splits",
394
+ type=list,
395
+ help="Splits to use for testing. If not provided, will skip the testing.",
396
+ )
397
+ # add arguments for accumulating gradient for iterations
398
+ validator.add_argument(
399
+ "accum_grad_iters",
400
+ type=int,
401
+ help="Number of iterations to accumulate gradient for.",
402
+ )
403
+
404
+ # ====== distributed training ======
405
+ validator.add_argument(
406
+ "device",
407
+ type=str,
408
+ choices=["cpu", "cuda"],
409
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
410
+ )
411
+ validator.add_argument(
412
+ "world_size",
413
+ type=int,
414
+ help="Number of processes participating in the job.",
415
+ )
416
+ validator.add_argument("dist_url", type=str)
417
+ validator.add_argument("distributed", type=bool)
418
+ # add arguments to opt using distributed sampler during evaluation or not
419
+ validator.add_argument(
420
+ "use_dist_eval_sampler",
421
+ type=bool,
422
+ help="Whether to use distributed sampler during evaluation or not.",
423
+ )
424
+
425
+ # ====== task specific ======
426
+ # generation task specific arguments
427
+ # add arguments for maximal length of text output
428
+ validator.add_argument(
429
+ "max_len",
430
+ type=int,
431
+ help="Maximal length of text output.",
432
+ )
433
+ # add arguments for minimal length of text output
434
+ validator.add_argument(
435
+ "min_len",
436
+ type=int,
437
+ help="Minimal length of text output.",
438
+ )
439
+ # add arguments number of beams
440
+ validator.add_argument(
441
+ "num_beams",
442
+ type=int,
443
+ help="Number of beams used for beam search.",
444
+ )
445
+
446
+ # vqa task specific arguments
447
+ # add arguments for number of answer candidates
448
+ validator.add_argument(
449
+ "num_ans_candidates",
450
+ type=int,
451
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
452
+ )
453
+ # add arguments for inference method
454
+ validator.add_argument(
455
+ "inference_method",
456
+ type=str,
457
+ choices=["genearte", "rank"],
458
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
459
+ )
460
+
461
+ # ====== model specific ======
462
+ validator.add_argument(
463
+ "k_test",
464
+ type=int,
465
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
466
+ )
467
+
468
+ return validator
minigpt4/common/dist_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
minigpt4/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
minigpt4/common/logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from minigpt4.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
minigpt4/common/optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from minigpt4.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
minigpt4/common/registry.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_builder(cls, name):
23
+ r"""Register a dataset builder to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the builder will be registered.
27
+
28
+ Usage:
29
+
30
+ from minigpt4.common.registry import registry
31
+ from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
32
+ """
33
+
34
+ def wrap(builder_cls):
35
+ from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
+
37
+ assert issubclass(
38
+ builder_cls, BaseDatasetBuilder
39
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
+ builder_cls
41
+ )
42
+ if name in cls.mapping["builder_name_mapping"]:
43
+ raise KeyError(
44
+ "Name '{}' already registered for {}.".format(
45
+ name, cls.mapping["builder_name_mapping"][name]
46
+ )
47
+ )
48
+ cls.mapping["builder_name_mapping"][name] = builder_cls
49
+ return builder_cls
50
+
51
+ return wrap
52
+
53
+ @classmethod
54
+ def register_task(cls, name):
55
+ r"""Register a task to registry with key 'name'
56
+
57
+ Args:
58
+ name: Key with which the task will be registered.
59
+
60
+ Usage:
61
+
62
+ from minigpt4.common.registry import registry
63
+ """
64
+
65
+ def wrap(task_cls):
66
+ from minigpt4.tasks.base_task import BaseTask
67
+
68
+ assert issubclass(
69
+ task_cls, BaseTask
70
+ ), "All tasks must inherit BaseTask class"
71
+ if name in cls.mapping["task_name_mapping"]:
72
+ raise KeyError(
73
+ "Name '{}' already registered for {}.".format(
74
+ name, cls.mapping["task_name_mapping"][name]
75
+ )
76
+ )
77
+ cls.mapping["task_name_mapping"][name] = task_cls
78
+ return task_cls
79
+
80
+ return wrap
81
+
82
+ @classmethod
83
+ def register_model(cls, name):
84
+ r"""Register a task to registry with key 'name'
85
+
86
+ Args:
87
+ name: Key with which the task will be registered.
88
+
89
+ Usage:
90
+
91
+ from minigpt4.common.registry import registry
92
+ """
93
+
94
+ def wrap(model_cls):
95
+ from minigpt4.models import BaseModel
96
+
97
+ assert issubclass(
98
+ model_cls, BaseModel
99
+ ), "All models must inherit BaseModel class"
100
+ if name in cls.mapping["model_name_mapping"]:
101
+ raise KeyError(
102
+ "Name '{}' already registered for {}.".format(
103
+ name, cls.mapping["model_name_mapping"][name]
104
+ )
105
+ )
106
+ cls.mapping["model_name_mapping"][name] = model_cls
107
+ return model_cls
108
+
109
+ return wrap
110
+
111
+ @classmethod
112
+ def register_processor(cls, name):
113
+ r"""Register a processor to registry with key 'name'
114
+
115
+ Args:
116
+ name: Key with which the task will be registered.
117
+
118
+ Usage:
119
+
120
+ from minigpt4.common.registry import registry
121
+ """
122
+
123
+ def wrap(processor_cls):
124
+ from minigpt4.processors import BaseProcessor
125
+
126
+ assert issubclass(
127
+ processor_cls, BaseProcessor
128
+ ), "All processors must inherit BaseProcessor class"
129
+ if name in cls.mapping["processor_name_mapping"]:
130
+ raise KeyError(
131
+ "Name '{}' already registered for {}.".format(
132
+ name, cls.mapping["processor_name_mapping"][name]
133
+ )
134
+ )
135
+ cls.mapping["processor_name_mapping"][name] = processor_cls
136
+ return processor_cls
137
+
138
+ return wrap
139
+
140
+ @classmethod
141
+ def register_lr_scheduler(cls, name):
142
+ r"""Register a model to registry with key 'name'
143
+
144
+ Args:
145
+ name: Key with which the task will be registered.
146
+
147
+ Usage:
148
+
149
+ from minigpt4.common.registry import registry
150
+ """
151
+
152
+ def wrap(lr_sched_cls):
153
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
154
+ raise KeyError(
155
+ "Name '{}' already registered for {}.".format(
156
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
157
+ )
158
+ )
159
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160
+ return lr_sched_cls
161
+
162
+ return wrap
163
+
164
+ @classmethod
165
+ def register_runner(cls, name):
166
+ r"""Register a model to registry with key 'name'
167
+
168
+ Args:
169
+ name: Key with which the task will be registered.
170
+
171
+ Usage:
172
+
173
+ from minigpt4.common.registry import registry
174
+ """
175
+
176
+ def wrap(runner_cls):
177
+ if name in cls.mapping["runner_name_mapping"]:
178
+ raise KeyError(
179
+ "Name '{}' already registered for {}.".format(
180
+ name, cls.mapping["runner_name_mapping"][name]
181
+ )
182
+ )
183
+ cls.mapping["runner_name_mapping"][name] = runner_cls
184
+ return runner_cls
185
+
186
+ return wrap
187
+
188
+ @classmethod
189
+ def register_path(cls, name, path):
190
+ r"""Register a path to registry with key 'name'
191
+
192
+ Args:
193
+ name: Key with which the path will be registered.
194
+
195
+ Usage:
196
+
197
+ from minigpt4.common.registry import registry
198
+ """
199
+ assert isinstance(path, str), "All path must be str."
200
+ if name in cls.mapping["paths"]:
201
+ raise KeyError("Name '{}' already registered.".format(name))
202
+ cls.mapping["paths"][name] = path
203
+
204
+ @classmethod
205
+ def register(cls, name, obj):
206
+ r"""Register an item to registry with key 'name'
207
+
208
+ Args:
209
+ name: Key with which the item will be registered.
210
+
211
+ Usage::
212
+
213
+ from minigpt4.common.registry import registry
214
+
215
+ registry.register("config", {})
216
+ """
217
+ path = name.split(".")
218
+ current = cls.mapping["state"]
219
+
220
+ for part in path[:-1]:
221
+ if part not in current:
222
+ current[part] = {}
223
+ current = current[part]
224
+
225
+ current[path[-1]] = obj
226
+
227
+ # @classmethod
228
+ # def get_trainer_class(cls, name):
229
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
230
+
231
+ @classmethod
232
+ def get_builder_class(cls, name):
233
+ return cls.mapping["builder_name_mapping"].get(name, None)
234
+
235
+ @classmethod
236
+ def get_model_class(cls, name):
237
+ return cls.mapping["model_name_mapping"].get(name, None)
238
+
239
+ @classmethod
240
+ def get_task_class(cls, name):
241
+ return cls.mapping["task_name_mapping"].get(name, None)
242
+
243
+ @classmethod
244
+ def get_processor_class(cls, name):
245
+ return cls.mapping["processor_name_mapping"].get(name, None)
246
+
247
+ @classmethod
248
+ def get_lr_scheduler_class(cls, name):
249
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250
+
251
+ @classmethod
252
+ def get_runner_class(cls, name):
253
+ return cls.mapping["runner_name_mapping"].get(name, None)
254
+
255
+ @classmethod
256
+ def list_runners(cls):
257
+ return sorted(cls.mapping["runner_name_mapping"].keys())
258
+
259
+ @classmethod
260
+ def list_models(cls):
261
+ return sorted(cls.mapping["model_name_mapping"].keys())
262
+
263
+ @classmethod
264
+ def list_tasks(cls):
265
+ return sorted(cls.mapping["task_name_mapping"].keys())
266
+
267
+ @classmethod
268
+ def list_processors(cls):
269
+ return sorted(cls.mapping["processor_name_mapping"].keys())
270
+
271
+ @classmethod
272
+ def list_lr_schedulers(cls):
273
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274
+
275
+ @classmethod
276
+ def list_datasets(cls):
277
+ return sorted(cls.mapping["builder_name_mapping"].keys())
278
+
279
+ @classmethod
280
+ def get_path(cls, name):
281
+ return cls.mapping["paths"].get(name, None)
282
+
283
+ @classmethod
284
+ def get(cls, name, default=None, no_warning=False):
285
+ r"""Get an item from registry with key 'name'
286
+
287
+ Args:
288
+ name (string): Key whose value needs to be retrieved.
289
+ default: If passed and key is not in registry, default value will
290
+ be returned with a warning. Default: None
291
+ no_warning (bool): If passed as True, warning when key doesn't exist
292
+ will not be generated. Useful for MMF's
293
+ internal operations. Default: False
294
+ """
295
+ original_name = name
296
+ name = name.split(".")
297
+ value = cls.mapping["state"]
298
+ for subname in name:
299
+ value = value.get(subname, default)
300
+ if value is default:
301
+ break
302
+
303
+ if (
304
+ "writer" in cls.mapping["state"]
305
+ and value == default
306
+ and no_warning is False
307
+ ):
308
+ cls.mapping["state"]["writer"].warning(
309
+ "Key {} is not present in registry, returning default value "
310
+ "of {}".format(original_name, default)
311
+ )
312
+ return value
313
+
314
+ @classmethod
315
+ def unregister(cls, name):
316
+ r"""Remove an item from registry with key 'name'
317
+
318
+ Args:
319
+ name: Key which needs to be removed.
320
+ Usage::
321
+
322
+ from mmf.common.registry import registry
323
+
324
+ config = registry.unregister("config")
325
+ """
326
+ return cls.mapping["state"].pop(name, None)
327
+
328
+
329
+ registry = Registry()
minigpt4/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from minigpt4.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
minigpt4/configs/datasets/cc_combine/align.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ cc_align:
8
+ data_type: images
9
+ build_info:
10
+ # Be careful not to append minus sign (-) before split to avoid itemizing
11
+ annotations:
12
+ train:
13
+ url: placeholder
14
+ storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/filter_cap.json
15
+ images:
16
+ storage: /ibex/project/c2133/blip_dataset/image_alignment_cc/
minigpt4/configs/datasets/cc_combine/defaults.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ cc_combine:
8
+ data_type: images
9
+ build_info:
10
+ # Be careful not to append minus sign (-) before split to avoid itemizing
11
+ storage: /ibex/project/c2133/blip_dataset/cc3m/cc3m_cc12m_sbu/{00000..01255}.tar
minigpt4/configs/datasets/laion/defaults.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ laion:
8
+
9
+ data_type: images
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ storage: /ibex/project/c2133/blip_dataset/laion_1b/laion_gpu/{00000..10488}.tar
minigpt4/configs/default.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ env:
7
+ # For default users
8
+ # cache_root: "cache"
9
+ # For internal use with persistent storage
10
+ cache_root: "/export/home/.cache/minigpt4"
minigpt4/configs/models/minigpt4.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: mini_gpt4
8
+
9
+ # vit encoder
10
+ image_size: 224
11
+ drop_path_rate: 0
12
+ use_grad_checkpoint: False
13
+ vit_precision: "fp16"
14
+ freeze_vit: True
15
+ freeze_qformer: True
16
+
17
+ # Q-Former
18
+ num_query_token: 32
19
+
20
+ # Vicuna
21
+ llama_model: "vicuna"
22
+
23
+ # generation configs
24
+ prompt: ""
25
+
26
+
27
+ preprocess:
28
+ vis_processor:
29
+ train:
30
+ name: "blip2_image_train"
31
+ image_size: 224
32
+ eval:
33
+ name: "blip2_image_eval"
34
+ image_size: 224
35
+ text_processor:
36
+ train:
37
+ name: "blip_caption"
38
+ eval:
39
+ name: "blip_caption"
minigpt4/conversation/__init__.py ADDED
File without changes
minigpt4/conversation/conversation.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+
9
+ import dataclasses
10
+ from enum import auto, Enum
11
+ from typing import List, Tuple, Any
12
+
13
+ from minigpt4.common.registry import registry
14
+
15
+
16
+ class SeparatorStyle(Enum):
17
+ """Different separator style."""
18
+ SINGLE = auto()
19
+ TWO = auto()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Conversation:
24
+ """A class that keeps all conversation history."""
25
+ system: str
26
+ roles: List[str]
27
+ messages: List[List[str]]
28
+ offset: int
29
+ # system_img: List[Image.Image] = []
30
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31
+ sep: str = "###"
32
+ sep2: str = None
33
+
34
+ skip_next: bool = False
35
+ conv_id: Any = None
36
+
37
+ def get_prompt(self):
38
+ if self.sep_style == SeparatorStyle.SINGLE:
39
+ ret = self.system + self.sep
40
+ for role, message in self.messages:
41
+ if message:
42
+ ret += role + ": " + message + self.sep
43
+ else:
44
+ ret += role + ":"
45
+ return ret
46
+ elif self.sep_style == SeparatorStyle.TWO:
47
+ seps = [self.sep, self.sep2]
48
+ ret = self.system + seps[0]
49
+ for i, (role, message) in enumerate(self.messages):
50
+ if message:
51
+ ret += role + ": " + message + seps[i % 2]
52
+ else:
53
+ ret += role + ":"
54
+ return ret
55
+ else:
56
+ raise ValueError(f"Invalid style: {self.sep_style}")
57
+
58
+ def append_message(self, role, message):
59
+ self.messages.append([role, message])
60
+
61
+ def to_gradio_chatbot(self):
62
+ ret = []
63
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
64
+ if i % 2 == 0:
65
+ ret.append([msg, None])
66
+ else:
67
+ ret[-1][-1] = msg
68
+ return ret
69
+
70
+ def copy(self):
71
+ return Conversation(
72
+ system=self.system,
73
+ # system_img=self.system_img,
74
+ roles=self.roles,
75
+ messages=[[x, y] for x, y in self.messages],
76
+ offset=self.offset,
77
+ sep_style=self.sep_style,
78
+ sep=self.sep,
79
+ sep2=self.sep2,
80
+ conv_id=self.conv_id)
81
+
82
+ def dict(self):
83
+ return {
84
+ "system": self.system,
85
+ # "system_img": self.system_img,
86
+ "roles": self.roles,
87
+ "messages": self.messages,
88
+ "offset": self.offset,
89
+ "sep": self.sep,
90
+ "sep2": self.sep2,
91
+ "conv_id": self.conv_id,
92
+ }
93
+
94
+
95
+ class StoppingCriteriaSub(StoppingCriteria):
96
+
97
+ def __init__(self, stops=[], encounters=1):
98
+ super().__init__()
99
+ self.stops = stops
100
+
101
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
102
+ for stop in self.stops:
103
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
104
+ return True
105
+
106
+ return False
107
+
108
+
109
+ CONV_VISION = Conversation(
110
+ system="Give the following image: <Img>ImageContent</Img>. "
111
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
112
+ roles=("Human", "Assistant"),
113
+ messages=[],
114
+ offset=2,
115
+ sep_style=SeparatorStyle.SINGLE,
116
+ sep="###",
117
+ )
118
+
119
+
120
+
121
+ class Chat:
122
+ def __init__(self, model, vis_processor, device='cuda:0'):
123
+ self.device = device
124
+ self.model = model
125
+ self.vis_processor = vis_processor
126
+ stop_words_ids = [torch.tensor([835]).to(self.device),
127
+ torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
128
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
129
+
130
+ def ask(self, text, conv):
131
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
132
+ and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
133
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
134
+ else:
135
+ conv.append_message(conv.roles[0], text)
136
+
137
+ def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
138
+ repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
139
+ conv.append_message(conv.roles[1], None)
140
+ embs = self.get_context_emb(conv, img_list)
141
+
142
+ # current_max_len = embs.shape[1] + max_new_tokens + 100
143
+ # begin_idx = max(0, current_max_len - max_length)
144
+ # embs = embs[:, begin_idx:]
145
+ outputs = self.model.llama_model.generate(
146
+ inputs_embeds=embs,
147
+ max_new_tokens=max_new_tokens,
148
+ stopping_criteria=self.stopping_criteria,
149
+ num_beams=num_beams,
150
+ min_length=min_length,
151
+ top_p=top_p,
152
+ repetition_penalty=repetition_penalty,
153
+ length_penalty=length_penalty,
154
+ temperature=temperature,
155
+ )
156
+ output_token = outputs[0]
157
+ if output_token[0] == 0:
158
+ output_token = output_token[1:]
159
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
160
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
161
+ output_text = output_text.split('Assistant:')[-1].strip()
162
+ conv.messages[-1][1] = output_text
163
+ return output_text, output_token.cpu().numpy()
164
+
165
+ def upload_img(self, image, conv, img_list):
166
+ if isinstance(image, str): # is a image path
167
+ raw_image = Image.open(image).convert('RGB')
168
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
169
+ elif isinstance(image, Image.Image):
170
+ raw_image = image
171
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
172
+ elif isinstance(image, torch.Tensor):
173
+ if len(image.shape) == 3:
174
+ image = image.unsqueeze(0)
175
+ image = image.to(self.device)
176
+
177
+ image_emb, _ = self.model.encode_img(image)
178
+ img_list.append(image_emb)
179
+ conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
180
+ msg = "Received."
181
+ # self.conv.append_message(self.conv.roles[1], msg)
182
+ return msg
183
+
184
+ def get_context_emb(self, conv, img_list):
185
+ prompt = conv.get_prompt()
186
+ prompt_segs = prompt.split('<ImageHere>')
187
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
188
+ seg_tokens = [
189
+ self.model.llama_tokenizer(
190
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
191
+ # only add bos to the first seg
192
+ for i, seg in enumerate(prompt_segs)
193
+ ]
194
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
195
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
196
+ mixed_embs = torch.cat(mixed_embs, dim=1)
197
+ return mixed_embs
198
+
199
+
minigpt4/datasets/__init__.py ADDED
File without changes
minigpt4/datasets/builders/__init__.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
9
+ from minigpt4.datasets.builders.image_text_pair_builder import (
10
+ CCCombineBuilder,
11
+ LaionBuilder,
12
+ CCAlignBuilder
13
+ )
14
+ from minigpt4.common.registry import registry
15
+
16
+ __all__ = [
17
+ "CCCombineBuilder",
18
+ "LaionBuilder",
19
+ "CCAlignBuilder"
20
+ ]
21
+
22
+
23
+ def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
24
+ """
25
+ Example
26
+
27
+ >>> dataset = load_dataset("coco_caption", cfg=None)
28
+ >>> splits = dataset.keys()
29
+ >>> print([len(dataset[split]) for split in splits])
30
+
31
+ """
32
+ if cfg_path is None:
33
+ cfg = None
34
+ else:
35
+ cfg = load_dataset_config(cfg_path)
36
+
37
+ try:
38
+ builder = registry.get_builder_class(name)(cfg)
39
+ except TypeError:
40
+ print(
41
+ f"Dataset {name} not found. Available datasets:\n"
42
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
43
+ )
44
+ exit(1)
45
+
46
+ if vis_path is not None:
47
+ if data_type is None:
48
+ # use default data type in the config
49
+ data_type = builder.config.data_type
50
+
51
+ assert (
52
+ data_type in builder.config.build_info
53
+ ), f"Invalid data_type {data_type} for {name}."
54
+
55
+ builder.config.build_info.get(data_type).storage = vis_path
56
+
57
+ dataset = builder.build_datasets()
58
+ return dataset
59
+
60
+
61
+ class DatasetZoo:
62
+ def __init__(self) -> None:
63
+ self.dataset_zoo = {
64
+ k: list(v.DATASET_CONFIG_DICT.keys())
65
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
66
+ }
67
+
68
+ def get_names(self):
69
+ return list(self.dataset_zoo.keys())
70
+
71
+
72
+ dataset_zoo = DatasetZoo()
minigpt4/datasets/builders/base_dataset_builder.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import shutil
11
+ import warnings
12
+
13
+ from omegaconf import OmegaConf
14
+ import torch.distributed as dist
15
+ from torchvision.datasets.utils import download_url
16
+
17
+ import minigpt4.common.utils as utils
18
+ from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
19
+ from minigpt4.common.registry import registry
20
+ from minigpt4.processors.base_processor import BaseProcessor
21
+
22
+
23
+
24
+ class BaseDatasetBuilder:
25
+ train_dataset_cls, eval_dataset_cls = None, None
26
+
27
+ def __init__(self, cfg=None):
28
+ super().__init__()
29
+
30
+ if cfg is None:
31
+ # help to create datasets from default config.
32
+ self.config = load_dataset_config(self.default_config_path())
33
+ elif isinstance(cfg, str):
34
+ self.config = load_dataset_config(cfg)
35
+ else:
36
+ # when called from task.build_dataset()
37
+ self.config = cfg
38
+
39
+ self.data_type = self.config.data_type
40
+
41
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
42
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43
+
44
+ def build_datasets(self):
45
+ # download, split, etc...
46
+ # only called on 1 GPU/TPU in distributed
47
+
48
+ if is_main_process():
49
+ self._download_data()
50
+
51
+ if is_dist_avail_and_initialized():
52
+ dist.barrier()
53
+
54
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
55
+ logging.info("Building datasets...")
56
+ datasets = self.build() # dataset['train'/'val'/'test']
57
+
58
+ return datasets
59
+
60
+ def build_processors(self):
61
+ vis_proc_cfg = self.config.get("vis_processor")
62
+ txt_proc_cfg = self.config.get("text_processor")
63
+
64
+ if vis_proc_cfg is not None:
65
+ vis_train_cfg = vis_proc_cfg.get("train")
66
+ vis_eval_cfg = vis_proc_cfg.get("eval")
67
+
68
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
69
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
70
+
71
+ if txt_proc_cfg is not None:
72
+ txt_train_cfg = txt_proc_cfg.get("train")
73
+ txt_eval_cfg = txt_proc_cfg.get("eval")
74
+
75
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
76
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
77
+
78
+ @staticmethod
79
+ def _build_proc_from_cfg(cfg):
80
+ return (
81
+ registry.get_processor_class(cfg.name).from_config(cfg)
82
+ if cfg is not None
83
+ else None
84
+ )
85
+
86
+ @classmethod
87
+ def default_config_path(cls, type="default"):
88
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
89
+
90
+ def _download_data(self):
91
+ self._download_ann()
92
+ self._download_vis()
93
+
94
+ def _download_ann(self):
95
+ """
96
+ Download annotation files if necessary.
97
+ All the vision-language datasets should have annotations of unified format.
98
+
99
+ storage_path can be:
100
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
101
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
102
+
103
+ Local annotation paths should be relative.
104
+ """
105
+ anns = self.config.build_info.annotations
106
+
107
+ splits = anns.keys()
108
+
109
+ cache_root = registry.get_path("cache_root")
110
+
111
+ for split in splits:
112
+ info = anns[split]
113
+
114
+ urls, storage_paths = info.get("url", None), info.storage
115
+
116
+ if isinstance(urls, str):
117
+ urls = [urls]
118
+ if isinstance(storage_paths, str):
119
+ storage_paths = [storage_paths]
120
+
121
+ assert len(urls) == len(storage_paths)
122
+
123
+ for url_or_filename, storage_path in zip(urls, storage_paths):
124
+ # if storage_path is relative, make it full by prefixing with cache_root.
125
+ if not os.path.isabs(storage_path):
126
+ storage_path = os.path.join(cache_root, storage_path)
127
+
128
+ dirname = os.path.dirname(storage_path)
129
+ if not os.path.exists(dirname):
130
+ os.makedirs(dirname)
131
+
132
+ if os.path.isfile(url_or_filename):
133
+ src, dst = url_or_filename, storage_path
134
+ if not os.path.exists(dst):
135
+ shutil.copyfile(src=src, dst=dst)
136
+ else:
137
+ logging.info("Using existing file {}.".format(dst))
138
+ else:
139
+ if os.path.isdir(storage_path):
140
+ # if only dirname is provided, suffix with basename of URL.
141
+ raise ValueError(
142
+ "Expecting storage_path to be a file path, got directory {}".format(
143
+ storage_path
144
+ )
145
+ )
146
+ else:
147
+ filename = os.path.basename(storage_path)
148
+
149
+ download_url(url=url_or_filename, root=dirname, filename=filename)
150
+
151
+ def _download_vis(self):
152
+
153
+ storage_path = self.config.build_info.get(self.data_type).storage
154
+ storage_path = utils.get_cache_path(storage_path)
155
+
156
+ if not os.path.exists(storage_path):
157
+ warnings.warn(
158
+ f"""
159
+ The specified path {storage_path} for visual inputs does not exist.
160
+ Please provide a correct path to the visual inputs or
161
+ refer to datasets/download_scripts/README.md for downloading instructions.
162
+ """
163
+ )
164
+
165
+ def build(self):
166
+ """
167
+ Create by split datasets inheriting torch.utils.data.Datasets.
168
+
169
+ # build() can be dataset-specific. Overwrite to customize.
170
+ """
171
+ self.build_processors()
172
+
173
+ build_info = self.config.build_info
174
+
175
+ ann_info = build_info.annotations
176
+ vis_info = build_info.get(self.data_type)
177
+
178
+ datasets = dict()
179
+ for split in ann_info.keys():
180
+ if split not in ["train", "val", "test"]:
181
+ continue
182
+
183
+ is_train = split == "train"
184
+
185
+ # processors
186
+ vis_processor = (
187
+ self.vis_processors["train"]
188
+ if is_train
189
+ else self.vis_processors["eval"]
190
+ )
191
+ text_processor = (
192
+ self.text_processors["train"]
193
+ if is_train
194
+ else self.text_processors["eval"]
195
+ )
196
+
197
+ # annotation path
198
+ ann_paths = ann_info.get(split).storage
199
+ if isinstance(ann_paths, str):
200
+ ann_paths = [ann_paths]
201
+
202
+ abs_ann_paths = []
203
+ for ann_path in ann_paths:
204
+ if not os.path.isabs(ann_path):
205
+ ann_path = utils.get_cache_path(ann_path)
206
+ abs_ann_paths.append(ann_path)
207
+ ann_paths = abs_ann_paths
208
+
209
+ # visual data storage path
210
+ vis_path = os.path.join(vis_info.storage, split)
211
+
212
+ if not os.path.isabs(vis_path):
213
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
214
+ vis_path = utils.get_cache_path(vis_path)
215
+
216
+ if not os.path.exists(vis_path):
217
+ warnings.warn("storage path {} does not exist.".format(vis_path))
218
+
219
+ # create datasets
220
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
221
+ datasets[split] = dataset_cls(
222
+ vis_processor=vis_processor,
223
+ text_processor=text_processor,
224
+ ann_paths=ann_paths,
225
+ vis_root=vis_path,
226
+ )
227
+
228
+ return datasets
229
+
230
+
231
+ def load_dataset_config(cfg_path):
232
+ cfg = OmegaConf.load(cfg_path).datasets
233
+ cfg = cfg[list(cfg.keys())[0]]
234
+
235
+ return cfg
minigpt4/datasets/builders/image_text_pair_builder.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+
10
+ from minigpt4.common.registry import registry
11
+ from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
12
+ from minigpt4.datasets.datasets.laion_dataset import LaionDataset
13
+ from minigpt4.datasets.datasets.cc_combine_dataset import CCCombineDataset, CCAlignDataset
14
+
15
+
16
+ @registry.register_builder("cc_combine")
17
+ class CCCombineBuilder(BaseDatasetBuilder):
18
+ train_dataset_cls = CCCombineDataset
19
+
20
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_combine/defaults.yaml"}
21
+
22
+ def _download_ann(self):
23
+ pass
24
+
25
+ def _download_vis(self):
26
+ pass
27
+
28
+ def build(self):
29
+ self.build_processors()
30
+
31
+ build_info = self.config.build_info
32
+
33
+ datasets = dict()
34
+ split = "train"
35
+
36
+ # create datasets
37
+ # [NOTE] return inner_datasets (wds.DataPipeline)
38
+ dataset_cls = self.train_dataset_cls
39
+ datasets[split] = dataset_cls(
40
+ vis_processor=self.vis_processors[split],
41
+ text_processor=self.text_processors[split],
42
+ location=build_info.storage,
43
+ ).inner_dataset
44
+
45
+ return datasets
46
+
47
+
48
+ @registry.register_builder("laion")
49
+ class LaionBuilder(BaseDatasetBuilder):
50
+ train_dataset_cls = LaionDataset
51
+
52
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
53
+
54
+ def _download_ann(self):
55
+ pass
56
+
57
+ def _download_vis(self):
58
+ pass
59
+
60
+ def build(self):
61
+ self.build_processors()
62
+
63
+ build_info = self.config.build_info
64
+
65
+ datasets = dict()
66
+ split = "train"
67
+
68
+ # create datasets
69
+ # [NOTE] return inner_datasets (wds.DataPipeline)
70
+ dataset_cls = self.train_dataset_cls
71
+ datasets[split] = dataset_cls(
72
+ vis_processor=self.vis_processors[split],
73
+ text_processor=self.text_processors[split],
74
+ location=build_info.storage,
75
+ ).inner_dataset
76
+
77
+ return datasets
78
+
79
+
80
+ @registry.register_builder("cc_align")
81
+ class CCAlignBuilder(BaseDatasetBuilder):
82
+ train_dataset_cls = CCAlignDataset
83
+
84
+ DATASET_CONFIG_DICT = {
85
+ "default": "configs/datasets/cc_combine/align.yaml",
86
+ }
minigpt4/datasets/data_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import gzip
9
+ import logging
10
+ import os
11
+ import random as rnd
12
+ import tarfile
13
+ import zipfile
14
+ import random
15
+ from typing import List
16
+ from tqdm import tqdm
17
+
18
+ import decord
19
+ from decord import VideoReader
20
+ import webdataset as wds
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils.data.dataset import IterableDataset
24
+
25
+ from minigpt4.common.registry import registry
26
+ from minigpt4.datasets.datasets.base_dataset import ConcatDataset
27
+
28
+
29
+ decord.bridge.set_bridge("torch")
30
+ MAX_INT = registry.get("MAX_INT")
31
+
32
+
33
+ class ChainDataset(wds.DataPipeline):
34
+ r"""Dataset for chaining multiple :class:`DataPipeline` s.
35
+
36
+ This class is useful to assemble different existing dataset streams. The
37
+ chaining operation is done on-the-fly, so concatenating large-scale
38
+ datasets with this class will be efficient.
39
+
40
+ Args:
41
+ datasets (iterable of IterableDataset): datasets to be chained together
42
+ """
43
+ def __init__(self, datasets: List[wds.DataPipeline]) -> None:
44
+ super().__init__()
45
+ self.datasets = datasets
46
+ self.prob = []
47
+ self.names = []
48
+ for dataset in self.datasets:
49
+ if hasattr(dataset, 'name'):
50
+ self.names.append(dataset.name)
51
+ else:
52
+ self.names.append('Unknown')
53
+ if hasattr(dataset, 'sample_ratio'):
54
+ self.prob.append(dataset.sample_ratio)
55
+ else:
56
+ self.prob.append(1)
57
+ logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
58
+
59
+ def __iter__(self):
60
+ datastreams = [iter(dataset) for dataset in self.datasets]
61
+ while True:
62
+ select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
63
+ yield next(select_datastream)
64
+
65
+
66
+ def apply_to_sample(f, sample):
67
+ if len(sample) == 0:
68
+ return {}
69
+
70
+ def _apply(x):
71
+ if torch.is_tensor(x):
72
+ return f(x)
73
+ elif isinstance(x, dict):
74
+ return {key: _apply(value) for key, value in x.items()}
75
+ elif isinstance(x, list):
76
+ return [_apply(x) for x in x]
77
+ else:
78
+ return x
79
+
80
+ return _apply(sample)
81
+
82
+
83
+ def move_to_cuda(sample):
84
+ def _move_to_cuda(tensor):
85
+ return tensor.cuda()
86
+
87
+ return apply_to_sample(_move_to_cuda, sample)
88
+
89
+
90
+ def prepare_sample(samples, cuda_enabled=True):
91
+ if cuda_enabled:
92
+ samples = move_to_cuda(samples)
93
+
94
+ # TODO fp16 support
95
+
96
+ return samples
97
+
98
+
99
+ def reorg_datasets_by_split(datasets):
100
+ """
101
+ Organizes datasets by split.
102
+
103
+ Args:
104
+ datasets: dict of torch.utils.data.Dataset objects by name.
105
+
106
+ Returns:
107
+ Dict of datasets by split {split_name: List[Datasets]}.
108
+ """
109
+ # if len(datasets) == 1:
110
+ # return datasets[list(datasets.keys())[0]]
111
+ # else:
112
+ reorg_datasets = dict()
113
+
114
+ # reorganize by split
115
+ for _, dataset in datasets.items():
116
+ for split_name, dataset_split in dataset.items():
117
+ if split_name not in reorg_datasets:
118
+ reorg_datasets[split_name] = [dataset_split]
119
+ else:
120
+ reorg_datasets[split_name].append(dataset_split)
121
+
122
+ return reorg_datasets
123
+
124
+
125
+ def concat_datasets(datasets):
126
+ """
127
+ Concatenates multiple datasets into a single dataset.
128
+
129
+ It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
130
+ generic IterableDataset because it requires creating separate samplers.
131
+
132
+ Now only supports conctenating training datasets and assuming validation and testing
133
+ have only a single dataset. This is because metrics should not be computed on the concatenated
134
+ datasets.
135
+
136
+ Args:
137
+ datasets: dict of torch.utils.data.Dataset objects by split.
138
+
139
+ Returns:
140
+ Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
141
+ "val" and "test" remain the same.
142
+
143
+ If the input training datasets contain both map-style and DataPipeline datasets, returns
144
+ a tuple, where the first element is a concatenated map-style dataset and the second
145
+ element is a chained DataPipeline dataset.
146
+
147
+ """
148
+ # concatenate datasets in the same split
149
+ for split_name in datasets:
150
+ if split_name != "train":
151
+ assert (
152
+ len(datasets[split_name]) == 1
153
+ ), "Do not support multiple {} datasets.".format(split_name)
154
+ datasets[split_name] = datasets[split_name][0]
155
+ else:
156
+ iterable_datasets, map_datasets = [], []
157
+ for dataset in datasets[split_name]:
158
+ if isinstance(dataset, wds.DataPipeline):
159
+ logging.info(
160
+ "Dataset {} is IterableDataset, can't be concatenated.".format(
161
+ dataset
162
+ )
163
+ )
164
+ iterable_datasets.append(dataset)
165
+ elif isinstance(dataset, IterableDataset):
166
+ raise NotImplementedError(
167
+ "Do not support concatenation of generic IterableDataset."
168
+ )
169
+ else:
170
+ map_datasets.append(dataset)
171
+
172
+ # if len(iterable_datasets) > 0:
173
+ # concatenate map-style datasets and iterable-style datasets separately
174
+ if len(iterable_datasets) > 1:
175
+ chained_datasets = (
176
+ ChainDataset(iterable_datasets)
177
+ )
178
+ elif len(iterable_datasets) == 1:
179
+ chained_datasets = iterable_datasets[0]
180
+ else:
181
+ chained_datasets = None
182
+
183
+ concat_datasets = (
184
+ ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
185
+ )
186
+
187
+ train_datasets = concat_datasets, chained_datasets
188
+ train_datasets = tuple([x for x in train_datasets if x is not None])
189
+ train_datasets = (
190
+ train_datasets[0] if len(train_datasets) == 1 else train_datasets
191
+ )
192
+
193
+ datasets[split_name] = train_datasets
194
+
195
+ return datasets
196
+
minigpt4/datasets/datasets/__init__.py ADDED
File without changes
minigpt4/datasets/datasets/base_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import json
9
+ from typing import Iterable
10
+
11
+ from torch.utils.data import Dataset, ConcatDataset
12
+ from torch.utils.data.dataloader import default_collate
13
+
14
+
15
+ class BaseDataset(Dataset):
16
+ def __init__(
17
+ self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
18
+ ):
19
+ """
20
+ vis_root (string): Root directory of images (e.g. coco/images/)
21
+ ann_root (string): directory to store the annotation file
22
+ """
23
+ self.vis_root = vis_root
24
+
25
+ self.annotation = []
26
+ for ann_path in ann_paths:
27
+ self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
28
+
29
+ self.vis_processor = vis_processor
30
+ self.text_processor = text_processor
31
+
32
+ self._add_instance_ids()
33
+
34
+ def __len__(self):
35
+ return len(self.annotation)
36
+
37
+ def collater(self, samples):
38
+ return default_collate(samples)
39
+
40
+ def set_processors(self, vis_processor, text_processor):
41
+ self.vis_processor = vis_processor
42
+ self.text_processor = text_processor
43
+
44
+ def _add_instance_ids(self, key="instance_id"):
45
+ for idx, ann in enumerate(self.annotation):
46
+ ann[key] = str(idx)
47
+
48
+
49
+ class ConcatDataset(ConcatDataset):
50
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
51
+ super().__init__(datasets)
52
+
53
+ def collater(self, samples):
54
+ # TODO For now only supports datasets with same underlying collater implementations
55
+
56
+ all_keys = set()
57
+ for s in samples:
58
+ all_keys.update(s)
59
+
60
+ shared_keys = all_keys
61
+ for s in samples:
62
+ shared_keys = shared_keys & set(s.keys())
63
+
64
+ samples_shared_keys = []
65
+ for s in samples:
66
+ samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
67
+
68
+ return self.datasets[0].collater(samples_shared_keys)
minigpt4/datasets/datasets/caption_datasets.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ from collections import OrderedDict
10
+
11
+ from minigpt4.datasets.datasets.base_dataset import BaseDataset
12
+ from PIL import Image
13
+
14
+
15
+ class __DisplMixin:
16
+ def displ_item(self, index):
17
+ sample, ann = self.__getitem__(index), self.annotation[index]
18
+
19
+ return OrderedDict(
20
+ {
21
+ "file": ann["image"],
22
+ "caption": ann["caption"],
23
+ "image": sample["image"],
24
+ }
25
+ )
26
+
27
+
28
+ class CaptionDataset(BaseDataset, __DisplMixin):
29
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
30
+ """
31
+ vis_root (string): Root directory of images (e.g. coco/images/)
32
+ ann_root (string): directory to store the annotation file
33
+ """
34
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
35
+
36
+ self.img_ids = {}
37
+ n = 0
38
+ for ann in self.annotation:
39
+ img_id = ann["image_id"]
40
+ if img_id not in self.img_ids.keys():
41
+ self.img_ids[img_id] = n
42
+ n += 1
43
+
44
+ def __getitem__(self, index):
45
+
46
+ # TODO this assumes image input, not general enough
47
+ ann = self.annotation[index]
48
+
49
+ img_file = '{:0>12}.jpg'.format(ann["image_id"])
50
+ image_path = os.path.join(self.vis_root, img_file)
51
+ image = Image.open(image_path).convert("RGB")
52
+
53
+ image = self.vis_processor(image)
54
+ caption = self.text_processor(ann["caption"])
55
+
56
+ return {
57
+ "image": image,
58
+ "text_input": caption,
59
+ "image_id": self.img_ids[ann["image_id"]],
60
+ }
61
+
62
+
63
+ class CaptionEvalDataset(BaseDataset, __DisplMixin):
64
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
65
+ """
66
+ vis_root (string): Root directory of images (e.g. coco/images/)
67
+ ann_root (string): directory to store the annotation file
68
+ split (string): val or test
69
+ """
70
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
71
+
72
+ def __getitem__(self, index):
73
+
74
+ ann = self.annotation[index]
75
+
76
+ image_path = os.path.join(self.vis_root, ann["image"])
77
+ image = Image.open(image_path).convert("RGB")
78
+
79
+ image = self.vis_processor(image)
80
+
81
+ return {
82
+ "image": image,
83
+ "image_id": ann["image_id"],
84
+ "instance_id": ann["instance_id"],
85
+ }
minigpt4/datasets/datasets/cc_combine_dataset.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import os
8
+ from PIL import Image
9
+ import webdataset as wds
10
+ from minigpt4.datasets.datasets.base_dataset import BaseDataset
11
+ from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
12
+
13
+
14
+ class CCCombineDataset(BaseDataset):
15
+ def __init__(self, vis_processor, text_processor, location):
16
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
17
+
18
+ self.inner_dataset = wds.DataPipeline(
19
+ wds.ResampledShards(location),
20
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
21
+ wds.shuffle(1000, handler=wds.warn_and_continue),
22
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
23
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
24
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
25
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
26
+ )
27
+
28
+ def to_dict(self, sample):
29
+ return {
30
+ "image": sample[0],
31
+ "text_input": self.text_processor(sample[1]["caption"]),
32
+ }
33
+
34
+
35
+ class CCAlignDataset(CaptionDataset):
36
+
37
+ def __getitem__(self, index):
38
+
39
+ # TODO this assumes image input, not general enough
40
+ ann = self.annotation[index]
41
+
42
+ img_file = '{}.jpg'.format(ann["image_id"])
43
+ image_path = os.path.join(self.vis_root, img_file)
44
+ image = Image.open(image_path).convert("RGB")
45
+
46
+ image = self.vis_processor(image)
47
+ caption = ann["caption"]
48
+
49
+ return {
50
+ "image": image,
51
+ "text_input": caption,
52
+ "image_id": self.img_ids[ann["image_id"]],
53
+ }
minigpt4/datasets/datasets/dataloader_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import time
9
+ import random
10
+ import torch
11
+ from minigpt4.datasets.data_utils import move_to_cuda
12
+ from torch.utils.data import DataLoader
13
+
14
+
15
+ class MultiIterLoader:
16
+ """
17
+ A simple wrapper for iterating over multiple iterators.
18
+
19
+ Args:
20
+ loaders (List[Loader]): List of Iterator loaders.
21
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22
+ """
23
+
24
+ def __init__(self, loaders, ratios=None):
25
+ # assert all loaders has __next__ method
26
+ for loader in loaders:
27
+ assert hasattr(
28
+ loader, "__next__"
29
+ ), "Loader {} has no __next__ method.".format(loader)
30
+
31
+ if ratios is None:
32
+ ratios = [1.0] * len(loaders)
33
+ else:
34
+ assert len(ratios) == len(loaders)
35
+ ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36
+
37
+ self.loaders = loaders
38
+ self.ratios = ratios
39
+
40
+ def __next__(self):
41
+ # random sample from each loader by ratio
42
+ loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43
+ return next(self.loaders[loader_idx])
44
+
45
+
46
+ class PrefetchLoader(object):
47
+ """
48
+ Modified from https://github.com/ChenRocks/UNITER.
49
+
50
+ overlap compute and cuda data transfer
51
+ (copied and then modified from nvidia apex)
52
+ """
53
+
54
+ def __init__(self, loader):
55
+ self.loader = loader
56
+ self.stream = torch.cuda.Stream()
57
+
58
+ def __iter__(self):
59
+ loader_it = iter(self.loader)
60
+ self.preload(loader_it)
61
+ batch = self.next(loader_it)
62
+ while batch is not None:
63
+ is_tuple = isinstance(batch, tuple)
64
+ if is_tuple:
65
+ task, batch = batch
66
+
67
+ if is_tuple:
68
+ yield task, batch
69
+ else:
70
+ yield batch
71
+ batch = self.next(loader_it)
72
+
73
+ def __len__(self):
74
+ return len(self.loader)
75
+
76
+ def preload(self, it):
77
+ try:
78
+ self.batch = next(it)
79
+ except StopIteration:
80
+ self.batch = None
81
+ return
82
+ # if record_stream() doesn't work, another option is to make sure
83
+ # device inputs are created on the main stream.
84
+ # self.next_input_gpu = torch.empty_like(self.next_input,
85
+ # device='cuda')
86
+ # self.next_target_gpu = torch.empty_like(self.next_target,
87
+ # device='cuda')
88
+ # Need to make sure the memory allocated for next_* is not still in use
89
+ # by the main stream at the time we start copying to next_*:
90
+ # self.stream.wait_stream(torch.cuda.current_stream())
91
+ with torch.cuda.stream(self.stream):
92
+ self.batch = move_to_cuda(self.batch)
93
+ # more code for the alternative if record_stream() doesn't work:
94
+ # copy_ will record the use of the pinned source tensor in this
95
+ # side stream.
96
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98
+ # self.next_input = self.next_input_gpu
99
+ # self.next_target = self.next_target_gpu
100
+
101
+ def next(self, it):
102
+ torch.cuda.current_stream().wait_stream(self.stream)
103
+ batch = self.batch
104
+ if batch is not None:
105
+ record_cuda_stream(batch)
106
+ self.preload(it)
107
+ return batch
108
+
109
+ def __getattr__(self, name):
110
+ method = self.loader.__getattribute__(name)
111
+ return method
112
+
113
+
114
+ def record_cuda_stream(batch):
115
+ if isinstance(batch, torch.Tensor):
116
+ batch.record_stream(torch.cuda.current_stream())
117
+ elif isinstance(batch, list) or isinstance(batch, tuple):
118
+ for t in batch:
119
+ record_cuda_stream(t)
120
+ elif isinstance(batch, dict):
121
+ for t in batch.values():
122
+ record_cuda_stream(t)
123
+ else:
124
+ pass
125
+
126
+
127
+ class IterLoader:
128
+ """
129
+ A wrapper to convert DataLoader as an infinite iterator.
130
+
131
+ Modified from:
132
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133
+ """
134
+
135
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136
+ self._dataloader = dataloader
137
+ self.iter_loader = iter(self._dataloader)
138
+ self._use_distributed = use_distributed
139
+ self._epoch = 0
140
+
141
+ @property
142
+ def epoch(self) -> int:
143
+ return self._epoch
144
+
145
+ def __next__(self):
146
+ try:
147
+ data = next(self.iter_loader)
148
+ except StopIteration:
149
+ self._epoch += 1
150
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151
+ self._dataloader.sampler.set_epoch(self._epoch)
152
+ time.sleep(2) # Prevent possible deadlock during epoch transition
153
+ self.iter_loader = iter(self._dataloader)
154
+ data = next(self.iter_loader)
155
+
156
+ return data
157
+
158
+ def __iter__(self):
159
+ return self
160
+
161
+ def __len__(self):
162
+ return len(self._dataloader)
minigpt4/datasets/datasets/laion_dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import webdataset as wds
9
+ from minigpt4.datasets.datasets.base_dataset import BaseDataset
10
+
11
+
12
+ class LaionDataset(BaseDataset):
13
+ def __init__(self, vis_processor, text_processor, location):
14
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
15
+
16
+ self.inner_dataset = wds.DataPipeline(
17
+ wds.ResampledShards(location),
18
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
19
+ wds.shuffle(1000, handler=wds.warn_and_continue),
20
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
21
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
23
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
24
+ )
25
+
26
+ def to_dict(self, sample):
27
+ return {
28
+ "image": sample[0],
29
+ "text_input": self.text_processor(sample[1]["caption"]),
30
+ }
31
+
minigpt4/models/Qformer.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class BertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word and position embeddings."""
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(
57
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
58
+ )
59
+ self.position_embeddings = nn.Embedding(
60
+ config.max_position_embeddings, config.hidden_size
61
+ )
62
+
63
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
64
+ # any TensorFlow checkpoint file
65
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
67
+
68
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
69
+ self.register_buffer(
70
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
71
+ )
72
+ self.position_embedding_type = getattr(
73
+ config, "position_embedding_type", "absolute"
74
+ )
75
+
76
+ self.config = config
77
+
78
+ def forward(
79
+ self,
80
+ input_ids=None,
81
+ position_ids=None,
82
+ query_embeds=None,
83
+ past_key_values_length=0,
84
+ ):
85
+ if input_ids is not None:
86
+ seq_length = input_ids.size()[1]
87
+ else:
88
+ seq_length = 0
89
+
90
+ if position_ids is None:
91
+ position_ids = self.position_ids[
92
+ :, past_key_values_length : seq_length + past_key_values_length
93
+ ].clone()
94
+
95
+ if input_ids is not None:
96
+ embeddings = self.word_embeddings(input_ids)
97
+ if self.position_embedding_type == "absolute":
98
+ position_embeddings = self.position_embeddings(position_ids)
99
+ embeddings = embeddings + position_embeddings
100
+
101
+ if query_embeds is not None:
102
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
103
+ else:
104
+ embeddings = query_embeds
105
+
106
+ embeddings = self.LayerNorm(embeddings)
107
+ embeddings = self.dropout(embeddings)
108
+ return embeddings
109
+
110
+
111
+ class BertSelfAttention(nn.Module):
112
+ def __init__(self, config, is_cross_attention):
113
+ super().__init__()
114
+ self.config = config
115
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
116
+ config, "embedding_size"
117
+ ):
118
+ raise ValueError(
119
+ "The hidden size (%d) is not a multiple of the number of attention "
120
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
121
+ )
122
+
123
+ self.num_attention_heads = config.num_attention_heads
124
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
125
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
126
+
127
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
128
+ if is_cross_attention:
129
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
130
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
131
+ else:
132
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
133
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
134
+
135
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
136
+ self.position_embedding_type = getattr(
137
+ config, "position_embedding_type", "absolute"
138
+ )
139
+ if (
140
+ self.position_embedding_type == "relative_key"
141
+ or self.position_embedding_type == "relative_key_query"
142
+ ):
143
+ self.max_position_embeddings = config.max_position_embeddings
144
+ self.distance_embedding = nn.Embedding(
145
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
146
+ )
147
+ self.save_attention = False
148
+
149
+ def save_attn_gradients(self, attn_gradients):
150
+ self.attn_gradients = attn_gradients
151
+
152
+ def get_attn_gradients(self):
153
+ return self.attn_gradients
154
+
155
+ def save_attention_map(self, attention_map):
156
+ self.attention_map = attention_map
157
+
158
+ def get_attention_map(self):
159
+ return self.attention_map
160
+
161
+ def transpose_for_scores(self, x):
162
+ new_x_shape = x.size()[:-1] + (
163
+ self.num_attention_heads,
164
+ self.attention_head_size,
165
+ )
166
+ x = x.view(*new_x_shape)
167
+ return x.permute(0, 2, 1, 3)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states,
172
+ attention_mask=None,
173
+ head_mask=None,
174
+ encoder_hidden_states=None,
175
+ encoder_attention_mask=None,
176
+ past_key_value=None,
177
+ output_attentions=False,
178
+ ):
179
+
180
+ # If this is instantiated as a cross-attention module, the keys
181
+ # and values come from an encoder; the attention mask needs to be
182
+ # such that the encoder's padding tokens are not attended to.
183
+ is_cross_attention = encoder_hidden_states is not None
184
+
185
+ if is_cross_attention:
186
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
187
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
188
+ attention_mask = encoder_attention_mask
189
+ elif past_key_value is not None:
190
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
191
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
192
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
193
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
194
+ else:
195
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
196
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
197
+
198
+ mixed_query_layer = self.query(hidden_states)
199
+
200
+ query_layer = self.transpose_for_scores(mixed_query_layer)
201
+
202
+ past_key_value = (key_layer, value_layer)
203
+
204
+ # Take the dot product between "query" and "key" to get the raw attention scores.
205
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
206
+
207
+ if (
208
+ self.position_embedding_type == "relative_key"
209
+ or self.position_embedding_type == "relative_key_query"
210
+ ):
211
+ seq_length = hidden_states.size()[1]
212
+ position_ids_l = torch.arange(
213
+ seq_length, dtype=torch.long, device=hidden_states.device
214
+ ).view(-1, 1)
215
+ position_ids_r = torch.arange(
216
+ seq_length, dtype=torch.long, device=hidden_states.device
217
+ ).view(1, -1)
218
+ distance = position_ids_l - position_ids_r
219
+ positional_embedding = self.distance_embedding(
220
+ distance + self.max_position_embeddings - 1
221
+ )
222
+ positional_embedding = positional_embedding.to(
223
+ dtype=query_layer.dtype
224
+ ) # fp16 compatibility
225
+
226
+ if self.position_embedding_type == "relative_key":
227
+ relative_position_scores = torch.einsum(
228
+ "bhld,lrd->bhlr", query_layer, positional_embedding
229
+ )
230
+ attention_scores = attention_scores + relative_position_scores
231
+ elif self.position_embedding_type == "relative_key_query":
232
+ relative_position_scores_query = torch.einsum(
233
+ "bhld,lrd->bhlr", query_layer, positional_embedding
234
+ )
235
+ relative_position_scores_key = torch.einsum(
236
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
237
+ )
238
+ attention_scores = (
239
+ attention_scores
240
+ + relative_position_scores_query
241
+ + relative_position_scores_key
242
+ )
243
+
244
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
245
+ if attention_mask is not None:
246
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
247
+ attention_scores = attention_scores + attention_mask
248
+
249
+ # Normalize the attention scores to probabilities.
250
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
251
+
252
+ if is_cross_attention and self.save_attention:
253
+ self.save_attention_map(attention_probs)
254
+ attention_probs.register_hook(self.save_attn_gradients)
255
+
256
+ # This is actually dropping out entire tokens to attend to, which might
257
+ # seem a bit unusual, but is taken from the original Transformer paper.
258
+ attention_probs_dropped = self.dropout(attention_probs)
259
+
260
+ # Mask heads if we want to
261
+ if head_mask is not None:
262
+ attention_probs_dropped = attention_probs_dropped * head_mask
263
+
264
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
265
+
266
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
267
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
268
+ context_layer = context_layer.view(*new_context_layer_shape)
269
+
270
+ outputs = (
271
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
272
+ )
273
+
274
+ outputs = outputs + (past_key_value,)
275
+ return outputs
276
+
277
+
278
+ class BertSelfOutput(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
282
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
284
+
285
+ def forward(self, hidden_states, input_tensor):
286
+ hidden_states = self.dense(hidden_states)
287
+ hidden_states = self.dropout(hidden_states)
288
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
289
+ return hidden_states
290
+
291
+
292
+ class BertAttention(nn.Module):
293
+ def __init__(self, config, is_cross_attention=False):
294
+ super().__init__()
295
+ self.self = BertSelfAttention(config, is_cross_attention)
296
+ self.output = BertSelfOutput(config)
297
+ self.pruned_heads = set()
298
+
299
+ def prune_heads(self, heads):
300
+ if len(heads) == 0:
301
+ return
302
+ heads, index = find_pruneable_heads_and_indices(
303
+ heads,
304
+ self.self.num_attention_heads,
305
+ self.self.attention_head_size,
306
+ self.pruned_heads,
307
+ )
308
+
309
+ # Prune linear layers
310
+ self.self.query = prune_linear_layer(self.self.query, index)
311
+ self.self.key = prune_linear_layer(self.self.key, index)
312
+ self.self.value = prune_linear_layer(self.self.value, index)
313
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
314
+
315
+ # Update hyper params and store pruned heads
316
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
317
+ self.self.all_head_size = (
318
+ self.self.attention_head_size * self.self.num_attention_heads
319
+ )
320
+ self.pruned_heads = self.pruned_heads.union(heads)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states,
325
+ attention_mask=None,
326
+ head_mask=None,
327
+ encoder_hidden_states=None,
328
+ encoder_attention_mask=None,
329
+ past_key_value=None,
330
+ output_attentions=False,
331
+ ):
332
+ self_outputs = self.self(
333
+ hidden_states,
334
+ attention_mask,
335
+ head_mask,
336
+ encoder_hidden_states,
337
+ encoder_attention_mask,
338
+ past_key_value,
339
+ output_attentions,
340
+ )
341
+ attention_output = self.output(self_outputs[0], hidden_states)
342
+
343
+ outputs = (attention_output,) + self_outputs[
344
+ 1:
345
+ ] # add attentions if we output them
346
+ return outputs
347
+
348
+
349
+ class BertIntermediate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
353
+ if isinstance(config.hidden_act, str):
354
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
355
+ else:
356
+ self.intermediate_act_fn = config.hidden_act
357
+
358
+ def forward(self, hidden_states):
359
+ hidden_states = self.dense(hidden_states)
360
+ hidden_states = self.intermediate_act_fn(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ class BertOutput(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
368
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
369
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
370
+
371
+ def forward(self, hidden_states, input_tensor):
372
+ hidden_states = self.dense(hidden_states)
373
+ hidden_states = self.dropout(hidden_states)
374
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
375
+ return hidden_states
376
+
377
+
378
+ class BertLayer(nn.Module):
379
+ def __init__(self, config, layer_num):
380
+ super().__init__()
381
+ self.config = config
382
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
383
+ self.seq_len_dim = 1
384
+ self.attention = BertAttention(config)
385
+ self.layer_num = layer_num
386
+ if (
387
+ self.config.add_cross_attention
388
+ and layer_num % self.config.cross_attention_freq == 0
389
+ ):
390
+ self.crossattention = BertAttention(
391
+ config, is_cross_attention=self.config.add_cross_attention
392
+ )
393
+ self.has_cross_attention = True
394
+ else:
395
+ self.has_cross_attention = False
396
+ self.intermediate = BertIntermediate(config)
397
+ self.output = BertOutput(config)
398
+
399
+ self.intermediate_query = BertIntermediate(config)
400
+ self.output_query = BertOutput(config)
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states,
405
+ attention_mask=None,
406
+ head_mask=None,
407
+ encoder_hidden_states=None,
408
+ encoder_attention_mask=None,
409
+ past_key_value=None,
410
+ output_attentions=False,
411
+ query_length=0,
412
+ ):
413
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
414
+ self_attn_past_key_value = (
415
+ past_key_value[:2] if past_key_value is not None else None
416
+ )
417
+ self_attention_outputs = self.attention(
418
+ hidden_states,
419
+ attention_mask,
420
+ head_mask,
421
+ output_attentions=output_attentions,
422
+ past_key_value=self_attn_past_key_value,
423
+ )
424
+ attention_output = self_attention_outputs[0]
425
+ outputs = self_attention_outputs[1:-1]
426
+
427
+ present_key_value = self_attention_outputs[-1]
428
+
429
+ if query_length > 0:
430
+ query_attention_output = attention_output[:, :query_length, :]
431
+
432
+ if self.has_cross_attention:
433
+ assert (
434
+ encoder_hidden_states is not None
435
+ ), "encoder_hidden_states must be given for cross-attention layers"
436
+ cross_attention_outputs = self.crossattention(
437
+ query_attention_output,
438
+ attention_mask,
439
+ head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ output_attentions=output_attentions,
443
+ )
444
+ query_attention_output = cross_attention_outputs[0]
445
+ outputs = (
446
+ outputs + cross_attention_outputs[1:-1]
447
+ ) # add cross attentions if we output attention weights
448
+
449
+ layer_output = apply_chunking_to_forward(
450
+ self.feed_forward_chunk_query,
451
+ self.chunk_size_feed_forward,
452
+ self.seq_len_dim,
453
+ query_attention_output,
454
+ )
455
+ if attention_output.shape[1] > query_length:
456
+ layer_output_text = apply_chunking_to_forward(
457
+ self.feed_forward_chunk,
458
+ self.chunk_size_feed_forward,
459
+ self.seq_len_dim,
460
+ attention_output[:, query_length:, :],
461
+ )
462
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
463
+ else:
464
+ layer_output = apply_chunking_to_forward(
465
+ self.feed_forward_chunk,
466
+ self.chunk_size_feed_forward,
467
+ self.seq_len_dim,
468
+ attention_output,
469
+ )
470
+ outputs = (layer_output,) + outputs
471
+
472
+ outputs = outputs + (present_key_value,)
473
+
474
+ return outputs
475
+
476
+ def feed_forward_chunk(self, attention_output):
477
+ intermediate_output = self.intermediate(attention_output)
478
+ layer_output = self.output(intermediate_output, attention_output)
479
+ return layer_output
480
+
481
+ def feed_forward_chunk_query(self, attention_output):
482
+ intermediate_output = self.intermediate_query(attention_output)
483
+ layer_output = self.output_query(intermediate_output, attention_output)
484
+ return layer_output
485
+
486
+
487
+ class BertEncoder(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.config = config
491
+ self.layer = nn.ModuleList(
492
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
493
+ )
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states,
498
+ attention_mask=None,
499
+ head_mask=None,
500
+ encoder_hidden_states=None,
501
+ encoder_attention_mask=None,
502
+ past_key_values=None,
503
+ use_cache=None,
504
+ output_attentions=False,
505
+ output_hidden_states=False,
506
+ return_dict=True,
507
+ query_length=0,
508
+ ):
509
+ all_hidden_states = () if output_hidden_states else None
510
+ all_self_attentions = () if output_attentions else None
511
+ all_cross_attentions = (
512
+ () if output_attentions and self.config.add_cross_attention else None
513
+ )
514
+
515
+ next_decoder_cache = () if use_cache else None
516
+
517
+ for i in range(self.config.num_hidden_layers):
518
+ layer_module = self.layer[i]
519
+ if output_hidden_states:
520
+ all_hidden_states = all_hidden_states + (hidden_states,)
521
+
522
+ layer_head_mask = head_mask[i] if head_mask is not None else None
523
+ past_key_value = past_key_values[i] if past_key_values is not None else None
524
+
525
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
526
+
527
+ if use_cache:
528
+ logger.warn(
529
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
+ )
531
+ use_cache = False
532
+
533
+ def create_custom_forward(module):
534
+ def custom_forward(*inputs):
535
+ return module(
536
+ *inputs, past_key_value, output_attentions, query_length
537
+ )
538
+
539
+ return custom_forward
540
+
541
+ layer_outputs = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(layer_module),
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ )
549
+ else:
550
+ layer_outputs = layer_module(
551
+ hidden_states,
552
+ attention_mask,
553
+ layer_head_mask,
554
+ encoder_hidden_states,
555
+ encoder_attention_mask,
556
+ past_key_value,
557
+ output_attentions,
558
+ query_length,
559
+ )
560
+
561
+ hidden_states = layer_outputs[0]
562
+ if use_cache:
563
+ next_decoder_cache += (layer_outputs[-1],)
564
+ if output_attentions:
565
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
566
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
567
+
568
+ if output_hidden_states:
569
+ all_hidden_states = all_hidden_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(
573
+ v
574
+ for v in [
575
+ hidden_states,
576
+ next_decoder_cache,
577
+ all_hidden_states,
578
+ all_self_attentions,
579
+ all_cross_attentions,
580
+ ]
581
+ if v is not None
582
+ )
583
+ return BaseModelOutputWithPastAndCrossAttentions(
584
+ last_hidden_state=hidden_states,
585
+ past_key_values=next_decoder_cache,
586
+ hidden_states=all_hidden_states,
587
+ attentions=all_self_attentions,
588
+ cross_attentions=all_cross_attentions,
589
+ )
590
+
591
+
592
+ class BertPooler(nn.Module):
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
596
+ self.activation = nn.Tanh()
597
+
598
+ def forward(self, hidden_states):
599
+ # We "pool" the model by simply taking the hidden state corresponding
600
+ # to the first token.
601
+ first_token_tensor = hidden_states[:, 0]
602
+ pooled_output = self.dense(first_token_tensor)
603
+ pooled_output = self.activation(pooled_output)
604
+ return pooled_output
605
+
606
+
607
+ class BertPredictionHeadTransform(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
611
+ if isinstance(config.hidden_act, str):
612
+ self.transform_act_fn = ACT2FN[config.hidden_act]
613
+ else:
614
+ self.transform_act_fn = config.hidden_act
615
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
616
+
617
+ def forward(self, hidden_states):
618
+ hidden_states = self.dense(hidden_states)
619
+ hidden_states = self.transform_act_fn(hidden_states)
620
+ hidden_states = self.LayerNorm(hidden_states)
621
+ return hidden_states
622
+
623
+
624
+ class BertLMPredictionHead(nn.Module):
625
+ def __init__(self, config):
626
+ super().__init__()
627
+ self.transform = BertPredictionHeadTransform(config)
628
+
629
+ # The output weights are the same as the input embeddings, but there is
630
+ # an output-only bias for each token.
631
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632
+
633
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
634
+
635
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
636
+ self.decoder.bias = self.bias
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.transform(hidden_states)
640
+ hidden_states = self.decoder(hidden_states)
641
+ return hidden_states
642
+
643
+
644
+ class BertOnlyMLMHead(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.predictions = BertLMPredictionHead(config)
648
+
649
+ def forward(self, sequence_output):
650
+ prediction_scores = self.predictions(sequence_output)
651
+ return prediction_scores
652
+
653
+
654
+ class BertPreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+
660
+ config_class = BertConfig
661
+ base_model_prefix = "bert"
662
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
663
+
664
+ def _init_weights(self, module):
665
+ """Initialize the weights"""
666
+ if isinstance(module, (nn.Linear, nn.Embedding)):
667
+ # Slightly different from the TF version which uses truncated_normal for initialization
668
+ # cf https://github.com/pytorch/pytorch/pull/5617
669
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
670
+ elif isinstance(module, nn.LayerNorm):
671
+ module.bias.data.zero_()
672
+ module.weight.data.fill_(1.0)
673
+ if isinstance(module, nn.Linear) and module.bias is not None:
674
+ module.bias.data.zero_()
675
+
676
+
677
+ class BertModel(BertPreTrainedModel):
678
+ """
679
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
680
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
681
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
682
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
683
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
684
+ input to the forward pass.
685
+ """
686
+
687
+ def __init__(self, config, add_pooling_layer=False):
688
+ super().__init__(config)
689
+ self.config = config
690
+
691
+ self.embeddings = BertEmbeddings(config)
692
+
693
+ self.encoder = BertEncoder(config)
694
+
695
+ self.pooler = BertPooler(config) if add_pooling_layer else None
696
+
697
+ self.init_weights()
698
+
699
+ def get_input_embeddings(self):
700
+ return self.embeddings.word_embeddings
701
+
702
+ def set_input_embeddings(self, value):
703
+ self.embeddings.word_embeddings = value
704
+
705
+ def _prune_heads(self, heads_to_prune):
706
+ """
707
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
708
+ class PreTrainedModel
709
+ """
710
+ for layer, heads in heads_to_prune.items():
711
+ self.encoder.layer[layer].attention.prune_heads(heads)
712
+
713
+ def get_extended_attention_mask(
714
+ self,
715
+ attention_mask: Tensor,
716
+ input_shape: Tuple[int],
717
+ device: device,
718
+ is_decoder: bool,
719
+ has_query: bool = False,
720
+ ) -> Tensor:
721
+ """
722
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
723
+
724
+ Arguments:
725
+ attention_mask (:obj:`torch.Tensor`):
726
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
727
+ input_shape (:obj:`Tuple[int]`):
728
+ The shape of the input to the model.
729
+ device: (:obj:`torch.device`):
730
+ The device of the input to the model.
731
+
732
+ Returns:
733
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
734
+ """
735
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
736
+ # ourselves in which case we just need to make it broadcastable to all heads.
737
+ if attention_mask.dim() == 3:
738
+ extended_attention_mask = attention_mask[:, None, :, :]
739
+ elif attention_mask.dim() == 2:
740
+ # Provided a padding mask of dimensions [batch_size, seq_length]
741
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
742
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
743
+ if is_decoder:
744
+ batch_size, seq_length = input_shape
745
+
746
+ seq_ids = torch.arange(seq_length, device=device)
747
+ causal_mask = (
748
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
749
+ <= seq_ids[None, :, None]
750
+ )
751
+
752
+ # add a prefix ones mask to the causal mask
753
+ # causal and attention masks must have same type with pytorch version < 1.3
754
+ causal_mask = causal_mask.to(attention_mask.dtype)
755
+
756
+ if causal_mask.shape[1] < attention_mask.shape[1]:
757
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
758
+ if has_query: # UniLM style attention mask
759
+ causal_mask = torch.cat(
760
+ [
761
+ torch.zeros(
762
+ (batch_size, prefix_seq_len, seq_length),
763
+ device=device,
764
+ dtype=causal_mask.dtype,
765
+ ),
766
+ causal_mask,
767
+ ],
768
+ axis=1,
769
+ )
770
+ causal_mask = torch.cat(
771
+ [
772
+ torch.ones(
773
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
774
+ device=device,
775
+ dtype=causal_mask.dtype,
776
+ ),
777
+ causal_mask,
778
+ ],
779
+ axis=-1,
780
+ )
781
+ extended_attention_mask = (
782
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
783
+ )
784
+ else:
785
+ extended_attention_mask = attention_mask[:, None, None, :]
786
+ else:
787
+ raise ValueError(
788
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
789
+ input_shape, attention_mask.shape
790
+ )
791
+ )
792
+
793
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
794
+ # masked positions, this operation will create a tensor which is 0.0 for
795
+ # positions we want to attend and -10000.0 for masked positions.
796
+ # Since we are adding it to the raw scores before the softmax, this is
797
+ # effectively the same as removing these entirely.
798
+ extended_attention_mask = extended_attention_mask.to(
799
+ dtype=self.dtype
800
+ ) # fp16 compatibility
801
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
802
+ return extended_attention_mask
803
+
804
+ def forward(
805
+ self,
806
+ input_ids=None,
807
+ attention_mask=None,
808
+ position_ids=None,
809
+ head_mask=None,
810
+ query_embeds=None,
811
+ encoder_hidden_states=None,
812
+ encoder_attention_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=None,
816
+ output_hidden_states=None,
817
+ return_dict=None,
818
+ is_decoder=False,
819
+ ):
820
+ r"""
821
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
822
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
823
+ the model is configured as a decoder.
824
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
825
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
826
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
827
+ - 1 for tokens that are **not masked**,
828
+ - 0 for tokens that are **masked**.
829
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
830
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
831
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
832
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
833
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
834
+ use_cache (:obj:`bool`, `optional`):
835
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
836
+ decoding (see :obj:`past_key_values`).
837
+ """
838
+ output_attentions = (
839
+ output_attentions
840
+ if output_attentions is not None
841
+ else self.config.output_attentions
842
+ )
843
+ output_hidden_states = (
844
+ output_hidden_states
845
+ if output_hidden_states is not None
846
+ else self.config.output_hidden_states
847
+ )
848
+ return_dict = (
849
+ return_dict if return_dict is not None else self.config.use_return_dict
850
+ )
851
+
852
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
853
+
854
+ if input_ids is None:
855
+ assert (
856
+ query_embeds is not None
857
+ ), "You have to specify query_embeds when input_ids is None"
858
+
859
+ # past_key_values_length
860
+ past_key_values_length = (
861
+ past_key_values[0][0].shape[2] - self.config.query_length
862
+ if past_key_values is not None
863
+ else 0
864
+ )
865
+
866
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
867
+
868
+ embedding_output = self.embeddings(
869
+ input_ids=input_ids,
870
+ position_ids=position_ids,
871
+ query_embeds=query_embeds,
872
+ past_key_values_length=past_key_values_length,
873
+ )
874
+
875
+ input_shape = embedding_output.size()[:-1]
876
+ batch_size, seq_length = input_shape
877
+ device = embedding_output.device
878
+
879
+ if attention_mask is None:
880
+ attention_mask = torch.ones(
881
+ ((batch_size, seq_length + past_key_values_length)), device=device
882
+ )
883
+
884
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
885
+ # ourselves in which case we just need to make it broadcastable to all heads.
886
+ if is_decoder:
887
+ extended_attention_mask = self.get_extended_attention_mask(
888
+ attention_mask,
889
+ input_ids.shape,
890
+ device,
891
+ is_decoder,
892
+ has_query=(query_embeds is not None),
893
+ )
894
+ else:
895
+ extended_attention_mask = self.get_extended_attention_mask(
896
+ attention_mask, input_shape, device, is_decoder
897
+ )
898
+
899
+ # If a 2D or 3D attention mask is provided for the cross-attention
900
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
901
+ if encoder_hidden_states is not None:
902
+ if type(encoder_hidden_states) == list:
903
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
904
+ 0
905
+ ].size()
906
+ else:
907
+ (
908
+ encoder_batch_size,
909
+ encoder_sequence_length,
910
+ _,
911
+ ) = encoder_hidden_states.size()
912
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
913
+
914
+ if type(encoder_attention_mask) == list:
915
+ encoder_extended_attention_mask = [
916
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
917
+ ]
918
+ elif encoder_attention_mask is None:
919
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
920
+ encoder_extended_attention_mask = self.invert_attention_mask(
921
+ encoder_attention_mask
922
+ )
923
+ else:
924
+ encoder_extended_attention_mask = self.invert_attention_mask(
925
+ encoder_attention_mask
926
+ )
927
+ else:
928
+ encoder_extended_attention_mask = None
929
+
930
+ # Prepare head mask if needed
931
+ # 1.0 in head_mask indicate we keep the head
932
+ # attention_probs has shape bsz x n_heads x N x N
933
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
934
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
935
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
936
+
937
+ encoder_outputs = self.encoder(
938
+ embedding_output,
939
+ attention_mask=extended_attention_mask,
940
+ head_mask=head_mask,
941
+ encoder_hidden_states=encoder_hidden_states,
942
+ encoder_attention_mask=encoder_extended_attention_mask,
943
+ past_key_values=past_key_values,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ query_length=query_length,
949
+ )
950
+ sequence_output = encoder_outputs[0]
951
+ pooled_output = (
952
+ self.pooler(sequence_output) if self.pooler is not None else None
953
+ )
954
+
955
+ if not return_dict:
956
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
957
+
958
+ return BaseModelOutputWithPoolingAndCrossAttentions(
959
+ last_hidden_state=sequence_output,
960
+ pooler_output=pooled_output,
961
+ past_key_values=encoder_outputs.past_key_values,
962
+ hidden_states=encoder_outputs.hidden_states,
963
+ attentions=encoder_outputs.attentions,
964
+ cross_attentions=encoder_outputs.cross_attentions,
965
+ )
966
+
967
+
968
+ class BertLMHeadModel(BertPreTrainedModel):
969
+
970
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
971
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
972
+
973
+ def __init__(self, config):
974
+ super().__init__(config)
975
+
976
+ self.bert = BertModel(config, add_pooling_layer=False)
977
+ self.cls = BertOnlyMLMHead(config)
978
+
979
+ self.init_weights()
980
+
981
+ def get_output_embeddings(self):
982
+ return self.cls.predictions.decoder
983
+
984
+ def set_output_embeddings(self, new_embeddings):
985
+ self.cls.predictions.decoder = new_embeddings
986
+
987
+ def forward(
988
+ self,
989
+ input_ids=None,
990
+ attention_mask=None,
991
+ position_ids=None,
992
+ head_mask=None,
993
+ query_embeds=None,
994
+ encoder_hidden_states=None,
995
+ encoder_attention_mask=None,
996
+ labels=None,
997
+ past_key_values=None,
998
+ use_cache=True,
999
+ output_attentions=None,
1000
+ output_hidden_states=None,
1001
+ return_dict=None,
1002
+ return_logits=False,
1003
+ is_decoder=True,
1004
+ reduction="mean",
1005
+ ):
1006
+ r"""
1007
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1008
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1009
+ the model is configured as a decoder.
1010
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1011
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1012
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1013
+ - 1 for tokens that are **not masked**,
1014
+ - 0 for tokens that are **masked**.
1015
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1016
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1017
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1018
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1019
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1020
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1021
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1022
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1023
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1024
+ use_cache (:obj:`bool`, `optional`):
1025
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1026
+ decoding (see :obj:`past_key_values`).
1027
+ Returns:
1028
+ Example::
1029
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1030
+ >>> import torch
1031
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1032
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1033
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1034
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1035
+ >>> outputs = model(**inputs)
1036
+ >>> prediction_logits = outputs.logits
1037
+ """
1038
+ return_dict = (
1039
+ return_dict if return_dict is not None else self.config.use_return_dict
1040
+ )
1041
+ if labels is not None:
1042
+ use_cache = False
1043
+ if past_key_values is not None:
1044
+ query_embeds = None
1045
+
1046
+ outputs = self.bert(
1047
+ input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ head_mask=head_mask,
1051
+ query_embeds=query_embeds,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ encoder_attention_mask=encoder_attention_mask,
1054
+ past_key_values=past_key_values,
1055
+ use_cache=use_cache,
1056
+ output_attentions=output_attentions,
1057
+ output_hidden_states=output_hidden_states,
1058
+ return_dict=return_dict,
1059
+ is_decoder=is_decoder,
1060
+ )
1061
+
1062
+ sequence_output = outputs[0]
1063
+ if query_embeds is not None:
1064
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1065
+
1066
+ prediction_scores = self.cls(sequence_output)
1067
+
1068
+ if return_logits:
1069
+ return prediction_scores[:, :-1, :].contiguous()
1070
+
1071
+ lm_loss = None
1072
+ if labels is not None:
1073
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1074
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1075
+ labels = labels[:, 1:].contiguous()
1076
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1077
+ lm_loss = loss_fct(
1078
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1079
+ labels.view(-1),
1080
+ )
1081
+ if reduction == "none":
1082
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1083
+
1084
+ if not return_dict:
1085
+ output = (prediction_scores,) + outputs[2:]
1086
+ return ((lm_loss,) + output) if lm_loss is not None else output
1087
+
1088
+ return CausalLMOutputWithCrossAttentions(
1089
+ loss=lm_loss,
1090
+ logits=prediction_scores,
1091
+ past_key_values=outputs.past_key_values,
1092
+ hidden_states=outputs.hidden_states,
1093
+ attentions=outputs.attentions,
1094
+ cross_attentions=outputs.cross_attentions,
1095
+ )
1096
+
1097
+ def prepare_inputs_for_generation(
1098
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1099
+ ):
1100
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1101
+ if attention_mask is None:
1102
+ attention_mask = input_ids.new_ones(input_ids.shape)
1103
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1104
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1105
+
1106
+ # cut decoder_input_ids if past is used
1107
+ if past is not None:
1108
+ input_ids = input_ids[:, -1:]
1109
+
1110
+ return {
1111
+ "input_ids": input_ids,
1112
+ "query_embeds": query_embeds,
1113
+ "attention_mask": attention_mask,
1114
+ "past_key_values": past,
1115
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1116
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1117
+ "is_decoder": True,
1118
+ }
1119
+
1120
+ def _reorder_cache(self, past, beam_idx):
1121
+ reordered_past = ()
1122
+ for layer_past in past:
1123
+ reordered_past += (
1124
+ tuple(
1125
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1126
+ ),
1127
+ )
1128
+ return reordered_past
1129
+
1130
+
1131
+ class BertForMaskedLM(BertPreTrainedModel):
1132
+
1133
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1134
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1135
+
1136
+ def __init__(self, config):
1137
+ super().__init__(config)
1138
+
1139
+ self.bert = BertModel(config, add_pooling_layer=False)
1140
+ self.cls = BertOnlyMLMHead(config)
1141
+
1142
+ self.init_weights()
1143
+
1144
+ def get_output_embeddings(self):
1145
+ return self.cls.predictions.decoder
1146
+
1147
+ def set_output_embeddings(self, new_embeddings):
1148
+ self.cls.predictions.decoder = new_embeddings
1149
+
1150
+ def forward(
1151
+ self,
1152
+ input_ids=None,
1153
+ attention_mask=None,
1154
+ position_ids=None,
1155
+ head_mask=None,
1156
+ query_embeds=None,
1157
+ encoder_hidden_states=None,
1158
+ encoder_attention_mask=None,
1159
+ labels=None,
1160
+ output_attentions=None,
1161
+ output_hidden_states=None,
1162
+ return_dict=None,
1163
+ return_logits=False,
1164
+ is_decoder=False,
1165
+ ):
1166
+ r"""
1167
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1168
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1169
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1170
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1171
+ """
1172
+
1173
+ return_dict = (
1174
+ return_dict if return_dict is not None else self.config.use_return_dict
1175
+ )
1176
+
1177
+ outputs = self.bert(
1178
+ input_ids,
1179
+ attention_mask=attention_mask,
1180
+ position_ids=position_ids,
1181
+ head_mask=head_mask,
1182
+ query_embeds=query_embeds,
1183
+ encoder_hidden_states=encoder_hidden_states,
1184
+ encoder_attention_mask=encoder_attention_mask,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ is_decoder=is_decoder,
1189
+ )
1190
+
1191
+ if query_embeds is not None:
1192
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1193
+ prediction_scores = self.cls(sequence_output)
1194
+
1195
+ if return_logits:
1196
+ return prediction_scores
1197
+
1198
+ masked_lm_loss = None
1199
+ if labels is not None:
1200
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1201
+ masked_lm_loss = loss_fct(
1202
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1203
+ )
1204
+
1205
+ if not return_dict:
1206
+ output = (prediction_scores,) + outputs[2:]
1207
+ return (
1208
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1209
+ )
1210
+
1211
+ return MaskedLMOutput(
1212
+ loss=masked_lm_loss,
1213
+ logits=prediction_scores,
1214
+ hidden_states=outputs.hidden_states,
1215
+ attentions=outputs.attentions,
1216
+ )
minigpt4/models/__init__.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+
12
+ from minigpt4.common.registry import registry
13
+ from minigpt4.models.base_model import BaseModel
14
+ from minigpt4.models.blip2 import Blip2Base
15
+ from minigpt4.models.mini_gpt4 import MiniGPT4
16
+ from minigpt4.processors.base_processor import BaseProcessor
17
+
18
+
19
+ __all__ = [
20
+ "load_model",
21
+ "BaseModel",
22
+ "Blip2Base",
23
+ "MiniGPT4",
24
+ ]
25
+
26
+
27
+ def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
28
+ """
29
+ Load supported models.
30
+
31
+ To list all available models and types in registry:
32
+ >>> from minigpt4.models import model_zoo
33
+ >>> print(model_zoo)
34
+
35
+ Args:
36
+ name (str): name of the model.
37
+ model_type (str): type of the model.
38
+ is_eval (bool): whether the model is in eval mode. Default: False.
39
+ device (str): device to use. Default: "cpu".
40
+ checkpoint (str): path or to checkpoint. Default: None.
41
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
42
+
43
+ Returns:
44
+ model (torch.nn.Module): model.
45
+ """
46
+
47
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
48
+
49
+ if checkpoint is not None:
50
+ model.load_checkpoint(checkpoint)
51
+
52
+ if is_eval:
53
+ model.eval()
54
+
55
+ if device == "cpu":
56
+ model = model.float()
57
+
58
+ return model.to(device)
59
+
60
+
61
+ def load_preprocess(config):
62
+ """
63
+ Load preprocessor configs and construct preprocessors.
64
+
65
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
66
+
67
+ Args:
68
+ config (dict): preprocessor configs.
69
+
70
+ Returns:
71
+ vis_processors (dict): preprocessors for visual inputs.
72
+ txt_processors (dict): preprocessors for text inputs.
73
+
74
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
75
+ """
76
+
77
+ def _build_proc_from_cfg(cfg):
78
+ return (
79
+ registry.get_processor_class(cfg.name).from_config(cfg)
80
+ if cfg is not None
81
+ else BaseProcessor()
82
+ )
83
+
84
+ vis_processors = dict()
85
+ txt_processors = dict()
86
+
87
+ vis_proc_cfg = config.get("vis_processor")
88
+ txt_proc_cfg = config.get("text_processor")
89
+
90
+ if vis_proc_cfg is not None:
91
+ vis_train_cfg = vis_proc_cfg.get("train")
92
+ vis_eval_cfg = vis_proc_cfg.get("eval")
93
+ else:
94
+ vis_train_cfg = None
95
+ vis_eval_cfg = None
96
+
97
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
98
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
99
+
100
+ if txt_proc_cfg is not None:
101
+ txt_train_cfg = txt_proc_cfg.get("train")
102
+ txt_eval_cfg = txt_proc_cfg.get("eval")
103
+ else:
104
+ txt_train_cfg = None
105
+ txt_eval_cfg = None
106
+
107
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
108
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
109
+
110
+ return vis_processors, txt_processors
111
+
112
+
113
+ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
114
+ """
115
+ Load model and its related preprocessors.
116
+
117
+ List all available models and types in registry:
118
+ >>> from minigpt4.models import model_zoo
119
+ >>> print(model_zoo)
120
+
121
+ Args:
122
+ name (str): name of the model.
123
+ model_type (str): type of the model.
124
+ is_eval (bool): whether the model is in eval mode. Default: False.
125
+ device (str): device to use. Default: "cpu".
126
+
127
+ Returns:
128
+ model (torch.nn.Module): model.
129
+ vis_processors (dict): preprocessors for visual inputs.
130
+ txt_processors (dict): preprocessors for text inputs.
131
+ """
132
+ model_cls = registry.get_model_class(name)
133
+
134
+ # load model
135
+ model = model_cls.from_pretrained(model_type=model_type)
136
+
137
+ if is_eval:
138
+ model.eval()
139
+
140
+ # load preprocess
141
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
142
+ if cfg is not None:
143
+ preprocess_cfg = cfg.preprocess
144
+
145
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
146
+ else:
147
+ vis_processors, txt_processors = None, None
148
+ logging.info(
149
+ f"""No default preprocess for model {name} ({model_type}).
150
+ This can happen if the model is not finetuned on downstream datasets,
151
+ or it is not intended for direct use without finetuning.
152
+ """
153
+ )
154
+
155
+ if device == "cpu" or device == torch.device("cpu"):
156
+ model = model.float()
157
+
158
+ return model.to(device), vis_processors, txt_processors
159
+
160
+
161
+ class ModelZoo:
162
+ """
163
+ A utility class to create string representation of available model architectures and types.
164
+
165
+ >>> from minigpt4.models import model_zoo
166
+ >>> # list all available models
167
+ >>> print(model_zoo)
168
+ >>> # show total number of models
169
+ >>> print(len(model_zoo))
170
+ """
171
+
172
+ def __init__(self) -> None:
173
+ self.model_zoo = {
174
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
175
+ for k, v in registry.mapping["model_name_mapping"].items()
176
+ }
177
+
178
+ def __str__(self) -> str:
179
+ return (
180
+ "=" * 50
181
+ + "\n"
182
+ + f"{'Architectures':<30} {'Types'}\n"
183
+ + "=" * 50
184
+ + "\n"
185
+ + "\n".join(
186
+ [
187
+ f"{name:<30} {', '.join(types)}"
188
+ for name, types in self.model_zoo.items()
189
+ ]
190
+ )
191
+ )
192
+
193
+ def __iter__(self):
194
+ return iter(self.model_zoo.items())
195
+
196
+ def __len__(self):
197
+ return sum([len(v) for v in self.model_zoo.values()])
198
+
199
+
200
+ model_zoo = ModelZoo()
minigpt4/models/base_model.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
15
+ from minigpt4.common.utils import get_abs_path, is_url
16
+ from omegaconf import OmegaConf
17
+
18
+
19
+ class BaseModel(nn.Module):
20
+ """Base class for models."""
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @property
26
+ def device(self):
27
+ return list(self.parameters())[0].device
28
+
29
+ def load_checkpoint(self, url_or_filename):
30
+ """
31
+ Load from a finetuned checkpoint.
32
+
33
+ This should expect no mismatch in the model keys and the checkpoint keys.
34
+ """
35
+
36
+ if is_url(url_or_filename):
37
+ cached_file = download_cached_file(
38
+ url_or_filename, check_hash=False, progress=True
39
+ )
40
+ checkpoint = torch.load(cached_file, map_location="cpu")
41
+ elif os.path.isfile(url_or_filename):
42
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
43
+ else:
44
+ raise RuntimeError("checkpoint url or path is invalid")
45
+
46
+ if "model" in checkpoint.keys():
47
+ state_dict = checkpoint["model"]
48
+ else:
49
+ state_dict = checkpoint
50
+
51
+ msg = self.load_state_dict(state_dict, strict=False)
52
+
53
+ logging.info("Missing keys {}".format(msg.missing_keys))
54
+ logging.info("load checkpoint from %s" % url_or_filename)
55
+
56
+ return msg
57
+
58
+ @classmethod
59
+ def from_pretrained(cls, model_type):
60
+ """
61
+ Build a pretrained model from default configuration file, specified by model_type.
62
+
63
+ Args:
64
+ - model_type (str): model type, specifying architecture and checkpoints.
65
+
66
+ Returns:
67
+ - model (nn.Module): pretrained or finetuned model, depending on the configuration.
68
+ """
69
+ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
70
+ model = cls.from_config(model_cfg)
71
+
72
+ return model
73
+
74
+ @classmethod
75
+ def default_config_path(cls, model_type):
76
+ assert (
77
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
78
+ ), "Unknown model type {}".format(model_type)
79
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
80
+
81
+ def load_checkpoint_from_config(self, cfg, **kwargs):
82
+ """
83
+ Load checkpoint as specified in the config file.
84
+
85
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
86
+ When loading the pretrained model, each task-specific architecture may define their
87
+ own load_from_pretrained() method.
88
+ """
89
+ load_finetuned = cfg.get("load_finetuned", True)
90
+ if load_finetuned:
91
+ finetune_path = cfg.get("finetuned", None)
92
+ assert (
93
+ finetune_path is not None
94
+ ), "Found load_finetuned is True, but finetune_path is None."
95
+ self.load_checkpoint(url_or_filename=finetune_path)
96
+ else:
97
+ # load pre-trained weights
98
+ pretrain_path = cfg.get("pretrained", None)
99
+ assert "Found load_finetuned is False, but pretrain_path is None."
100
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
101
+
102
+ def before_evaluation(self, **kwargs):
103
+ pass
104
+
105
+ def show_n_params(self, return_str=True):
106
+ tot = 0
107
+ for p in self.parameters():
108
+ w = 1
109
+ for x in p.shape:
110
+ w *= x
111
+ tot += w
112
+ if return_str:
113
+ if tot >= 1e6:
114
+ return "{:.1f}M".format(tot / 1e6)
115
+ else:
116
+ return "{:.1f}K".format(tot / 1e3)
117
+ else:
118
+ return tot
119
+
120
+
121
+ class BaseEncoder(nn.Module):
122
+ """
123
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
124
+ """
125
+
126
+ def __init__(self):
127
+ super().__init__()
128
+
129
+ def forward_features(self, samples, **kwargs):
130
+ raise NotImplementedError
131
+
132
+ @property
133
+ def device(self):
134
+ return list(self.parameters())[0].device
135
+
136
+
137
+ class SharedQueueMixin:
138
+ @torch.no_grad()
139
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
140
+ # gather keys before updating queue
141
+ image_feats = concat_all_gather(image_feat)
142
+ text_feats = concat_all_gather(text_feat)
143
+
144
+ batch_size = image_feats.shape[0]
145
+
146
+ ptr = int(self.queue_ptr)
147
+ assert self.queue_size % batch_size == 0 # for simplicity
148
+
149
+ # replace the keys at ptr (dequeue and enqueue)
150
+ self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
151
+ self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
152
+
153
+ if idxs is not None:
154
+ idxs = concat_all_gather(idxs)
155
+ self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
156
+
157
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
158
+ self.queue_ptr[0] = ptr
159
+
160
+
161
+ class MomentumDistilationMixin:
162
+ @torch.no_grad()
163
+ def copy_params(self):
164
+ for model_pair in self.model_pairs:
165
+ for param, param_m in zip(
166
+ model_pair[0].parameters(), model_pair[1].parameters()
167
+ ):
168
+ param_m.data.copy_(param.data) # initialize
169
+ param_m.requires_grad = False # not update by gradient
170
+
171
+ @torch.no_grad()
172
+ def _momentum_update(self):
173
+ for model_pair in self.model_pairs:
174
+ for param, param_m in zip(
175
+ model_pair[0].parameters(), model_pair[1].parameters()
176
+ ):
177
+ param_m.data = param_m.data * self.momentum + param.data * (
178
+ 1.0 - self.momentum
179
+ )
180
+
181
+
182
+ class GatherLayer(torch.autograd.Function):
183
+ """
184
+ Gather tensors from all workers with support for backward propagation:
185
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
186
+ """
187
+
188
+ @staticmethod
189
+ def forward(ctx, x):
190
+ output = [
191
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
192
+ ]
193
+ torch.distributed.all_gather(output, x)
194
+ return tuple(output)
195
+
196
+ @staticmethod
197
+ def backward(ctx, *grads):
198
+ all_gradients = torch.stack(grads)
199
+ torch.distributed.all_reduce(all_gradients)
200
+ return all_gradients[torch.distributed.get_rank()]
201
+
202
+
203
+ def all_gather_with_grad(tensors):
204
+ """
205
+ Performs all_gather operation on the provided tensors.
206
+ Graph remains connected for backward grad computation.
207
+ """
208
+ # Queue the gathered tensors
209
+ world_size = torch.distributed.get_world_size()
210
+ # There is no need for reduction in the single-proc case
211
+ if world_size == 1:
212
+ return tensors
213
+
214
+ # tensor_all = GatherLayer.apply(tensors)
215
+ tensor_all = GatherLayer.apply(tensors)
216
+
217
+ return torch.cat(tensor_all, dim=0)
218
+
219
+
220
+ @torch.no_grad()
221
+ def concat_all_gather(tensor):
222
+ """
223
+ Performs all_gather operation on the provided tensors.
224
+ *** Warning ***: torch.distributed.all_gather has no gradient.
225
+ """
226
+ # if use distributed training
227
+ if not is_dist_avail_and_initialized():
228
+ return tensor
229
+
230
+ tensors_gather = [
231
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
232
+ ]
233
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
234
+
235
+ output = torch.cat(tensors_gather, dim=0)
236
+ return output
237
+
238
+
239
+ def tile(x, dim, n_tile):
240
+ init_dim = x.size(dim)
241
+ repeat_idx = [1] * x.dim()
242
+ repeat_idx[dim] = n_tile
243
+ x = x.repeat(*(repeat_idx))
244
+ order_index = torch.LongTensor(
245
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
246
+ )
247
+ return torch.index_select(x, dim, order_index.to(x.device))
minigpt4/models/blip2.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import logging
9
+ import os
10
+ import time
11
+ import datetime
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.distributed as dist
16
+ import torch.nn.functional as F
17
+
18
+ import minigpt4.common.dist_utils as dist_utils
19
+ from minigpt4.common.dist_utils import download_cached_file
20
+ from minigpt4.common.utils import is_url
21
+ from minigpt4.common.logger import MetricLogger
22
+ from minigpt4.models.base_model import BaseModel
23
+ from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
24
+ from minigpt4.models.eva_vit import create_eva_vit_g
25
+ from transformers import BertTokenizer
26
+
27
+
28
+ class Blip2Base(BaseModel):
29
+ @classmethod
30
+ def init_tokenizer(cls):
31
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
32
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
33
+ return tokenizer
34
+
35
+ def maybe_autocast(self, dtype=torch.float16):
36
+ # if on cpu, don't use autocast
37
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
38
+ enable_autocast = self.device != torch.device("cpu")
39
+
40
+ if enable_autocast:
41
+ return torch.cuda.amp.autocast(dtype=dtype)
42
+ else:
43
+ return contextlib.nullcontext()
44
+
45
+ @classmethod
46
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
47
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
48
+ encoder_config.encoder_width = vision_width
49
+ # insert cross-attention layer every other block
50
+ encoder_config.add_cross_attention = True
51
+ encoder_config.cross_attention_freq = cross_attention_freq
52
+ encoder_config.query_length = num_query_token
53
+ Qformer = BertLMHeadModel(config=encoder_config)
54
+ query_tokens = nn.Parameter(
55
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
56
+ )
57
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
58
+ return Qformer, query_tokens
59
+
60
+ @classmethod
61
+ def init_vision_encoder(
62
+ cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
63
+ ):
64
+ assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
65
+ visual_encoder = create_eva_vit_g(
66
+ img_size, drop_path_rate, use_grad_checkpoint, precision
67
+ )
68
+
69
+ ln_vision = LayerNorm(visual_encoder.num_features)
70
+ return visual_encoder, ln_vision
71
+
72
+ def load_from_pretrained(self, url_or_filename):
73
+ if is_url(url_or_filename):
74
+ cached_file = download_cached_file(
75
+ url_or_filename, check_hash=False, progress=True
76
+ )
77
+ checkpoint = torch.load(cached_file, map_location="cpu")
78
+ elif os.path.isfile(url_or_filename):
79
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
80
+ else:
81
+ raise RuntimeError("checkpoint url or path is invalid")
82
+
83
+ state_dict = checkpoint["model"]
84
+
85
+ msg = self.load_state_dict(state_dict, strict=False)
86
+
87
+ # logging.info("Missing keys {}".format(msg.missing_keys))
88
+ logging.info("load checkpoint from %s" % url_or_filename)
89
+
90
+ return msg
91
+
92
+
93
+ def disabled_train(self, mode=True):
94
+ """Overwrite model.train with this function to make sure train/eval mode
95
+ does not change anymore."""
96
+ return self
97
+
98
+
99
+ class LayerNorm(nn.LayerNorm):
100
+ """Subclass torch's LayerNorm to handle fp16."""
101
+
102
+ def forward(self, x: torch.Tensor):
103
+ orig_type = x.dtype
104
+ ret = super().forward(x.type(torch.float32))
105
+ return ret.type(orig_type)
106
+
107
+
108
+ def compute_sim_matrix(model, data_loader, **kwargs):
109
+ k_test = kwargs.pop("k_test")
110
+
111
+ metric_logger = MetricLogger(delimiter=" ")
112
+ header = "Evaluation:"
113
+
114
+ logging.info("Computing features for evaluation...")
115
+ start_time = time.time()
116
+
117
+ texts = data_loader.dataset.text
118
+ num_text = len(texts)
119
+ text_bs = 256
120
+ text_ids = []
121
+ text_embeds = []
122
+ text_atts = []
123
+ for i in range(0, num_text, text_bs):
124
+ text = texts[i : min(num_text, i + text_bs)]
125
+ text_input = model.tokenizer(
126
+ text,
127
+ padding="max_length",
128
+ truncation=True,
129
+ max_length=35,
130
+ return_tensors="pt",
131
+ ).to(model.device)
132
+ text_feat = model.forward_text(text_input)
133
+ text_embed = F.normalize(model.text_proj(text_feat))
134
+ text_embeds.append(text_embed)
135
+ text_ids.append(text_input.input_ids)
136
+ text_atts.append(text_input.attention_mask)
137
+
138
+ text_embeds = torch.cat(text_embeds, dim=0)
139
+ text_ids = torch.cat(text_ids, dim=0)
140
+ text_atts = torch.cat(text_atts, dim=0)
141
+
142
+ vit_feats = []
143
+ image_embeds = []
144
+ for samples in data_loader:
145
+ image = samples["image"]
146
+
147
+ image = image.to(model.device)
148
+ image_feat, vit_feat = model.forward_image(image)
149
+ image_embed = model.vision_proj(image_feat)
150
+ image_embed = F.normalize(image_embed, dim=-1)
151
+
152
+ vit_feats.append(vit_feat.cpu())
153
+ image_embeds.append(image_embed)
154
+
155
+ vit_feats = torch.cat(vit_feats, dim=0)
156
+ image_embeds = torch.cat(image_embeds, dim=0)
157
+
158
+ sims_matrix = []
159
+ for image_embed in image_embeds:
160
+ sim_q2t = image_embed @ text_embeds.t()
161
+ sim_i2t, _ = sim_q2t.max(0)
162
+ sims_matrix.append(sim_i2t)
163
+ sims_matrix = torch.stack(sims_matrix, dim=0)
164
+
165
+ score_matrix_i2t = torch.full(
166
+ (len(data_loader.dataset.image), len(texts)), -100.0
167
+ ).to(model.device)
168
+
169
+ num_tasks = dist_utils.get_world_size()
170
+ rank = dist_utils.get_rank()
171
+ step = sims_matrix.size(0) // num_tasks + 1
172
+ start = rank * step
173
+ end = min(sims_matrix.size(0), start + step)
174
+
175
+ for i, sims in enumerate(
176
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
177
+ ):
178
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
179
+ image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
180
+ score = model.compute_itm(
181
+ image_inputs=image_inputs,
182
+ text_ids=text_ids[topk_idx],
183
+ text_atts=text_atts[topk_idx],
184
+ ).float()
185
+ score_matrix_i2t[start + i, topk_idx] = score + topk_sim
186
+
187
+ sims_matrix = sims_matrix.t()
188
+ score_matrix_t2i = torch.full(
189
+ (len(texts), len(data_loader.dataset.image)), -100.0
190
+ ).to(model.device)
191
+
192
+ step = sims_matrix.size(0) // num_tasks + 1
193
+ start = rank * step
194
+ end = min(sims_matrix.size(0), start + step)
195
+
196
+ for i, sims in enumerate(
197
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
198
+ ):
199
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
200
+ image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
201
+ score = model.compute_itm(
202
+ image_inputs=image_inputs,
203
+ text_ids=text_ids[start + i].repeat(k_test, 1),
204
+ text_atts=text_atts[start + i].repeat(k_test, 1),
205
+ ).float()
206
+ score_matrix_t2i[start + i, topk_idx] = score + topk_sim
207
+
208
+ if dist_utils.is_dist_avail_and_initialized():
209
+ dist.barrier()
210
+ torch.distributed.all_reduce(
211
+ score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
212
+ )
213
+ torch.distributed.all_reduce(
214
+ score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
215
+ )
216
+
217
+ total_time = time.time() - start_time
218
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
219
+ logging.info("Evaluation time {}".format(total_time_str))
220
+
221
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
minigpt4/models/blip2_outputs.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from transformers.modeling_outputs import (
13
+ ModelOutput,
14
+ BaseModelOutputWithPoolingAndCrossAttentions,
15
+ CausalLMOutputWithCrossAttentions,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class BlipSimilarity(ModelOutput):
21
+ sim_i2t: torch.FloatTensor = None
22
+ sim_t2i: torch.FloatTensor = None
23
+
24
+ sim_i2t_m: Optional[torch.FloatTensor] = None
25
+ sim_t2i_m: Optional[torch.FloatTensor] = None
26
+
27
+ sim_i2t_targets: Optional[torch.FloatTensor] = None
28
+ sim_t2i_targets: Optional[torch.FloatTensor] = None
29
+
30
+
31
+ @dataclass
32
+ class BlipIntermediateOutput(ModelOutput):
33
+ """
34
+ Data class for intermediate outputs of BLIP models.
35
+
36
+ image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
37
+ text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
38
+
39
+ image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
40
+ text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
41
+
42
+ encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
43
+ encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
44
+
45
+ decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
46
+ decoder_labels (torch.LongTensor): labels for the captioning loss.
47
+
48
+ itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
49
+ itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
50
+
51
+ """
52
+
53
+ # uni-modal features
54
+ image_embeds: torch.FloatTensor = None
55
+ text_embeds: Optional[torch.FloatTensor] = None
56
+
57
+ image_embeds_m: Optional[torch.FloatTensor] = None
58
+ text_embeds_m: Optional[torch.FloatTensor] = None
59
+
60
+ # intermediate outputs of multimodal encoder
61
+ encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
62
+ encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
63
+
64
+ itm_logits: Optional[torch.FloatTensor] = None
65
+ itm_labels: Optional[torch.LongTensor] = None
66
+
67
+ # intermediate outputs of multimodal decoder
68
+ decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
69
+ decoder_labels: Optional[torch.LongTensor] = None
70
+
71
+
72
+ @dataclass
73
+ class BlipOutput(ModelOutput):
74
+ # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
75
+ sims: Optional[BlipSimilarity] = None
76
+
77
+ intermediate_output: BlipIntermediateOutput = None
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+
81
+ loss_itc: Optional[torch.FloatTensor] = None
82
+
83
+ loss_itm: Optional[torch.FloatTensor] = None
84
+
85
+ loss_lm: Optional[torch.FloatTensor] = None
86
+
87
+
88
+ @dataclass
89
+ class BlipOutputFeatures(ModelOutput):
90
+ """
91
+ Data class of features from BlipFeatureExtractor.
92
+
93
+ Args:
94
+ image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
95
+ image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
96
+ text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
97
+ text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
98
+
99
+ The first embedding or feature is for the [CLS] token.
100
+
101
+ Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
102
+ """
103
+
104
+ image_embeds: Optional[torch.FloatTensor] = None
105
+ image_embeds_proj: Optional[torch.FloatTensor] = None
106
+
107
+ text_embeds: Optional[torch.FloatTensor] = None
108
+ text_embeds_proj: Optional[torch.FloatTensor] = None
109
+
110
+ multimodal_embeds: Optional[torch.FloatTensor] = None
minigpt4/models/eva_vit.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+
18
+ from minigpt4.common.dist_utils import download_cached_file
19
+
20
+ def _cfg(url='', **kwargs):
21
+ return {
22
+ 'url': url,
23
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
24
+ 'crop_pct': .9, 'interpolation': 'bicubic',
25
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
26
+ **kwargs
27
+ }
28
+
29
+
30
+ class DropPath(nn.Module):
31
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32
+ """
33
+ def __init__(self, drop_prob=None):
34
+ super(DropPath, self).__init__()
35
+ self.drop_prob = drop_prob
36
+
37
+ def forward(self, x):
38
+ return drop_path(x, self.drop_prob, self.training)
39
+
40
+ def extra_repr(self) -> str:
41
+ return 'p={}'.format(self.drop_prob)
42
+
43
+
44
+ class Mlp(nn.Module):
45
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
46
+ super().__init__()
47
+ out_features = out_features or in_features
48
+ hidden_features = hidden_features or in_features
49
+ self.fc1 = nn.Linear(in_features, hidden_features)
50
+ self.act = act_layer()
51
+ self.fc2 = nn.Linear(hidden_features, out_features)
52
+ self.drop = nn.Dropout(drop)
53
+
54
+ def forward(self, x):
55
+ x = self.fc1(x)
56
+ x = self.act(x)
57
+ # x = self.drop(x)
58
+ # commit this for the orignal BERT implement
59
+ x = self.fc2(x)
60
+ x = self.drop(x)
61
+ return x
62
+
63
+
64
+ class Attention(nn.Module):
65
+ def __init__(
66
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
67
+ proj_drop=0., window_size=None, attn_head_dim=None):
68
+ super().__init__()
69
+ self.num_heads = num_heads
70
+ head_dim = dim // num_heads
71
+ if attn_head_dim is not None:
72
+ head_dim = attn_head_dim
73
+ all_head_dim = head_dim * self.num_heads
74
+ self.scale = qk_scale or head_dim ** -0.5
75
+
76
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
77
+ if qkv_bias:
78
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
79
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
80
+ else:
81
+ self.q_bias = None
82
+ self.v_bias = None
83
+
84
+ if window_size:
85
+ self.window_size = window_size
86
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
87
+ self.relative_position_bias_table = nn.Parameter(
88
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
89
+ # cls to token & token 2 cls & cls to cls
90
+
91
+ # get pair-wise relative position index for each token inside the window
92
+ coords_h = torch.arange(window_size[0])
93
+ coords_w = torch.arange(window_size[1])
94
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
95
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
96
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
97
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
98
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
99
+ relative_coords[:, :, 1] += window_size[1] - 1
100
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
101
+ relative_position_index = \
102
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
103
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
104
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
105
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
106
+ relative_position_index[0, 0] = self.num_relative_distance - 1
107
+
108
+ self.register_buffer("relative_position_index", relative_position_index)
109
+ else:
110
+ self.window_size = None
111
+ self.relative_position_bias_table = None
112
+ self.relative_position_index = None
113
+
114
+ self.attn_drop = nn.Dropout(attn_drop)
115
+ self.proj = nn.Linear(all_head_dim, dim)
116
+ self.proj_drop = nn.Dropout(proj_drop)
117
+
118
+ def forward(self, x, rel_pos_bias=None):
119
+ B, N, C = x.shape
120
+ qkv_bias = None
121
+ if self.q_bias is not None:
122
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
123
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
125
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
126
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1))
130
+
131
+ if self.relative_position_bias_table is not None:
132
+ relative_position_bias = \
133
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
134
+ self.window_size[0] * self.window_size[1] + 1,
135
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
136
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
137
+ attn = attn + relative_position_bias.unsqueeze(0)
138
+
139
+ if rel_pos_bias is not None:
140
+ attn = attn + rel_pos_bias
141
+
142
+ attn = attn.softmax(dim=-1)
143
+ attn = self.attn_drop(attn)
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ return x
149
+
150
+
151
+ class Block(nn.Module):
152
+
153
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
154
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
155
+ window_size=None, attn_head_dim=None):
156
+ super().__init__()
157
+ self.norm1 = norm_layer(dim)
158
+ self.attn = Attention(
159
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
160
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
161
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
162
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+ self.norm2 = norm_layer(dim)
164
+ mlp_hidden_dim = int(dim * mlp_ratio)
165
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
166
+
167
+ if init_values is not None and init_values > 0:
168
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
169
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
170
+ else:
171
+ self.gamma_1, self.gamma_2 = None, None
172
+
173
+ def forward(self, x, rel_pos_bias=None):
174
+ if self.gamma_1 is None:
175
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
176
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
177
+ else:
178
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
179
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
180
+ return x
181
+
182
+
183
+ class PatchEmbed(nn.Module):
184
+ """ Image to Patch Embedding
185
+ """
186
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
187
+ super().__init__()
188
+ img_size = to_2tuple(img_size)
189
+ patch_size = to_2tuple(patch_size)
190
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
191
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
192
+ self.img_size = img_size
193
+ self.patch_size = patch_size
194
+ self.num_patches = num_patches
195
+
196
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
197
+
198
+ def forward(self, x, **kwargs):
199
+ B, C, H, W = x.shape
200
+ # FIXME look at relaxing size constraints
201
+ assert H == self.img_size[0] and W == self.img_size[1], \
202
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
203
+ x = self.proj(x).flatten(2).transpose(1, 2)
204
+ return x
205
+
206
+
207
+ class RelativePositionBias(nn.Module):
208
+
209
+ def __init__(self, window_size, num_heads):
210
+ super().__init__()
211
+ self.window_size = window_size
212
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
213
+ self.relative_position_bias_table = nn.Parameter(
214
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
215
+ # cls to token & token 2 cls & cls to cls
216
+
217
+ # get pair-wise relative position index for each token inside the window
218
+ coords_h = torch.arange(window_size[0])
219
+ coords_w = torch.arange(window_size[1])
220
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
221
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
222
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
223
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
224
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
225
+ relative_coords[:, :, 1] += window_size[1] - 1
226
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
227
+ relative_position_index = \
228
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
229
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
230
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
231
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
232
+ relative_position_index[0, 0] = self.num_relative_distance - 1
233
+
234
+ self.register_buffer("relative_position_index", relative_position_index)
235
+
236
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
237
+
238
+ def forward(self):
239
+ relative_position_bias = \
240
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
241
+ self.window_size[0] * self.window_size[1] + 1,
242
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
243
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
244
+
245
+
246
+ class VisionTransformer(nn.Module):
247
+ """ Vision Transformer with support for patch or hybrid CNN input stage
248
+ """
249
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
250
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
251
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
252
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
253
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
254
+ super().__init__()
255
+ self.image_size = img_size
256
+ self.num_classes = num_classes
257
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
258
+
259
+ self.patch_embed = PatchEmbed(
260
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
261
+ num_patches = self.patch_embed.num_patches
262
+
263
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
264
+ if use_abs_pos_emb:
265
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
266
+ else:
267
+ self.pos_embed = None
268
+ self.pos_drop = nn.Dropout(p=drop_rate)
269
+
270
+ if use_shared_rel_pos_bias:
271
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
272
+ else:
273
+ self.rel_pos_bias = None
274
+ self.use_checkpoint = use_checkpoint
275
+
276
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
277
+ self.use_rel_pos_bias = use_rel_pos_bias
278
+ self.blocks = nn.ModuleList([
279
+ Block(
280
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
281
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
282
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
283
+ for i in range(depth)])
284
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
285
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
286
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
287
+
288
+ if self.pos_embed is not None:
289
+ trunc_normal_(self.pos_embed, std=.02)
290
+ trunc_normal_(self.cls_token, std=.02)
291
+ # trunc_normal_(self.mask_token, std=.02)
292
+ # if isinstance(self.head, nn.Linear):
293
+ # trunc_normal_(self.head.weight, std=.02)
294
+ self.apply(self._init_weights)
295
+ self.fix_init_weight()
296
+ # if isinstance(self.head, nn.Linear):
297
+ # self.head.weight.data.mul_(init_scale)
298
+ # self.head.bias.data.mul_(init_scale)
299
+
300
+ def fix_init_weight(self):
301
+ def rescale(param, layer_id):
302
+ param.div_(math.sqrt(2.0 * layer_id))
303
+
304
+ for layer_id, layer in enumerate(self.blocks):
305
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
306
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
307
+
308
+ def _init_weights(self, m):
309
+ if isinstance(m, nn.Linear):
310
+ trunc_normal_(m.weight, std=.02)
311
+ if isinstance(m, nn.Linear) and m.bias is not None:
312
+ nn.init.constant_(m.bias, 0)
313
+ elif isinstance(m, nn.LayerNorm):
314
+ nn.init.constant_(m.bias, 0)
315
+ nn.init.constant_(m.weight, 1.0)
316
+
317
+ def get_classifier(self):
318
+ return self.head
319
+
320
+ def reset_classifier(self, num_classes, global_pool=''):
321
+ self.num_classes = num_classes
322
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
323
+
324
+ def forward_features(self, x):
325
+ x = self.patch_embed(x)
326
+ batch_size, seq_len, _ = x.size()
327
+
328
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
329
+ x = torch.cat((cls_tokens, x), dim=1)
330
+ if self.pos_embed is not None:
331
+ x = x + self.pos_embed
332
+ x = self.pos_drop(x)
333
+
334
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
335
+ for blk in self.blocks:
336
+ if self.use_checkpoint:
337
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
338
+ else:
339
+ x = blk(x, rel_pos_bias)
340
+ return x
341
+ # x = self.norm(x)
342
+
343
+ # if self.fc_norm is not None:
344
+ # t = x[:, 1:, :]
345
+ # return self.fc_norm(t.mean(1))
346
+ # else:
347
+ # return x[:, 0]
348
+
349
+ def forward(self, x):
350
+ x = self.forward_features(x)
351
+ # x = self.head(x)
352
+ return x
353
+
354
+ def get_intermediate_layers(self, x):
355
+ x = self.patch_embed(x)
356
+ batch_size, seq_len, _ = x.size()
357
+
358
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
359
+ x = torch.cat((cls_tokens, x), dim=1)
360
+ if self.pos_embed is not None:
361
+ x = x + self.pos_embed
362
+ x = self.pos_drop(x)
363
+
364
+ features = []
365
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
366
+ for blk in self.blocks:
367
+ x = blk(x, rel_pos_bias)
368
+ features.append(x)
369
+
370
+ return features
371
+
372
+
373
+ def interpolate_pos_embed(model, checkpoint_model):
374
+ if 'pos_embed' in checkpoint_model:
375
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
376
+ embedding_size = pos_embed_checkpoint.shape[-1]
377
+ num_patches = model.patch_embed.num_patches
378
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
379
+ # height (== width) for the checkpoint position embedding
380
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
381
+ # height (== width) for the new position embedding
382
+ new_size = int(num_patches ** 0.5)
383
+ # class_token and dist_token are kept unchanged
384
+ if orig_size != new_size:
385
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
386
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
387
+ # only the position tokens are interpolated
388
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
389
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
390
+ pos_tokens = torch.nn.functional.interpolate(
391
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
392
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
393
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
394
+ checkpoint_model['pos_embed'] = new_pos_embed
395
+
396
+
397
+ def convert_weights_to_fp16(model: nn.Module):
398
+ """Convert applicable model parameters to fp16"""
399
+
400
+ def _convert_weights_to_fp16(l):
401
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
402
+ l.weight.data = l.weight.data.half()
403
+ if l.bias is not None:
404
+ l.bias.data = l.bias.data.half()
405
+
406
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
407
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
408
+ # tensor = getattr(l, attr)
409
+ # if tensor is not None:
410
+ # tensor.data = tensor.data.half()
411
+
412
+ model.apply(_convert_weights_to_fp16)
413
+
414
+
415
+ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
416
+ model = VisionTransformer(
417
+ img_size=img_size,
418
+ patch_size=14,
419
+ use_mean_pooling=False,
420
+ embed_dim=1408,
421
+ depth=39,
422
+ num_heads=1408//88,
423
+ mlp_ratio=4.3637,
424
+ qkv_bias=True,
425
+ drop_path_rate=drop_path_rate,
426
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
427
+ use_checkpoint=use_checkpoint,
428
+ )
429
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
430
+ cached_file = download_cached_file(
431
+ url, check_hash=False, progress=True
432
+ )
433
+ state_dict = torch.load(cached_file, map_location="cpu")
434
+ interpolate_pos_embed(model,state_dict)
435
+
436
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
437
+ # print(incompatible_keys)
438
+
439
+ if precision == "fp16":
440
+ # model.to("cuda")
441
+ convert_weights_to_fp16(model)
442
+ return model
minigpt4/models/mini_gpt4.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import random
9
+ import os
10
+ import torch
11
+ from torch.cuda.amp import autocast as autocast
12
+ import torch.nn as nn
13
+
14
+ from minigpt4.common.registry import registry
15
+ from minigpt4.models.blip2 import Blip2Base, disabled_train
16
+ from minigpt4.models.modeling_llama import LlamaForCausalLM
17
+ from transformers import LlamaTokenizer
18
+
19
+
20
+ @registry.register_model("mini_gpt4")
21
+ class MiniGPT4(Blip2Base):
22
+ """
23
+ BLIP2 GPT-LLAMA model.
24
+ """
25
+
26
+ PRETRAINED_MODEL_CONFIG_DICT = {
27
+ "pretrain_vicuna": "configs/models/minigpt4.yaml",
28
+ }
29
+
30
+ def __init__(
31
+ self,
32
+ vit_model="eva_clip_g",
33
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
34
+ img_size=224,
35
+ drop_path_rate=0,
36
+ use_grad_checkpoint=False,
37
+ vit_precision="fp16",
38
+ freeze_vit=True,
39
+ freeze_qformer=True,
40
+ num_query_token=32,
41
+ llama_model="",
42
+ llama_cache_dir='',
43
+ prompt_path="",
44
+ prompt_template="",
45
+ max_txt_len=32,
46
+ end_sym='\n',
47
+ ):
48
+ super().__init__()
49
+
50
+ self.tokenizer = self.init_tokenizer()
51
+
52
+ print('Loading VIT')
53
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
54
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
55
+ )
56
+ if freeze_vit:
57
+ for name, param in self.visual_encoder.named_parameters():
58
+ param.requires_grad = False
59
+ self.visual_encoder = self.visual_encoder.eval()
60
+ self.visual_encoder.train = disabled_train
61
+ for name, param in self.ln_vision.named_parameters():
62
+ param.requires_grad = False
63
+ self.ln_vision = self.ln_vision.eval()
64
+ self.ln_vision.train = disabled_train
65
+ logging.info("freeze vision encoder")
66
+ print('Loading VIT Done')
67
+
68
+ print('Loading Q-Former')
69
+ self.Qformer, self.query_tokens = self.init_Qformer(
70
+ num_query_token, self.visual_encoder.num_features
71
+ )
72
+ self.Qformer.cls = None
73
+ self.Qformer.bert.embeddings.word_embeddings = None
74
+ self.Qformer.bert.embeddings.position_embeddings = None
75
+ for layer in self.Qformer.bert.encoder.layer:
76
+ layer.output = None
77
+ layer.intermediate = None
78
+ self.load_from_pretrained(url_or_filename=q_former_model)
79
+
80
+ if freeze_qformer:
81
+ for name, param in self.Qformer.named_parameters():
82
+ param.requires_grad = False
83
+ self.Qformer = self.Qformer.eval()
84
+ self.Qformer.train = disabled_train
85
+ self.query_tokens.requires_grad = False
86
+ logging.info("freeze Qformer")
87
+ print('Loading Q-Former Done')
88
+
89
+ print('Loading LLAMA')
90
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained('Vision-CAIR/vicuna', use_fast=False, use_auth_token=os.environ["API_TOKEN"])
91
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
92
+
93
+ if llama_cache_dir:
94
+ self.llama_model = LlamaForCausalLM.from_pretrained(
95
+ 'Vision-CAIR/vicuna', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=os.environ["API_TOKEN"]
96
+ )
97
+ else:
98
+ self.llama_model = LlamaForCausalLM.from_pretrained(
99
+ 'Vision-CAIR/vicuna', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=os.environ["API_TOKEN"]
100
+ )
101
+ for name, param in self.llama_model.named_parameters():
102
+ param.requires_grad = False
103
+ print('Loading LLAMA Done')
104
+
105
+ self.llama_proj = nn.Linear(
106
+ self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
107
+ )
108
+ self.max_txt_len = max_txt_len
109
+ self.end_sym = end_sym
110
+
111
+ if prompt_path:
112
+ with open(prompt_path, 'r') as f:
113
+ raw_prompts = f.read().splitlines()
114
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
115
+ self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
116
+ print('Load {} training prompts'.format(len(self.prompt_list)))
117
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
118
+ else:
119
+ self.prompt_list = []
120
+
121
+ def vit_to_cpu(self):
122
+ self.ln_vision.to("cpu")
123
+ self.ln_vision.float()
124
+ self.visual_encoder.to("cpu")
125
+ self.visual_encoder.float()
126
+
127
+ def encode_img(self, image):
128
+ device = image.device
129
+ self.vit_to_cpu()
130
+ image = image.to("cpu")
131
+ with self.maybe_autocast():
132
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
133
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
134
+
135
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
136
+ query_output = self.Qformer.bert(
137
+ query_embeds=query_tokens,
138
+ encoder_hidden_states=image_embeds,
139
+ encoder_attention_mask=image_atts,
140
+ return_dict=True,
141
+ )
142
+
143
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
144
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
145
+ return inputs_llama, atts_llama
146
+
147
+ def prompt_wrap(self, img_embeds, atts_img, prompt):
148
+ if prompt:
149
+ batch_size = img_embeds.shape[0]
150
+ p_before, p_after = prompt.split('<ImageHere>')
151
+ p_before_tokens = self.llama_tokenizer(
152
+ p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
153
+ p_after_tokens = self.llama_tokenizer(
154
+ p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
155
+ p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
156
+ p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
157
+ wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
158
+ wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
159
+ return wrapped_img_embeds, wrapped_atts_img
160
+ else:
161
+ return img_embeds, atts_img
162
+
163
+ def forward(self, samples):
164
+ image = samples["image"]
165
+ img_embeds, atts_img = self.encode_img(image)
166
+ if hasattr(samples, 'question_split'): # VQA dataset
167
+ print('VQA Batch')
168
+ vqa_prompt = '###Human: <Img><ImageHere></Img> '
169
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
170
+ elif self.prompt_list:
171
+ prompt = random.choice(self.prompt_list)
172
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
173
+
174
+ self.llama_tokenizer.padding_side = "right"
175
+
176
+ text = [t + self.end_sym for t in samples["text_input"]]
177
+
178
+ to_regress_tokens = self.llama_tokenizer(
179
+ text,
180
+ return_tensors="pt",
181
+ padding="longest",
182
+ truncation=True,
183
+ max_length=self.max_txt_len,
184
+ add_special_tokens=False
185
+ ).to(image.device)
186
+
187
+ targets = to_regress_tokens.input_ids.masked_fill(
188
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
189
+ )
190
+
191
+ empty_targets = (
192
+ torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
193
+ dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
194
+ )
195
+ targets = torch.cat([empty_targets, targets], dim=1)
196
+
197
+ batch_size = img_embeds.shape[0]
198
+ bos = torch.ones([batch_size, 1],
199
+ dtype=to_regress_tokens.input_ids.dtype,
200
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
201
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
202
+ atts_bos = atts_img[:, :1]
203
+
204
+ to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
205
+ inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
206
+ attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
207
+
208
+ with self.maybe_autocast():
209
+ outputs = self.llama_model(
210
+ inputs_embeds=inputs_embeds,
211
+ attention_mask=attention_mask,
212
+ return_dict=True,
213
+ labels=targets,
214
+ )
215
+ loss = outputs.loss
216
+
217
+ return {"loss": loss}
218
+
219
+ @classmethod
220
+ def from_config(cls, cfg):
221
+ vit_model = cfg.get("vit_model", "eva_clip_g")
222
+ q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
223
+ img_size = cfg.get("image_size")
224
+ num_query_token = cfg.get("num_query_token")
225
+ llama_model = cfg.get("llama_model")
226
+
227
+ drop_path_rate = cfg.get("drop_path_rate", 0)
228
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
229
+ vit_precision = cfg.get("vit_precision", "fp16")
230
+ freeze_vit = cfg.get("freeze_vit", True)
231
+ freeze_qformer = cfg.get("freeze_qformer", True)
232
+ llama_cache_dir = cfg.get("llama_cache_dir", "")
233
+
234
+ prompt_path = cfg.get("prompt_path", "")
235
+ prompt_template = cfg.get("prompt_template", "")
236
+ max_txt_len = cfg.get("max_txt_len", 32)
237
+ end_sym = cfg.get("end_sym", '\n')
238
+
239
+ model = cls(
240
+ vit_model=vit_model,
241
+ q_former_model=q_former_model,
242
+ img_size=img_size,
243
+ drop_path_rate=drop_path_rate,
244
+ use_grad_checkpoint=use_grad_checkpoint,
245
+ vit_precision=vit_precision,
246
+ freeze_vit=freeze_vit,
247
+ freeze_qformer=freeze_qformer,
248
+ llama_cache_dir=llama_cache_dir,
249
+ num_query_token=num_query_token,
250
+ llama_model=llama_model,
251
+ prompt_path=prompt_path,
252
+ prompt_template=prompt_template,
253
+ max_txt_len=max_txt_len,
254
+ end_sym=end_sym
255
+ )
256
+
257
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
258
+ if ckpt_path:
259
+ print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
260
+ ckpt = torch.load(ckpt_path, map_location="cpu")
261
+ msg = model.load_state_dict(ckpt['model'], strict=False)
262
+
263
+ return model
minigpt4/models/modeling_llama.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
+ from transformers.models.llama.configuration_llama import LlamaConfig
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = "LlamaConfig"
39
+
40
+
41
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42
+ def _make_causal_mask(
43
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44
+ ):
45
+ """
46
+ Make causal mask used for bi-directional self-attention.
47
+ """
48
+ bsz, tgt_len = input_ids_shape
49
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
50
+ mask_cond = torch.arange(mask.size(-1), device=device)
51
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52
+ mask = mask.to(dtype)
53
+
54
+ if past_key_values_length > 0:
55
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
60
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
+ """
62
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
+ """
64
+ bsz, src_len = mask.size()
65
+ tgt_len = tgt_len if tgt_len is not None else src_len
66
+
67
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+
69
+ inverted_mask = 1.0 - expanded_mask
70
+
71
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72
+
73
+
74
+ class LlamaRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ LlamaRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
85
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
86
+
87
+ # convert into half-precision if necessary
88
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
89
+ hidden_states = hidden_states.to(self.weight.dtype)
90
+
91
+ return self.weight * hidden_states
92
+
93
+
94
+ class LlamaRotaryEmbedding(torch.nn.Module):
95
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
96
+ super().__init__()
97
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
98
+ self.register_buffer("inv_freq", inv_freq)
99
+
100
+ # Build here to make `torch.jit.trace` work.
101
+ self.max_seq_len_cached = max_position_embeddings
102
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
103
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
104
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
105
+ emb = torch.cat((freqs, freqs), dim=-1)
106
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
107
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
108
+
109
+ def forward(self, x, seq_len=None):
110
+ # x: [bs, num_attention_heads, seq_len, head_size]
111
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
112
+ if seq_len > self.max_seq_len_cached:
113
+ self.max_seq_len_cached = seq_len
114
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
115
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
116
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
117
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
118
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
119
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
120
+ return (
121
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
122
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
123
+ )
124
+
125
+
126
+ def rotate_half(x):
127
+ """Rotates half the hidden dims of the input."""
128
+ x1 = x[..., : x.shape[-1] // 2]
129
+ x2 = x[..., x.shape[-1] // 2 :]
130
+ return torch.cat((-x2, x1), dim=-1)
131
+
132
+
133
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
134
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
135
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
136
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
137
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
138
+ q_embed = (q * cos) + (rotate_half(q) * sin)
139
+ k_embed = (k * cos) + (rotate_half(k) * sin)
140
+ return q_embed, k_embed
141
+
142
+
143
+ class LlamaMLP(nn.Module):
144
+ def __init__(
145
+ self,
146
+ hidden_size: int,
147
+ intermediate_size: int,
148
+ hidden_act: str,
149
+ ):
150
+ super().__init__()
151
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
152
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
153
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
154
+ self.act_fn = ACT2FN[hidden_act]
155
+
156
+ def forward(self, x):
157
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
158
+
159
+
160
+ class LlamaAttention(nn.Module):
161
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
162
+
163
+ def __init__(self, config: LlamaConfig):
164
+ super().__init__()
165
+ self.config = config
166
+ self.hidden_size = config.hidden_size
167
+ self.num_heads = config.num_attention_heads
168
+ self.head_dim = self.hidden_size // self.num_heads
169
+ self.max_position_embeddings = config.max_position_embeddings
170
+
171
+ if (self.head_dim * self.num_heads) != self.hidden_size:
172
+ raise ValueError(
173
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
174
+ f" and `num_heads`: {self.num_heads})."
175
+ )
176
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
177
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
178
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
179
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
180
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
181
+
182
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
183
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
184
+
185
+ def forward(
186
+ self,
187
+ hidden_states: torch.Tensor,
188
+ attention_mask: Optional[torch.Tensor] = None,
189
+ position_ids: Optional[torch.LongTensor] = None,
190
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
191
+ output_attentions: bool = False,
192
+ use_cache: bool = False,
193
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
194
+ bsz, q_len, _ = hidden_states.size()
195
+
196
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
197
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
199
+
200
+ kv_seq_len = key_states.shape[-2]
201
+ if past_key_value is not None:
202
+ kv_seq_len += past_key_value[0].shape[-2]
203
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
204
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
205
+ # [bsz, nh, t, hd]
206
+
207
+ if past_key_value is not None:
208
+ # reuse k, v, self_attention
209
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
210
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
211
+
212
+ past_key_value = (key_states, value_states) if use_cache else None
213
+
214
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
215
+
216
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
217
+ raise ValueError(
218
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
219
+ f" {attn_weights.size()}"
220
+ )
221
+
222
+ if attention_mask is not None:
223
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
224
+ raise ValueError(
225
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
226
+ )
227
+ attn_weights = attn_weights + attention_mask
228
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
229
+
230
+ # upcast attention to fp32
231
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
232
+ attn_output = torch.matmul(attn_weights, value_states)
233
+
234
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
235
+ raise ValueError(
236
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
237
+ f" {attn_output.size()}"
238
+ )
239
+
240
+ attn_output = attn_output.transpose(1, 2)
241
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
242
+
243
+ attn_output = self.o_proj(attn_output)
244
+
245
+ if not output_attentions:
246
+ attn_weights = None
247
+
248
+ return attn_output, attn_weights, past_key_value
249
+
250
+
251
+ class LlamaDecoderLayer(nn.Module):
252
+ def __init__(self, config: LlamaConfig):
253
+ super().__init__()
254
+ self.hidden_size = config.hidden_size
255
+ self.self_attn = LlamaAttention(config=config)
256
+ self.mlp = LlamaMLP(
257
+ hidden_size=self.hidden_size,
258
+ intermediate_size=config.intermediate_size,
259
+ hidden_act=config.hidden_act,
260
+ )
261
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
262
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
270
+ output_attentions: Optional[bool] = False,
271
+ use_cache: Optional[bool] = False,
272
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
273
+ """
274
+ Args:
275
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
276
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
277
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
278
+ output_attentions (`bool`, *optional*):
279
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
280
+ returned tensors for more detail.
281
+ use_cache (`bool`, *optional*):
282
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
283
+ (see `past_key_values`).
284
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
285
+ """
286
+
287
+ residual = hidden_states
288
+
289
+ hidden_states = self.input_layernorm(hidden_states)
290
+
291
+ # Self Attention
292
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
293
+ hidden_states=hidden_states,
294
+ attention_mask=attention_mask,
295
+ position_ids=position_ids,
296
+ past_key_value=past_key_value,
297
+ output_attentions=output_attentions,
298
+ use_cache=use_cache,
299
+ )
300
+ hidden_states = residual + hidden_states
301
+
302
+ # Fully Connected
303
+ residual = hidden_states
304
+ hidden_states = self.post_attention_layernorm(hidden_states)
305
+ hidden_states = self.mlp(hidden_states)
306
+ hidden_states = residual + hidden_states
307
+
308
+ outputs = (hidden_states,)
309
+
310
+ if output_attentions:
311
+ outputs += (self_attn_weights,)
312
+
313
+ if use_cache:
314
+ outputs += (present_key_value,)
315
+
316
+ return outputs
317
+
318
+
319
+ LLAMA_START_DOCSTRING = r"""
320
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
321
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
322
+ etc.)
323
+
324
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
325
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
326
+ and behavior.
327
+
328
+ Parameters:
329
+ config ([`LlamaConfig`]):
330
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
331
+ load the weights associated with the model, only the configuration. Check out the
332
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
333
+ """
334
+
335
+
336
+ @add_start_docstrings(
337
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
338
+ LLAMA_START_DOCSTRING,
339
+ )
340
+ class LlamaPreTrainedModel(PreTrainedModel):
341
+ config_class = LlamaConfig
342
+ base_model_prefix = "model"
343
+ supports_gradient_checkpointing = True
344
+ _no_split_modules = ["LlamaDecoderLayer"]
345
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
346
+
347
+ def _init_weights(self, module):
348
+ std = self.config.initializer_range
349
+ if isinstance(module, nn.Linear):
350
+ module.weight.data.normal_(mean=0.0, std=std)
351
+ if module.bias is not None:
352
+ module.bias.data.zero_()
353
+ elif isinstance(module, nn.Embedding):
354
+ module.weight.data.normal_(mean=0.0, std=std)
355
+ if module.padding_idx is not None:
356
+ module.weight.data[module.padding_idx].zero_()
357
+
358
+ def _set_gradient_checkpointing(self, module, value=False):
359
+ if isinstance(module, LlamaModel):
360
+ module.gradient_checkpointing = value
361
+
362
+
363
+ LLAMA_INPUTS_DOCSTRING = r"""
364
+ Args:
365
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
366
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
367
+ it.
368
+
369
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
370
+ [`PreTrainedTokenizer.__call__`] for details.
371
+
372
+ [What are input IDs?](../glossary#input-ids)
373
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
374
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
375
+
376
+ - 1 for tokens that are **not masked**,
377
+ - 0 for tokens that are **masked**.
378
+
379
+ [What are attention masks?](../glossary#attention-mask)
380
+
381
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
382
+ [`PreTrainedTokenizer.__call__`] for details.
383
+
384
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
385
+ `past_key_values`).
386
+
387
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
388
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
389
+ information on the default strategy.
390
+
391
+ - 1 indicates the head is **not masked**,
392
+ - 0 indicates the head is **masked**.
393
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
394
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
395
+ config.n_positions - 1]`.
396
+
397
+ [What are position IDs?](../glossary#position-ids)
398
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
399
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
400
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
401
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
402
+
403
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
404
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
405
+
406
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
407
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
408
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
409
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
410
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
411
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
412
+ model's internal embedding lookup matrix.
413
+ use_cache (`bool`, *optional*):
414
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
415
+ `past_key_values`).
416
+ output_attentions (`bool`, *optional*):
417
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
418
+ tensors for more detail.
419
+ output_hidden_states (`bool`, *optional*):
420
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
421
+ more detail.
422
+ return_dict (`bool`, *optional*):
423
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
424
+ """
425
+
426
+
427
+ @add_start_docstrings(
428
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
429
+ LLAMA_START_DOCSTRING,
430
+ )
431
+ class LlamaModel(LlamaPreTrainedModel):
432
+ """
433
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
434
+
435
+ Args:
436
+ config: LlamaConfig
437
+ """
438
+
439
+ def __init__(self, config: LlamaConfig):
440
+ super().__init__(config)
441
+ self.padding_idx = config.pad_token_id
442
+ self.vocab_size = config.vocab_size
443
+
444
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
445
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
446
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
447
+
448
+ self.gradient_checkpointing = False
449
+ # Initialize weights and apply final processing
450
+ self.post_init()
451
+
452
+ def get_input_embeddings(self):
453
+ return self.embed_tokens
454
+
455
+ def set_input_embeddings(self, value):
456
+ self.embed_tokens = value
457
+
458
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
459
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
460
+ # create causal mask
461
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
462
+ combined_attention_mask = None
463
+ if input_shape[-1] > 1:
464
+ combined_attention_mask = _make_causal_mask(
465
+ input_shape,
466
+ inputs_embeds.dtype,
467
+ device=inputs_embeds.device,
468
+ past_key_values_length=past_key_values_length,
469
+ )
470
+
471
+ if attention_mask is not None:
472
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
473
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
474
+ inputs_embeds.device
475
+ )
476
+ combined_attention_mask = (
477
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
478
+ )
479
+
480
+ return combined_attention_mask
481
+
482
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
483
+ def forward(
484
+ self,
485
+ input_ids: torch.LongTensor = None,
486
+ attention_mask: Optional[torch.Tensor] = None,
487
+ position_ids: Optional[torch.LongTensor] = None,
488
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
489
+ inputs_embeds: Optional[torch.FloatTensor] = None,
490
+ query_embeds: Optional[torch.FloatTensor] = None,
491
+ use_cache: Optional[bool] = None,
492
+ output_attentions: Optional[bool] = None,
493
+ output_hidden_states: Optional[bool] = None,
494
+ return_dict: Optional[bool] = None,
495
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
496
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
497
+ output_hidden_states = (
498
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
499
+ )
500
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
501
+
502
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
503
+
504
+ # retrieve input_ids and inputs_embeds
505
+ if input_ids is not None and inputs_embeds is not None:
506
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
507
+ elif input_ids is not None:
508
+ batch_size, seq_length = input_ids.shape
509
+ elif inputs_embeds is not None:
510
+ batch_size, seq_length, _ = inputs_embeds.shape
511
+ else:
512
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
513
+
514
+ if inputs_embeds is None:
515
+ inputs_embeds = self.embed_tokens(input_ids)
516
+ if query_embeds is not None:
517
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
518
+ batch_size, seq_length, _ = inputs_embeds.shape
519
+
520
+ seq_length_with_past = seq_length
521
+ past_key_values_length = 0
522
+
523
+ if past_key_values is not None:
524
+ past_key_values_length = past_key_values[0][0].shape[2]
525
+ seq_length_with_past = seq_length_with_past + past_key_values_length
526
+
527
+ if position_ids is None:
528
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
529
+ position_ids = torch.arange(
530
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
531
+ )
532
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
533
+ else:
534
+ position_ids = position_ids.view(-1, seq_length).long()
535
+
536
+ # embed positions
537
+ if attention_mask is None:
538
+ attention_mask = torch.ones(
539
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
540
+ )
541
+ attention_mask = self._prepare_decoder_attention_mask(
542
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
543
+ )
544
+
545
+ hidden_states = inputs_embeds
546
+
547
+ if self.gradient_checkpointing and self.training:
548
+ if use_cache:
549
+ logger.warning_once(
550
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
551
+ )
552
+ use_cache = False
553
+
554
+ # decoder layers
555
+ all_hidden_states = () if output_hidden_states else None
556
+ all_self_attns = () if output_attentions else None
557
+ next_decoder_cache = () if use_cache else None
558
+
559
+ for idx, decoder_layer in enumerate(self.layers):
560
+ if output_hidden_states:
561
+ all_hidden_states += (hidden_states,)
562
+
563
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
564
+
565
+ if self.gradient_checkpointing and self.training:
566
+
567
+ def create_custom_forward(module):
568
+ def custom_forward(*inputs):
569
+ # None for past_key_value
570
+ return module(*inputs, output_attentions, None)
571
+
572
+ return custom_forward
573
+
574
+ layer_outputs = torch.utils.checkpoint.checkpoint(
575
+ create_custom_forward(decoder_layer),
576
+ hidden_states,
577
+ attention_mask,
578
+ position_ids,
579
+ None,
580
+ )
581
+ else:
582
+ layer_outputs = decoder_layer(
583
+ hidden_states,
584
+ attention_mask=attention_mask,
585
+ position_ids=position_ids,
586
+ past_key_value=past_key_value,
587
+ output_attentions=output_attentions,
588
+ use_cache=use_cache,
589
+ )
590
+
591
+ hidden_states = layer_outputs[0]
592
+
593
+ if use_cache:
594
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
595
+
596
+ if output_attentions:
597
+ all_self_attns += (layer_outputs[1],)
598
+
599
+ hidden_states = self.norm(hidden_states)
600
+
601
+ # add hidden states from the last decoder layer
602
+ if output_hidden_states:
603
+ all_hidden_states += (hidden_states,)
604
+
605
+ next_cache = next_decoder_cache if use_cache else None
606
+ if not return_dict:
607
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
608
+ return BaseModelOutputWithPast(
609
+ last_hidden_state=hidden_states,
610
+ past_key_values=next_cache,
611
+ hidden_states=all_hidden_states,
612
+ attentions=all_self_attns,
613
+ )
614
+
615
+
616
+ class LlamaForCausalLM(LlamaPreTrainedModel):
617
+ def __init__(self, config):
618
+ super().__init__(config)
619
+ self.model = LlamaModel(config)
620
+
621
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
622
+
623
+ # Initialize weights and apply final processing
624
+ self.post_init()
625
+
626
+ def get_input_embeddings(self):
627
+ return self.model.embed_tokens
628
+
629
+ def set_input_embeddings(self, value):
630
+ self.model.embed_tokens = value
631
+
632
+ def get_output_embeddings(self):
633
+ return self.lm_head
634
+
635
+ def set_output_embeddings(self, new_embeddings):
636
+ self.lm_head = new_embeddings
637
+
638
+ def set_decoder(self, decoder):
639
+ self.model = decoder
640
+
641
+ def get_decoder(self):
642
+ return self.model
643
+
644
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
645
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
646
+ def forward(
647
+ self,
648
+ input_ids: torch.LongTensor = None,
649
+ attention_mask: Optional[torch.Tensor] = None,
650
+ position_ids: Optional[torch.LongTensor] = None,
651
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
652
+ inputs_embeds: Optional[torch.FloatTensor] = None,
653
+ query_embeds: Optional[torch.FloatTensor] = None,
654
+ labels: Optional[torch.LongTensor] = None,
655
+ use_cache: Optional[bool] = None,
656
+ output_attentions: Optional[bool] = None,
657
+ output_hidden_states: Optional[bool] = None,
658
+ return_dict: Optional[bool] = None,
659
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
660
+ r"""
661
+ Args:
662
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
663
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
664
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
665
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
666
+
667
+ Returns:
668
+
669
+ Example:
670
+
671
+ ```python
672
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
673
+
674
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
675
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
676
+
677
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
678
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
679
+
680
+ >>> # Generate
681
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
682
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
683
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
684
+ ```"""
685
+
686
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
687
+ output_hidden_states = (
688
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
689
+ )
690
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
691
+
692
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
693
+ outputs = self.model(
694
+ input_ids=input_ids,
695
+ attention_mask=attention_mask,
696
+ position_ids=position_ids,
697
+ past_key_values=past_key_values,
698
+ inputs_embeds=inputs_embeds,
699
+ query_embeds=query_embeds,
700
+ use_cache=use_cache,
701
+ output_attentions=output_attentions,
702
+ output_hidden_states=output_hidden_states,
703
+ return_dict=return_dict,
704
+ )
705
+
706
+ hidden_states = outputs[0]
707
+ logits = self.lm_head(hidden_states)
708
+
709
+ loss = None
710
+ if labels is not None:
711
+ # Shift so that tokens < n predict n
712
+ shift_logits = logits[..., :-1, :].contiguous()
713
+ shift_labels = labels[..., 1:].contiguous()
714
+ # Flatten the tokens
715
+ loss_fct = CrossEntropyLoss()
716
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
717
+ shift_labels = shift_labels.view(-1)
718
+ # Enable model parallelism
719
+ shift_labels = shift_labels.to(shift_logits.device)
720
+ loss = loss_fct(shift_logits, shift_labels)
721
+
722
+ if not return_dict:
723
+ output = (logits,) + outputs[1:]
724
+ return (loss,) + output if loss is not None else output
725
+
726
+ return CausalLMOutputWithPast(
727
+ loss=loss,
728
+ logits=logits,
729
+ past_key_values=outputs.past_key_values,
730
+ hidden_states=outputs.hidden_states,
731
+ attentions=outputs.attentions,
732
+ )
733
+
734
+ def prepare_inputs_for_generation(
735
+ self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
736
+ ):
737
+ if past_key_values:
738
+ input_ids = input_ids[:, -1:]
739
+
740
+ position_ids = kwargs.get("position_ids", None)
741
+ if attention_mask is not None and position_ids is None:
742
+ # create position_ids on the fly for batch generation
743
+ position_ids = attention_mask.long().cumsum(-1) - 1
744
+ position_ids.masked_fill_(attention_mask == 0, 1)
745
+ if past_key_values:
746
+ position_ids = position_ids[:, -1].unsqueeze(-1)
747
+ query_embeds = None
748
+
749
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
750
+ if inputs_embeds is not None and past_key_values is None:
751
+ model_inputs = {"inputs_embeds": inputs_embeds}
752
+ else:
753
+ model_inputs = {"input_ids": input_ids}
754
+
755
+ model_inputs.update(
756
+ {
757
+ "position_ids": position_ids,
758
+ "query_embeds": query_embeds,
759
+ "past_key_values": past_key_values,
760
+ "use_cache": kwargs.get("use_cache"),
761
+ "attention_mask": attention_mask,
762
+ }
763
+ )
764
+ return model_inputs
765
+
766
+ @staticmethod
767
+ def _reorder_cache(past_key_values, beam_idx):
768
+ reordered_past = ()
769
+ for layer_past in past_key_values:
770
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
771
+ return reordered_past
772
+
minigpt4/processors/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.processors.base_processor import BaseProcessor
9
+ from minigpt4.processors.blip_processors import (
10
+ Blip2ImageTrainProcessor,
11
+ Blip2ImageEvalProcessor,
12
+ BlipCaptionProcessor,
13
+ )
14
+
15
+ from minigpt4.common.registry import registry
16
+
17
+ __all__ = [
18
+ "BaseProcessor",
19
+ "Blip2ImageTrainProcessor",
20
+ "Blip2ImageEvalProcessor",
21
+ "BlipCaptionProcessor",
22
+ ]
23
+
24
+
25
+ def load_processor(name, cfg=None):
26
+ """
27
+ Example
28
+
29
+ >>> processor = load_processor("alpro_video_train", cfg=None)
30
+ """
31
+ processor = registry.get_processor_class(name).from_config(cfg)
32
+
33
+ return processor
minigpt4/processors/base_processor.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ class BaseProcessor:
12
+ def __init__(self):
13
+ self.transform = lambda x: x
14
+ return
15
+
16
+ def __call__(self, item):
17
+ return self.transform(item)
18
+
19
+ @classmethod
20
+ def from_config(cls, cfg=None):
21
+ return cls()
22
+
23
+ def build(self, **kwargs):
24
+ cfg = OmegaConf.create(kwargs)
25
+
26
+ return self.from_config(cfg)
minigpt4/processors/blip_processors.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import re
9
+
10
+ from minigpt4.common.registry import registry
11
+ from minigpt4.processors.base_processor import BaseProcessor
12
+ from minigpt4.processors.randaugment import RandomAugment
13
+ from omegaconf import OmegaConf
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+
17
+
18
+ class BlipImageBaseProcessor(BaseProcessor):
19
+ def __init__(self, mean=None, std=None):
20
+ if mean is None:
21
+ mean = (0.48145466, 0.4578275, 0.40821073)
22
+ if std is None:
23
+ std = (0.26862954, 0.26130258, 0.27577711)
24
+
25
+ self.normalize = transforms.Normalize(mean, std)
26
+
27
+
28
+ @registry.register_processor("blip_caption")
29
+ class BlipCaptionProcessor(BaseProcessor):
30
+ def __init__(self, prompt="", max_words=50):
31
+ self.prompt = prompt
32
+ self.max_words = max_words
33
+
34
+ def __call__(self, caption):
35
+ caption = self.prompt + self.pre_caption(caption)
36
+
37
+ return caption
38
+
39
+ @classmethod
40
+ def from_config(cls, cfg=None):
41
+ if cfg is None:
42
+ cfg = OmegaConf.create()
43
+
44
+ prompt = cfg.get("prompt", "")
45
+ max_words = cfg.get("max_words", 50)
46
+
47
+ return cls(prompt=prompt, max_words=max_words)
48
+
49
+ def pre_caption(self, caption):
50
+ caption = re.sub(
51
+ r"([.!\"()*#:;~])",
52
+ " ",
53
+ caption.lower(),
54
+ )
55
+ caption = re.sub(
56
+ r"\s{2,}",
57
+ " ",
58
+ caption,
59
+ )
60
+ caption = caption.rstrip("\n")
61
+ caption = caption.strip(" ")
62
+
63
+ # truncate caption
64
+ caption_words = caption.split(" ")
65
+ if len(caption_words) > self.max_words:
66
+ caption = " ".join(caption_words[: self.max_words])
67
+
68
+ return caption
69
+
70
+
71
+ @registry.register_processor("blip2_image_train")
72
+ class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
73
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
74
+ super().__init__(mean=mean, std=std)
75
+
76
+ self.transform = transforms.Compose(
77
+ [
78
+ transforms.RandomResizedCrop(
79
+ image_size,
80
+ scale=(min_scale, max_scale),
81
+ interpolation=InterpolationMode.BICUBIC,
82
+ ),
83
+ transforms.ToTensor(),
84
+ self.normalize,
85
+ ]
86
+ )
87
+
88
+ def __call__(self, item):
89
+ return self.transform(item)
90
+
91
+ @classmethod
92
+ def from_config(cls, cfg=None):
93
+ if cfg is None:
94
+ cfg = OmegaConf.create()
95
+
96
+ image_size = cfg.get("image_size", 224)
97
+
98
+ mean = cfg.get("mean", None)
99
+ std = cfg.get("std", None)
100
+
101
+ min_scale = cfg.get("min_scale", 0.5)
102
+ max_scale = cfg.get("max_scale", 1.0)
103
+
104
+ return cls(
105
+ image_size=image_size,
106
+ mean=mean,
107
+ std=std,
108
+ min_scale=min_scale,
109
+ max_scale=max_scale,
110
+ )
111
+
112
+
113
+ @registry.register_processor("blip2_image_eval")
114
+ class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
115
+ def __init__(self, image_size=224, mean=None, std=None):
116
+ super().__init__(mean=mean, std=std)
117
+
118
+ self.transform = transforms.Compose(
119
+ [
120
+ transforms.Resize(
121
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
122
+ ),
123
+ transforms.ToTensor(),
124
+ self.normalize,
125
+ ]
126
+ )
127
+
128
+ def __call__(self, item):
129
+ return self.transform(item)
130
+
131
+ @classmethod
132
+ def from_config(cls, cfg=None):
133
+ if cfg is None:
134
+ cfg = OmegaConf.create()
135
+
136
+ image_size = cfg.get("image_size", 224)
137
+
138
+ mean = cfg.get("mean", None)
139
+ std = cfg.get("std", None)
140
+
141
+ return cls(image_size=image_size, mean=mean, std=std)
minigpt4/processors/randaugment.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch
12
+
13
+
14
+ ## aug functions
15
+ def identity_func(img):
16
+ return img
17
+
18
+
19
+ def autocontrast_func(img, cutoff=0):
20
+ """
21
+ same output as PIL.ImageOps.autocontrast
22
+ """
23
+ n_bins = 256
24
+
25
+ def tune_channel(ch):
26
+ n = ch.size
27
+ cut = cutoff * n // 100
28
+ if cut == 0:
29
+ high, low = ch.max(), ch.min()
30
+ else:
31
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
32
+ low = np.argwhere(np.cumsum(hist) > cut)
33
+ low = 0 if low.shape[0] == 0 else low[0]
34
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
35
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
36
+ if high <= low:
37
+ table = np.arange(n_bins)
38
+ else:
39
+ scale = (n_bins - 1) / (high - low)
40
+ offset = -low * scale
41
+ table = np.arange(n_bins) * scale + offset
42
+ table[table < 0] = 0
43
+ table[table > n_bins - 1] = n_bins - 1
44
+ table = table.clip(0, 255).astype(np.uint8)
45
+ return table[ch]
46
+
47
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
48
+ out = cv2.merge(channels)
49
+ return out
50
+
51
+
52
+ def equalize_func(img):
53
+ """
54
+ same output as PIL.ImageOps.equalize
55
+ PIL's implementation is different from cv2.equalize
56
+ """
57
+ n_bins = 256
58
+
59
+ def tune_channel(ch):
60
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
61
+ non_zero_hist = hist[hist != 0].reshape(-1)
62
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
63
+ if step == 0:
64
+ return ch
65
+ n = np.empty_like(hist)
66
+ n[0] = step // 2
67
+ n[1:] = hist[:-1]
68
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
69
+ return table[ch]
70
+
71
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
72
+ out = cv2.merge(channels)
73
+ return out
74
+
75
+
76
+ def rotate_func(img, degree, fill=(0, 0, 0)):
77
+ """
78
+ like PIL, rotate by degree, not radians
79
+ """
80
+ H, W = img.shape[0], img.shape[1]
81
+ center = W / 2, H / 2
82
+ M = cv2.getRotationMatrix2D(center, degree, 1)
83
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
84
+ return out
85
+
86
+
87
+ def solarize_func(img, thresh=128):
88
+ """
89
+ same output as PIL.ImageOps.posterize
90
+ """
91
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
92
+ table = table.clip(0, 255).astype(np.uint8)
93
+ out = table[img]
94
+ return out
95
+
96
+
97
+ def color_func(img, factor):
98
+ """
99
+ same output as PIL.ImageEnhance.Color
100
+ """
101
+ ## implementation according to PIL definition, quite slow
102
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
103
+ # out = blend(degenerate, img, factor)
104
+ # M = (
105
+ # np.eye(3) * factor
106
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
107
+ # )[np.newaxis, np.newaxis, :]
108
+ M = np.float32(
109
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
110
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
111
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
112
+ return out
113
+
114
+
115
+ def contrast_func(img, factor):
116
+ """
117
+ same output as PIL.ImageEnhance.Contrast
118
+ """
119
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
120
+ table = (
121
+ np.array([(el - mean) * factor + mean for el in range(256)])
122
+ .clip(0, 255)
123
+ .astype(np.uint8)
124
+ )
125
+ out = table[img]
126
+ return out
127
+
128
+
129
+ def brightness_func(img, factor):
130
+ """
131
+ same output as PIL.ImageEnhance.Contrast
132
+ """
133
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
134
+ out = table[img]
135
+ return out
136
+
137
+
138
+ def sharpness_func(img, factor):
139
+ """
140
+ The differences the this result and PIL are all on the 4 boundaries, the center
141
+ areas are same
142
+ """
143
+ kernel = np.ones((3, 3), dtype=np.float32)
144
+ kernel[1][1] = 5
145
+ kernel /= 13
146
+ degenerate = cv2.filter2D(img, -1, kernel)
147
+ if factor == 0.0:
148
+ out = degenerate
149
+ elif factor == 1.0:
150
+ out = img
151
+ else:
152
+ out = img.astype(np.float32)
153
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
154
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
155
+ out = out.astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
160
+ H, W = img.shape[0], img.shape[1]
161
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
162
+ out = cv2.warpAffine(
163
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
164
+ ).astype(np.uint8)
165
+ return out
166
+
167
+
168
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
169
+ """
170
+ same output as PIL.Image.transform
171
+ """
172
+ H, W = img.shape[0], img.shape[1]
173
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
174
+ out = cv2.warpAffine(
175
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
176
+ ).astype(np.uint8)
177
+ return out
178
+
179
+
180
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
181
+ """
182
+ same output as PIL.Image.transform
183
+ """
184
+ H, W = img.shape[0], img.shape[1]
185
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
186
+ out = cv2.warpAffine(
187
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
188
+ ).astype(np.uint8)
189
+ return out
190
+
191
+
192
+ def posterize_func(img, bits):
193
+ """
194
+ same output as PIL.ImageOps.posterize
195
+ """
196
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
197
+ return out
198
+
199
+
200
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
201
+ H, W = img.shape[0], img.shape[1]
202
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
203
+ out = cv2.warpAffine(
204
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
205
+ ).astype(np.uint8)
206
+ return out
207
+
208
+
209
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
210
+ replace = np.array(replace, dtype=np.uint8)
211
+ H, W = img.shape[0], img.shape[1]
212
+ rh, rw = np.random.random(2)
213
+ pad_size = pad_size // 2
214
+ ch, cw = int(rh * H), int(rw * W)
215
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
216
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
217
+ out = img.copy()
218
+ out[x1:x2, y1:y2, :] = replace
219
+ return out
220
+
221
+
222
+ ### level to args
223
+ def enhance_level_to_args(MAX_LEVEL):
224
+ def level_to_args(level):
225
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
226
+
227
+ return level_to_args
228
+
229
+
230
+ def shear_level_to_args(MAX_LEVEL, replace_value):
231
+ def level_to_args(level):
232
+ level = (level / MAX_LEVEL) * 0.3
233
+ if np.random.random() > 0.5:
234
+ level = -level
235
+ return (level, replace_value)
236
+
237
+ return level_to_args
238
+
239
+
240
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
241
+ def level_to_args(level):
242
+ level = (level / MAX_LEVEL) * float(translate_const)
243
+ if np.random.random() > 0.5:
244
+ level = -level
245
+ return (level, replace_value)
246
+
247
+ return level_to_args
248
+
249
+
250
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
251
+ def level_to_args(level):
252
+ level = int((level / MAX_LEVEL) * cutout_const)
253
+ return (level, replace_value)
254
+
255
+ return level_to_args
256
+
257
+
258
+ def solarize_level_to_args(MAX_LEVEL):
259
+ def level_to_args(level):
260
+ level = int((level / MAX_LEVEL) * 256)
261
+ return (level,)
262
+
263
+ return level_to_args
264
+
265
+
266
+ def none_level_to_args(level):
267
+ return ()
268
+
269
+
270
+ def posterize_level_to_args(MAX_LEVEL):
271
+ def level_to_args(level):
272
+ level = int((level / MAX_LEVEL) * 4)
273
+ return (level,)
274
+
275
+ return level_to_args
276
+
277
+
278
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
279
+ def level_to_args(level):
280
+ level = (level / MAX_LEVEL) * 30
281
+ if np.random.random() < 0.5:
282
+ level = -level
283
+ return (level, replace_value)
284
+
285
+ return level_to_args
286
+
287
+
288
+ func_dict = {
289
+ "Identity": identity_func,
290
+ "AutoContrast": autocontrast_func,
291
+ "Equalize": equalize_func,
292
+ "Rotate": rotate_func,
293
+ "Solarize": solarize_func,
294
+ "Color": color_func,
295
+ "Contrast": contrast_func,
296
+ "Brightness": brightness_func,
297
+ "Sharpness": sharpness_func,
298
+ "ShearX": shear_x_func,
299
+ "TranslateX": translate_x_func,
300
+ "TranslateY": translate_y_func,
301
+ "Posterize": posterize_func,
302
+ "ShearY": shear_y_func,
303
+ }
304
+
305
+ translate_const = 10
306
+ MAX_LEVEL = 10
307
+ replace_value = (128, 128, 128)
308
+ arg_dict = {
309
+ "Identity": none_level_to_args,
310
+ "AutoContrast": none_level_to_args,
311
+ "Equalize": none_level_to_args,
312
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
313
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
314
+ "Color": enhance_level_to_args(MAX_LEVEL),
315
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
316
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
317
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
318
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
319
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
320
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
321
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
322
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
323
+ }
324
+
325
+
326
+ class RandomAugment(object):
327
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
328
+ self.N = N
329
+ self.M = M
330
+ self.isPIL = isPIL
331
+ if augs:
332
+ self.augs = augs
333
+ else:
334
+ self.augs = list(arg_dict.keys())
335
+
336
+ def get_random_ops(self):
337
+ sampled_ops = np.random.choice(self.augs, self.N)
338
+ return [(op, 0.5, self.M) for op in sampled_ops]
339
+
340
+ def __call__(self, img):
341
+ if self.isPIL:
342
+ img = np.array(img)
343
+ ops = self.get_random_ops()
344
+ for name, prob, level in ops:
345
+ if np.random.random() > prob:
346
+ continue
347
+ args = arg_dict[name](level)
348
+ img = func_dict[name](img, *args)
349
+ return img
350
+
351
+
352
+ class VideoRandomAugment(object):
353
+ def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
354
+ self.N = N
355
+ self.M = M
356
+ self.p = p
357
+ self.tensor_in_tensor_out = tensor_in_tensor_out
358
+ if augs:
359
+ self.augs = augs
360
+ else:
361
+ self.augs = list(arg_dict.keys())
362
+
363
+ def get_random_ops(self):
364
+ sampled_ops = np.random.choice(self.augs, self.N, replace=False)
365
+ return [(op, self.M) for op in sampled_ops]
366
+
367
+ def __call__(self, frames):
368
+ assert (
369
+ frames.shape[-1] == 3
370
+ ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
371
+
372
+ if self.tensor_in_tensor_out:
373
+ frames = frames.numpy().astype(np.uint8)
374
+
375
+ num_frames = frames.shape[0]
376
+
377
+ ops = num_frames * [self.get_random_ops()]
378
+ apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
379
+
380
+ frames = torch.stack(
381
+ list(map(self._aug, frames, ops, apply_or_not)), dim=0
382
+ ).float()
383
+
384
+ return frames
385
+
386
+ def _aug(self, img, ops, apply_or_not):
387
+ for i, (name, level) in enumerate(ops):
388
+ if not apply_or_not[i]:
389
+ continue
390
+ args = arg_dict[name](level)
391
+ img = func_dict[name](img, *args)
392
+ return torch.from_numpy(img)
393
+
394
+
395
+ if __name__ == "__main__":
396
+ a = RandomAugment()
397
+ img = np.random.randn(32, 32, 3)
398
+ a(img)
minigpt4/runners/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.runners.runner_base import RunnerBase
9
+
10
+ __all__ = ["RunnerBase"]
minigpt4/runners/runner_base.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import json
10
+ import logging
11
+ import os
12
+ import time
13
+ from pathlib import Path
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ import webdataset as wds
18
+ from minigpt4.common.dist_utils import (
19
+ download_cached_file,
20
+ get_rank,
21
+ get_world_size,
22
+ is_main_process,
23
+ main_process,
24
+ )
25
+ from minigpt4.common.registry import registry
26
+ from minigpt4.common.utils import is_url
27
+ from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
28
+ from minigpt4.datasets.datasets.dataloader_utils import (
29
+ IterLoader,
30
+ MultiIterLoader,
31
+ PrefetchLoader,
32
+ )
33
+ from torch.nn.parallel import DistributedDataParallel as DDP
34
+ from torch.utils.data import DataLoader, DistributedSampler
35
+
36
+
37
+ @registry.register_runner("runner_base")
38
+ class RunnerBase:
39
+ """
40
+ A runner class to train and evaluate a model given a task and datasets.
41
+
42
+ The runner uses pytorch distributed data parallel by default. Future release
43
+ will support other distributed frameworks.
44
+ """
45
+
46
+ def __init__(self, cfg, task, model, datasets, job_id):
47
+ self.config = cfg
48
+ self.job_id = job_id
49
+
50
+ self.task = task
51
+ self.datasets = datasets
52
+
53
+ self._model = model
54
+
55
+ self._wrapped_model = None
56
+ self._device = None
57
+ self._optimizer = None
58
+ self._scaler = None
59
+ self._dataloaders = None
60
+ self._lr_sched = None
61
+
62
+ self.start_epoch = 0
63
+
64
+ # self.setup_seeds()
65
+ self.setup_output_dir()
66
+
67
+ @property
68
+ def device(self):
69
+ if self._device is None:
70
+ self._device = torch.device(self.config.run_cfg.device)
71
+
72
+ return self._device
73
+
74
+ @property
75
+ def use_distributed(self):
76
+ return self.config.run_cfg.distributed
77
+
78
+ @property
79
+ def model(self):
80
+ """
81
+ A property to get the DDP-wrapped model on the device.
82
+ """
83
+ # move model to device
84
+ if self._model.device != self.device:
85
+ self._model = self._model.to(self.device)
86
+
87
+ # distributed training wrapper
88
+ if self.use_distributed:
89
+ if self._wrapped_model is None:
90
+ self._wrapped_model = DDP(
91
+ self._model, device_ids=[self.config.run_cfg.gpu]
92
+ )
93
+ else:
94
+ self._wrapped_model = self._model
95
+
96
+ return self._wrapped_model
97
+
98
+ @property
99
+ def optimizer(self):
100
+ # TODO make optimizer class and configurations
101
+ if self._optimizer is None:
102
+ num_parameters = 0
103
+ p_wd, p_non_wd = [], []
104
+ for n, p in self.model.named_parameters():
105
+ if not p.requires_grad:
106
+ continue # frozen weights
107
+ print(n)
108
+ if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
109
+ p_non_wd.append(p)
110
+ else:
111
+ p_wd.append(p)
112
+ num_parameters += p.data.nelement()
113
+ logging.info("number of trainable parameters: %d" % num_parameters)
114
+ optim_params = [
115
+ {
116
+ "params": p_wd,
117
+ "weight_decay": float(self.config.run_cfg.weight_decay),
118
+ },
119
+ {"params": p_non_wd, "weight_decay": 0},
120
+ ]
121
+ beta2 = self.config.run_cfg.get("beta2", 0.999)
122
+ self._optimizer = torch.optim.AdamW(
123
+ optim_params,
124
+ lr=float(self.config.run_cfg.init_lr),
125
+ weight_decay=float(self.config.run_cfg.weight_decay),
126
+ betas=(0.9, beta2),
127
+ )
128
+
129
+ return self._optimizer
130
+
131
+ @property
132
+ def scaler(self):
133
+ amp = self.config.run_cfg.get("amp", False)
134
+
135
+ if amp:
136
+ if self._scaler is None:
137
+ self._scaler = torch.cuda.amp.GradScaler()
138
+
139
+ return self._scaler
140
+
141
+ @property
142
+ def lr_scheduler(self):
143
+ """
144
+ A property to get and create learning rate scheduler by split just in need.
145
+ """
146
+ if self._lr_sched is None:
147
+ lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
148
+
149
+ # max_epoch = self.config.run_cfg.max_epoch
150
+ max_epoch = self.max_epoch
151
+ # min_lr = self.config.run_cfg.min_lr
152
+ min_lr = self.min_lr
153
+ # init_lr = self.config.run_cfg.init_lr
154
+ init_lr = self.init_lr
155
+
156
+ # optional parameters
157
+ decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
158
+ warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
159
+ warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
160
+ iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None)
161
+
162
+ if iters_per_epoch is None:
163
+ try:
164
+ iters_per_epoch = len(self.dataloaders['train'])
165
+ except (AttributeError, TypeError):
166
+ iters_per_epoch = 10000
167
+
168
+ self._lr_sched = lr_sched_cls(
169
+ optimizer=self.optimizer,
170
+ max_epoch=max_epoch,
171
+ iters_per_epoch=iters_per_epoch,
172
+ min_lr=min_lr,
173
+ init_lr=init_lr,
174
+ decay_rate=decay_rate,
175
+ warmup_start_lr=warmup_start_lr,
176
+ warmup_steps=warmup_steps,
177
+ )
178
+
179
+ return self._lr_sched
180
+
181
+ @property
182
+ def dataloaders(self) -> dict:
183
+ """
184
+ A property to get and create dataloaders by split just in need.
185
+
186
+ If no train_dataset_ratio is provided, concatenate map-style datasets and
187
+ chain wds.DataPipe datasets separately. Training set becomes a tuple
188
+ (ConcatDataset, ChainDataset), both are optional but at least one of them is
189
+ required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
190
+
191
+ If train_dataset_ratio is provided, create a MultiIterLoader to sample
192
+ each dataset by ratios during training.
193
+
194
+ Currently do not support multiple datasets for validation and test.
195
+
196
+ Returns:
197
+ dict: {split_name: (tuples of) dataloader}
198
+ """
199
+ if self._dataloaders is None:
200
+
201
+ # concatenate map-style datasets and chain wds.DataPipe datasets separately
202
+ # training set becomes a tuple (ConcatDataset, ChainDataset), both are
203
+ # optional but at least one of them is required. The resultant ConcatDataset
204
+ # and ChainDataset will be sampled evenly.
205
+ logging.info(
206
+ "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
207
+ )
208
+
209
+ datasets = reorg_datasets_by_split(self.datasets)
210
+ self.datasets = datasets
211
+ # self.datasets = concat_datasets(datasets)
212
+
213
+ # print dataset statistics after concatenation/chaining
214
+ for split_name in self.datasets:
215
+ if isinstance(self.datasets[split_name], tuple) or isinstance(
216
+ self.datasets[split_name], list
217
+ ):
218
+ # mixed wds.DataPipeline and torch.utils.data.Dataset
219
+ num_records = sum(
220
+ [
221
+ len(d)
222
+ if not type(d) in [wds.DataPipeline, ChainDataset]
223
+ else 0
224
+ for d in self.datasets[split_name]
225
+ ]
226
+ )
227
+
228
+ else:
229
+ if hasattr(self.datasets[split_name], "__len__"):
230
+ # a single map-style dataset
231
+ num_records = len(self.datasets[split_name])
232
+ else:
233
+ # a single wds.DataPipeline
234
+ num_records = -1
235
+ logging.info(
236
+ "Only a single wds.DataPipeline dataset, no __len__ attribute."
237
+ )
238
+
239
+ if num_records >= 0:
240
+ logging.info(
241
+ "Loaded {} records for {} split from the dataset.".format(
242
+ num_records, split_name
243
+ )
244
+ )
245
+
246
+ # create dataloaders
247
+ split_names = sorted(self.datasets.keys())
248
+
249
+ datasets = [self.datasets[split] for split in split_names]
250
+ is_trains = [split in self.train_splits for split in split_names]
251
+
252
+ batch_sizes = [
253
+ self.config.run_cfg.batch_size_train
254
+ if split == "train"
255
+ else self.config.run_cfg.batch_size_eval
256
+ for split in split_names
257
+ ]
258
+
259
+ collate_fns = []
260
+ for dataset in datasets:
261
+ if isinstance(dataset, tuple) or isinstance(dataset, list):
262
+ collate_fns.append([getattr(d, "collater", None) for d in dataset])
263
+ else:
264
+ collate_fns.append(getattr(dataset, "collater", None))
265
+
266
+ dataloaders = self.create_loaders(
267
+ datasets=datasets,
268
+ num_workers=self.config.run_cfg.num_workers,
269
+ batch_sizes=batch_sizes,
270
+ is_trains=is_trains,
271
+ collate_fns=collate_fns,
272
+ )
273
+
274
+ self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
275
+
276
+ return self._dataloaders
277
+
278
+ @property
279
+ def cuda_enabled(self):
280
+ return self.device.type == "cuda"
281
+
282
+ @property
283
+ def max_epoch(self):
284
+ return int(self.config.run_cfg.max_epoch)
285
+
286
+ @property
287
+ def log_freq(self):
288
+ log_freq = self.config.run_cfg.get("log_freq", 50)
289
+ return int(log_freq)
290
+
291
+ @property
292
+ def init_lr(self):
293
+ return float(self.config.run_cfg.init_lr)
294
+
295
+ @property
296
+ def min_lr(self):
297
+ return float(self.config.run_cfg.min_lr)
298
+
299
+ @property
300
+ def accum_grad_iters(self):
301
+ return int(self.config.run_cfg.get("accum_grad_iters", 1))
302
+
303
+ @property
304
+ def valid_splits(self):
305
+ valid_splits = self.config.run_cfg.get("valid_splits", [])
306
+
307
+ if len(valid_splits) == 0:
308
+ logging.info("No validation splits found.")
309
+
310
+ return valid_splits
311
+
312
+ @property
313
+ def test_splits(self):
314
+ test_splits = self.config.run_cfg.get("test_splits", [])
315
+
316
+ return test_splits
317
+
318
+ @property
319
+ def train_splits(self):
320
+ train_splits = self.config.run_cfg.get("train_splits", [])
321
+
322
+ if len(train_splits) == 0:
323
+ logging.info("Empty train splits.")
324
+
325
+ return train_splits
326
+
327
+ @property
328
+ def evaluate_only(self):
329
+ """
330
+ Set to True to skip training.
331
+ """
332
+ return self.config.run_cfg.evaluate
333
+
334
+ @property
335
+ def use_dist_eval_sampler(self):
336
+ return self.config.run_cfg.get("use_dist_eval_sampler", True)
337
+
338
+ @property
339
+ def resume_ckpt_path(self):
340
+ return self.config.run_cfg.get("resume_ckpt_path", None)
341
+
342
+ @property
343
+ def train_loader(self):
344
+ train_dataloader = self.dataloaders["train"]
345
+
346
+ return train_dataloader
347
+
348
+ def setup_output_dir(self):
349
+ lib_root = Path(registry.get_path("library_root"))
350
+
351
+ output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
352
+ result_dir = output_dir / "result"
353
+
354
+ output_dir.mkdir(parents=True, exist_ok=True)
355
+ result_dir.mkdir(parents=True, exist_ok=True)
356
+
357
+ registry.register_path("result_dir", str(result_dir))
358
+ registry.register_path("output_dir", str(output_dir))
359
+
360
+ self.result_dir = result_dir
361
+ self.output_dir = output_dir
362
+
363
+ def train(self):
364
+ start_time = time.time()
365
+ best_agg_metric = 0
366
+ best_epoch = 0
367
+
368
+ self.log_config()
369
+
370
+ # resume from checkpoint if specified
371
+ if not self.evaluate_only and self.resume_ckpt_path is not None:
372
+ self._load_checkpoint(self.resume_ckpt_path)
373
+
374
+ for cur_epoch in range(self.start_epoch, self.max_epoch):
375
+ # training phase
376
+ if not self.evaluate_only:
377
+ logging.info("Start training")
378
+ train_stats = self.train_epoch(cur_epoch)
379
+ self.log_stats(split_name="train", stats=train_stats)
380
+
381
+ # evaluation phase
382
+ if len(self.valid_splits) > 0:
383
+ for split_name in self.valid_splits:
384
+ logging.info("Evaluating on {}.".format(split_name))
385
+
386
+ val_log = self.eval_epoch(
387
+ split_name=split_name, cur_epoch=cur_epoch
388
+ )
389
+ if val_log is not None:
390
+ if is_main_process():
391
+ assert (
392
+ "agg_metrics" in val_log
393
+ ), "No agg_metrics found in validation log."
394
+
395
+ agg_metrics = val_log["agg_metrics"]
396
+ if agg_metrics > best_agg_metric and split_name == "val":
397
+ best_epoch, best_agg_metric = cur_epoch, agg_metrics
398
+
399
+ self._save_checkpoint(cur_epoch, is_best=True)
400
+
401
+ val_log.update({"best_epoch": best_epoch})
402
+ self.log_stats(val_log, split_name)
403
+
404
+ else:
405
+ # if no validation split is provided, we just save the checkpoint at the end of each epoch.
406
+ if not self.evaluate_only:
407
+ self._save_checkpoint(cur_epoch, is_best=False)
408
+
409
+ if self.evaluate_only:
410
+ break
411
+
412
+ if self.config.run_cfg.distributed:
413
+ dist.barrier()
414
+
415
+ # testing phase
416
+ test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
417
+ self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
418
+
419
+ total_time = time.time() - start_time
420
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
421
+ logging.info("Training time {}".format(total_time_str))
422
+
423
+ def evaluate(self, cur_epoch="best", skip_reload=False):
424
+ test_logs = dict()
425
+
426
+ if len(self.test_splits) > 0:
427
+ for split_name in self.test_splits:
428
+ test_logs[split_name] = self.eval_epoch(
429
+ split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
430
+ )
431
+
432
+ return test_logs
433
+
434
+ def train_epoch(self, epoch):
435
+ # train
436
+ self.model.train()
437
+
438
+ return self.task.train_epoch(
439
+ epoch=epoch,
440
+ model=self.model,
441
+ data_loader=self.train_loader,
442
+ optimizer=self.optimizer,
443
+ scaler=self.scaler,
444
+ lr_scheduler=self.lr_scheduler,
445
+ cuda_enabled=self.cuda_enabled,
446
+ log_freq=self.log_freq,
447
+ accum_grad_iters=self.accum_grad_iters,
448
+ )
449
+
450
+ @torch.no_grad()
451
+ def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
452
+ """
453
+ Evaluate the model on a given split.
454
+
455
+ Args:
456
+ split_name (str): name of the split to evaluate on.
457
+ cur_epoch (int): current epoch.
458
+ skip_reload_best (bool): whether to skip reloading the best checkpoint.
459
+ During training, we will reload the best checkpoint for validation.
460
+ During testing, we will use provided weights and skip reloading the best checkpoint .
461
+ """
462
+ data_loader = self.dataloaders.get(split_name, None)
463
+ assert data_loader, "data_loader for split {} is None.".format(split_name)
464
+
465
+ # TODO In validation, you need to compute loss as well as metrics
466
+ # TODO consider moving to model.before_evaluation()
467
+ model = self.unwrap_dist_model(self.model)
468
+ if not skip_reload and cur_epoch == "best":
469
+ model = self._reload_best_model(model)
470
+ model.eval()
471
+
472
+ self.task.before_evaluation(
473
+ model=model,
474
+ dataset=self.datasets[split_name],
475
+ )
476
+ results = self.task.evaluation(model, data_loader)
477
+
478
+ if results is not None:
479
+ return self.task.after_evaluation(
480
+ val_result=results,
481
+ split_name=split_name,
482
+ epoch=cur_epoch,
483
+ )
484
+
485
+ def unwrap_dist_model(self, model):
486
+ if self.use_distributed:
487
+ return model.module
488
+ else:
489
+ return model
490
+
491
+ def create_loaders(
492
+ self,
493
+ datasets,
494
+ num_workers,
495
+ batch_sizes,
496
+ is_trains,
497
+ collate_fns,
498
+ dataset_ratios=None,
499
+ ):
500
+ """
501
+ Create dataloaders for training and validation.
502
+ """
503
+
504
+ def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
505
+ # create a single dataloader for each split
506
+ if isinstance(dataset, ChainDataset) or isinstance(
507
+ dataset, wds.DataPipeline
508
+ ):
509
+ # wds.WebdDataset instance are chained together
510
+ # webdataset.DataPipeline has its own sampler and collate_fn
511
+ loader = iter(
512
+ DataLoader(
513
+ dataset,
514
+ batch_size=bsz,
515
+ num_workers=num_workers,
516
+ pin_memory=True,
517
+ )
518
+ )
519
+ else:
520
+ # map-style dataset are concatenated together
521
+ # setup distributed sampler
522
+ if self.use_distributed:
523
+ sampler = DistributedSampler(
524
+ dataset,
525
+ shuffle=is_train,
526
+ num_replicas=get_world_size(),
527
+ rank=get_rank(),
528
+ )
529
+ if not self.use_dist_eval_sampler:
530
+ # e.g. retrieval evaluation
531
+ sampler = sampler if is_train else None
532
+ else:
533
+ sampler = None
534
+
535
+ loader = DataLoader(
536
+ dataset,
537
+ batch_size=bsz,
538
+ num_workers=num_workers,
539
+ pin_memory=True,
540
+ sampler=sampler,
541
+ shuffle=sampler is None and is_train,
542
+ collate_fn=collate_fn,
543
+ drop_last=True if is_train else False,
544
+ )
545
+ loader = PrefetchLoader(loader)
546
+
547
+ if is_train:
548
+ loader = IterLoader(loader, use_distributed=self.use_distributed)
549
+
550
+ return loader
551
+
552
+ loaders = []
553
+
554
+ for dataset, bsz, is_train, collate_fn in zip(
555
+ datasets, batch_sizes, is_trains, collate_fns
556
+ ):
557
+ if isinstance(dataset, list) or isinstance(dataset, tuple):
558
+ if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None:
559
+ dataset_ratios = [d.sample_ratio for d in dataset]
560
+ loader = MultiIterLoader(
561
+ loaders=[
562
+ _create_loader(d, num_workers, bsz, is_train, collate_fn[i])
563
+ for i, d in enumerate(dataset)
564
+ ],
565
+ ratios=dataset_ratios,
566
+ )
567
+ else:
568
+ loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
569
+
570
+ loaders.append(loader)
571
+
572
+ return loaders
573
+
574
+ @main_process
575
+ def _save_checkpoint(self, cur_epoch, is_best=False):
576
+ """
577
+ Save the checkpoint at the current epoch.
578
+ """
579
+ model_no_ddp = self.unwrap_dist_model(self.model)
580
+ param_grad_dic = {
581
+ k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
582
+ }
583
+ state_dict = model_no_ddp.state_dict()
584
+ for k in list(state_dict.keys()):
585
+ if k in param_grad_dic.keys() and not param_grad_dic[k]:
586
+ # delete parameters that do not require gradient
587
+ del state_dict[k]
588
+ save_obj = {
589
+ "model": state_dict,
590
+ "optimizer": self.optimizer.state_dict(),
591
+ "config": self.config.to_dict(),
592
+ "scaler": self.scaler.state_dict() if self.scaler else None,
593
+ "epoch": cur_epoch,
594
+ }
595
+ save_to = os.path.join(
596
+ self.output_dir,
597
+ "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
598
+ )
599
+ logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
600
+ torch.save(save_obj, save_to)
601
+
602
+ def _reload_best_model(self, model):
603
+ """
604
+ Load the best checkpoint for evaluation.
605
+ """
606
+ checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
607
+
608
+ logging.info("Loading checkpoint from {}.".format(checkpoint_path))
609
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
610
+ try:
611
+ model.load_state_dict(checkpoint["model"])
612
+ except RuntimeError as e:
613
+ logging.warning(
614
+ """
615
+ Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
616
+ Trying to load the model with strict=False.
617
+ """
618
+ )
619
+ model.load_state_dict(checkpoint["model"], strict=False)
620
+ return model
621
+
622
+ def _load_checkpoint(self, url_or_filename):
623
+ """
624
+ Resume from a checkpoint.
625
+ """
626
+ if is_url(url_or_filename):
627
+ cached_file = download_cached_file(
628
+ url_or_filename, check_hash=False, progress=True
629
+ )
630
+ checkpoint = torch.load(cached_file, map_location=self.device, strict=False)
631
+ elif os.path.isfile(url_or_filename):
632
+ checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False)
633
+ else:
634
+ raise RuntimeError("checkpoint url or path is invalid")
635
+
636
+ state_dict = checkpoint["model"]
637
+ self.unwrap_dist_model(self.model).load_state_dict(state_dict)
638
+
639
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
640
+ if self.scaler and "scaler" in checkpoint:
641
+ self.scaler.load_state_dict(checkpoint["scaler"])
642
+
643
+ self.start_epoch = checkpoint["epoch"] + 1
644
+ logging.info("Resume checkpoint from {}".format(url_or_filename))
645
+
646
+ @main_process
647
+ def log_stats(self, stats, split_name):
648
+ if isinstance(stats, dict):
649
+ log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
650
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
651
+ f.write(json.dumps(log_stats) + "\n")
652
+ elif isinstance(stats, list):
653
+ pass
654
+
655
+ @main_process
656
+ def log_config(self):
657
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
658
+ f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")