Spaces:
Runtime error
Runtime error
gordonhubackup
commited on
Commit
·
e62d81d
1
Parent(s):
2ba1ee8
upload
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE.txt +14 -0
- MANIFEST.in +2 -0
- README.md +1 -1
- app.py +151 -0
- bliva/__init__.py +29 -0
- bliva/common/config.py +469 -0
- bliva/common/dist_utils.py +137 -0
- bliva/common/gradcam.py +24 -0
- bliva/common/logger.py +193 -0
- bliva/common/optims.py +117 -0
- bliva/common/registry.py +268 -0
- bliva/common/utils.py +424 -0
- bliva/configs/default.yaml +7 -0
- bliva/configs/models/bliva_flant5xxl.yaml +39 -0
- bliva/configs/models/bliva_vicuna7b.yaml +39 -0
- bliva/conversation/__init__.py +0 -0
- bliva/conversation/conversation.py +180 -0
- bliva/models/Qformer.py +1216 -0
- bliva/models/__init__.py +208 -0
- bliva/models/base_model.py +251 -0
- bliva/models/blip2.py +319 -0
- bliva/models/bliva_flant5xxl.py +803 -0
- bliva/models/bliva_vicuna7b.py +783 -0
- bliva/models/clip_vit.py +272 -0
- bliva/models/eva_vit.py +442 -0
- bliva/models/modeling_llama.py +888 -0
- bliva/models/modeling_t5.py +2063 -0
- bliva/models/vit.py +527 -0
- bliva/processors/__init__.py +38 -0
- bliva/processors/base_processor.py +26 -0
- bliva/processors/blip_processors.py +239 -0
- bliva/processors/clip_processors.py +92 -0
- bliva/processors/randaugment.py +398 -0
- bliva_vicuna7b.pth +3 -0
- evaluate.py +93 -0
- hf_vicuna_7b/config.json +23 -0
- hf_vicuna_7b/generation_config.json +7 -0
- hf_vicuna_7b/pytorch_model-00001-of-00002.bin +3 -0
- hf_vicuna_7b/pytorch_model-00002-of-00002.bin +3 -0
- hf_vicuna_7b/pytorch_model.bin.index.json +330 -0
- hf_vicuna_7b/special_tokens_map.json +23 -0
- hf_vicuna_7b/tokenizer.model +3 -0
- hf_vicuna_7b/tokenizer_config.json +33 -0
- images/example.jpg +0 -0
- images/img1.jpg +0 -0
- images/img2.jpg +0 -0
- images/img3.jpg +0 -0
- images/img4.jpg +0 -0
- images/img5.jpg +0 -0
- images/img6.jpg +0 -0
LICENSE.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Salesforce, Inc.
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
11 |
+
|
12 |
+
3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
MANIFEST.in
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
include requirements.txt
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: BLIVA
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: BLIVA
|
3 |
+
emoji: 🚀
|
4 |
colorFrom: purple
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from bliva.common.config import Config
|
11 |
+
from bliva.common.dist_utils import get_rank
|
12 |
+
from bliva.common.registry import registry
|
13 |
+
from bliva.conversation.conversation import Chat, CONV_VISION, CONV_DIRECT
|
14 |
+
|
15 |
+
# imports modules for registration
|
16 |
+
|
17 |
+
from bliva.models import *
|
18 |
+
from bliva.processors import *
|
19 |
+
from bliva.models import load_model_and_preprocess
|
20 |
+
from evaluate import disable_torch_init
|
21 |
+
|
22 |
+
def parse_args():
|
23 |
+
parser = argparse.ArgumentParser(description="Demo")
|
24 |
+
parser.add_argument("--model_name",default='bliva_vicuna', type=str, help='model name')
|
25 |
+
parser.add_argument("--gpu_id", type=int, default=0, help="specify the gpu to load the model.")
|
26 |
+
args = parser.parse_args()
|
27 |
+
return args
|
28 |
+
|
29 |
+
# ========================================
|
30 |
+
# Model Initialization
|
31 |
+
# ========================================
|
32 |
+
|
33 |
+
print('Initializing Chat')
|
34 |
+
args = parse_args()
|
35 |
+
|
36 |
+
if torch.cuda.is_available():
|
37 |
+
device='cuda:{}'.format(args.gpu_id)
|
38 |
+
else:
|
39 |
+
device=torch.device('cpu')
|
40 |
+
|
41 |
+
disable_torch_init()
|
42 |
+
if args.model_name == "blip2_vicuna_instruct":
|
43 |
+
model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="vicuna7b", is_eval=True, device=device)
|
44 |
+
elif args.model_name == "bliva_vicuna":
|
45 |
+
model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="vicuna7b", is_eval=True, device=device)
|
46 |
+
elif args.model_name == "bliva_flant5":
|
47 |
+
model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="flant5xxl", is_eval=True, device=device)
|
48 |
+
else:
|
49 |
+
print("Model not found")
|
50 |
+
|
51 |
+
vis_processor = vis_processors["eval"]
|
52 |
+
|
53 |
+
|
54 |
+
# vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
55 |
+
# vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
56 |
+
chat = Chat(model, vis_processor, device=device)
|
57 |
+
print('Initialization Finished')
|
58 |
+
|
59 |
+
# ========================================
|
60 |
+
# Gradio Setting
|
61 |
+
# ========================================
|
62 |
+
|
63 |
+
def gradio_reset(chat_state, img_list):
|
64 |
+
if chat_state is not None:
|
65 |
+
chat_state.messages = []
|
66 |
+
if img_list is not None:
|
67 |
+
img_list = []
|
68 |
+
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
69 |
+
|
70 |
+
def upload_img(gr_img, text_input, chat_state):
|
71 |
+
if gr_img is None:
|
72 |
+
return None, None, gr.update(interactive=True), chat_state, None
|
73 |
+
chat_state = CONV_DIRECT.copy() #CONV_VISION.copy()
|
74 |
+
img_list = []
|
75 |
+
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
76 |
+
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
77 |
+
|
78 |
+
def gradio_ask(user_message, chatbot, chat_state):
|
79 |
+
if len(user_message) == 0:
|
80 |
+
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
81 |
+
chat.ask(user_message, chat_state)
|
82 |
+
chatbot = chatbot + [[user_message, None]]
|
83 |
+
return '', chatbot, chat_state
|
84 |
+
|
85 |
+
|
86 |
+
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
87 |
+
llm_message = chat.answer(conv=chat_state,
|
88 |
+
img_list=img_list,
|
89 |
+
num_beams=num_beams,
|
90 |
+
temperature=temperature,
|
91 |
+
max_new_tokens=300,
|
92 |
+
max_length=2000)[0]
|
93 |
+
chatbot[-1][1] = llm_message[0]
|
94 |
+
return chatbot, chat_state, img_list
|
95 |
+
|
96 |
+
title = """<h1 align="center">Demo of BLIVA</h1>"""
|
97 |
+
description = """<h3>This is the demo of BLIVA. Upload your images and start chatting!</h3>"""
|
98 |
+
article = """<p><a href='https://gordonhu608.github.io/bliva/'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/mlpc-ucsd/BLIVA'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
99 |
+
"""
|
100 |
+
|
101 |
+
#TODO show examples below
|
102 |
+
|
103 |
+
with gr.Blocks() as demo:
|
104 |
+
gr.Markdown(title)
|
105 |
+
gr.Markdown(description)
|
106 |
+
gr.Markdown(article)
|
107 |
+
|
108 |
+
with gr.Row():
|
109 |
+
with gr.Column(scale=0.5):
|
110 |
+
image = gr.Image(type="pil")
|
111 |
+
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
112 |
+
clear = gr.Button("Restart 🔄")
|
113 |
+
|
114 |
+
num_beams = gr.Slider(
|
115 |
+
minimum=1,
|
116 |
+
maximum=10,
|
117 |
+
value=5,
|
118 |
+
step=1,
|
119 |
+
interactive=True,
|
120 |
+
label="beam search numbers)",
|
121 |
+
)
|
122 |
+
|
123 |
+
temperature = gr.Slider(
|
124 |
+
minimum=0.1,
|
125 |
+
maximum=2.0,
|
126 |
+
value=1.0,
|
127 |
+
step=0.1,
|
128 |
+
interactive=True,
|
129 |
+
label="Temperature",
|
130 |
+
)
|
131 |
+
|
132 |
+
with gr.Column():
|
133 |
+
chat_state = gr.State()
|
134 |
+
img_list = gr.State()
|
135 |
+
chatbot = gr.Chatbot(label='BLIVA')
|
136 |
+
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
137 |
+
|
138 |
+
gr.Examples(examples=[
|
139 |
+
[f"images/example.jpg", "Describe this image in detail."],
|
140 |
+
[f"images/img3.jpg", "What is this image about?"],
|
141 |
+
[f"images/img4.jpg", "What is the title of this movie?"],
|
142 |
+
], inputs=[image, text_input])
|
143 |
+
|
144 |
+
upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
|
145 |
+
|
146 |
+
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
147 |
+
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
148 |
+
)
|
149 |
+
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
|
150 |
+
|
151 |
+
demo.launch(enable_queue=True)
|
bliva/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
|
13 |
+
from bliva.common.registry import registry
|
14 |
+
|
15 |
+
from bliva.models import *
|
16 |
+
from bliva.processors import *
|
17 |
+
|
18 |
+
|
19 |
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
20 |
+
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
|
21 |
+
|
22 |
+
registry.register_path("library_root", root_dir)
|
23 |
+
repo_root = os.path.join(root_dir, "..")
|
24 |
+
registry.register_path("repo_root", repo_root)
|
25 |
+
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
|
26 |
+
registry.register_path("cache_root", cache_root)
|
27 |
+
|
28 |
+
registry.register("MAX_INT", sys.maxsize)
|
29 |
+
registry.register("SPLIT_NAMES", ["train", "val", "test"])
|
bliva/common/config.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import json
|
10 |
+
from typing import Dict
|
11 |
+
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from bliva.common.registry import registry
|
14 |
+
|
15 |
+
|
16 |
+
class Config:
|
17 |
+
def __init__(self, args):
|
18 |
+
self.config = {}
|
19 |
+
|
20 |
+
self.args = args
|
21 |
+
|
22 |
+
# Register the config and configuration for setup
|
23 |
+
registry.register("configuration", self)
|
24 |
+
|
25 |
+
user_config = self._build_opt_list(self.args.options)
|
26 |
+
|
27 |
+
config = OmegaConf.load(self.args.cfg_path)
|
28 |
+
|
29 |
+
runner_config = self.build_runner_config(config)
|
30 |
+
model_config = self.build_model_config(config, **user_config)
|
31 |
+
dataset_config = self.build_dataset_config(config)
|
32 |
+
|
33 |
+
# Validate the user-provided runner configuration
|
34 |
+
# model and dataset configuration are supposed to be validated by the respective classes
|
35 |
+
# [TODO] validate the model/dataset configuration
|
36 |
+
# self._validate_runner_config(runner_config)
|
37 |
+
|
38 |
+
# Override the default configuration with user options.
|
39 |
+
self.config = OmegaConf.merge(
|
40 |
+
runner_config, model_config, dataset_config, user_config
|
41 |
+
)
|
42 |
+
|
43 |
+
def _validate_runner_config(self, runner_config):
|
44 |
+
"""
|
45 |
+
This method validates the configuration, such that
|
46 |
+
1) all the user specified options are valid;
|
47 |
+
2) no type mismatches between the user specified options and the config.
|
48 |
+
"""
|
49 |
+
runner_config_validator = create_runner_config_validator()
|
50 |
+
runner_config_validator.validate(runner_config)
|
51 |
+
|
52 |
+
def _build_opt_list(self, opts):
|
53 |
+
opts_dot_list = self._convert_to_dot_list(opts)
|
54 |
+
return OmegaConf.from_dotlist(opts_dot_list)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def build_model_config(config, **kwargs):
|
58 |
+
model = config.get("model", None)
|
59 |
+
assert model is not None, "Missing model configuration file."
|
60 |
+
|
61 |
+
model_cls = registry.get_model_class(model.arch)
|
62 |
+
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
63 |
+
|
64 |
+
model_type = kwargs.get("model.model_type", None)
|
65 |
+
if not model_type:
|
66 |
+
model_type = model.get("model_type", None)
|
67 |
+
# else use the model type selected by user.
|
68 |
+
|
69 |
+
assert model_type is not None, "Missing model_type."
|
70 |
+
|
71 |
+
model_config_path = model_cls.default_config_path(model_type=model_type)
|
72 |
+
|
73 |
+
model_config = OmegaConf.create()
|
74 |
+
# hiararchy override, customized config > default config
|
75 |
+
model_config = OmegaConf.merge(
|
76 |
+
model_config,
|
77 |
+
OmegaConf.load(model_config_path),
|
78 |
+
{"model": config["model"]},
|
79 |
+
)
|
80 |
+
|
81 |
+
return model_config
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def build_runner_config(config):
|
85 |
+
return {"run": config.run}
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def build_dataset_config(config):
|
89 |
+
datasets = config.get("datasets", None)
|
90 |
+
if datasets is None:
|
91 |
+
raise KeyError(
|
92 |
+
"Expecting 'datasets' as the root key for dataset configuration."
|
93 |
+
)
|
94 |
+
|
95 |
+
dataset_config = OmegaConf.create()
|
96 |
+
|
97 |
+
for dataset_name in datasets:
|
98 |
+
print(f"Building dataset config for {dataset_name}")
|
99 |
+
builder_cls = registry.get_builder_class(dataset_name)
|
100 |
+
|
101 |
+
dataset_config_type = datasets[dataset_name].get("type", "default")
|
102 |
+
dataset_config_path = builder_cls.default_config_path(
|
103 |
+
type=dataset_config_type
|
104 |
+
)
|
105 |
+
|
106 |
+
# hiararchy override, customized config > default config
|
107 |
+
dataset_config = OmegaConf.merge(
|
108 |
+
dataset_config,
|
109 |
+
OmegaConf.load(dataset_config_path),
|
110 |
+
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
111 |
+
)
|
112 |
+
|
113 |
+
return dataset_config
|
114 |
+
|
115 |
+
def _convert_to_dot_list(self, opts):
|
116 |
+
if opts is None:
|
117 |
+
opts = []
|
118 |
+
|
119 |
+
if len(opts) == 0:
|
120 |
+
return opts
|
121 |
+
|
122 |
+
has_equal = opts[0].find("=") != -1
|
123 |
+
|
124 |
+
if has_equal:
|
125 |
+
return opts
|
126 |
+
|
127 |
+
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
128 |
+
|
129 |
+
def get_config(self):
|
130 |
+
return self.config
|
131 |
+
|
132 |
+
@property
|
133 |
+
def run_cfg(self):
|
134 |
+
return self.config.run
|
135 |
+
|
136 |
+
@property
|
137 |
+
def datasets_cfg(self):
|
138 |
+
return self.config.datasets
|
139 |
+
|
140 |
+
@property
|
141 |
+
def model_cfg(self):
|
142 |
+
return self.config.model
|
143 |
+
|
144 |
+
def pretty_print(self):
|
145 |
+
logging.info("\n===== Running Parameters =====")
|
146 |
+
logging.info(self._convert_node_to_json(self.config.run))
|
147 |
+
|
148 |
+
logging.info("\n====== Dataset Attributes ======")
|
149 |
+
datasets = self.config.datasets
|
150 |
+
|
151 |
+
for dataset in datasets:
|
152 |
+
if dataset in self.config.datasets:
|
153 |
+
logging.info(f"\n======== {dataset} =======")
|
154 |
+
dataset_config = self.config.datasets[dataset]
|
155 |
+
logging.info(self._convert_node_to_json(dataset_config))
|
156 |
+
else:
|
157 |
+
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
158 |
+
|
159 |
+
logging.info(f"\n====== Model Attributes ======")
|
160 |
+
logging.info(self._convert_node_to_json(self.config.model))
|
161 |
+
|
162 |
+
def _convert_node_to_json(self, node):
|
163 |
+
container = OmegaConf.to_container(node, resolve=True)
|
164 |
+
return json.dumps(container, indent=4, sort_keys=True)
|
165 |
+
|
166 |
+
def to_dict(self):
|
167 |
+
return OmegaConf.to_container(self.config)
|
168 |
+
|
169 |
+
|
170 |
+
def node_to_dict(node):
|
171 |
+
return OmegaConf.to_container(node)
|
172 |
+
|
173 |
+
|
174 |
+
class ConfigValidator:
|
175 |
+
"""
|
176 |
+
This is a preliminary implementation to centralize and validate the configuration.
|
177 |
+
May be altered in the future.
|
178 |
+
|
179 |
+
A helper class to validate configurations from yaml file.
|
180 |
+
|
181 |
+
This serves the following purposes:
|
182 |
+
1. Ensure all the options in the yaml are defined, raise error if not.
|
183 |
+
2. when type mismatches are found, the validator will raise an error.
|
184 |
+
3. a central place to store and display helpful messages for supported configurations.
|
185 |
+
|
186 |
+
"""
|
187 |
+
|
188 |
+
class _Argument:
|
189 |
+
def __init__(self, name, choices=None, type=None, help=None):
|
190 |
+
self.name = name
|
191 |
+
self.val = None
|
192 |
+
self.choices = choices
|
193 |
+
self.type = type
|
194 |
+
self.help = help
|
195 |
+
|
196 |
+
def __str__(self):
|
197 |
+
s = f"{self.name}={self.val}"
|
198 |
+
if self.type is not None:
|
199 |
+
s += f", ({self.type})"
|
200 |
+
if self.choices is not None:
|
201 |
+
s += f", choices: {self.choices}"
|
202 |
+
if self.help is not None:
|
203 |
+
s += f", ({self.help})"
|
204 |
+
return s
|
205 |
+
|
206 |
+
def __init__(self, description):
|
207 |
+
self.description = description
|
208 |
+
|
209 |
+
self.arguments = dict()
|
210 |
+
|
211 |
+
self.parsed_args = None
|
212 |
+
|
213 |
+
def __getitem__(self, key):
|
214 |
+
assert self.parsed_args is not None, "No arguments parsed yet."
|
215 |
+
|
216 |
+
return self.parsed_args[key]
|
217 |
+
|
218 |
+
def __str__(self) -> str:
|
219 |
+
return self.format_help()
|
220 |
+
|
221 |
+
def add_argument(self, *args, **kwargs):
|
222 |
+
"""
|
223 |
+
Assume the first argument is the name of the argument.
|
224 |
+
"""
|
225 |
+
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
226 |
+
|
227 |
+
def validate(self, config=None):
|
228 |
+
"""
|
229 |
+
Convert yaml config (dict-like) to list, required by argparse.
|
230 |
+
"""
|
231 |
+
for k, v in config.items():
|
232 |
+
assert (
|
233 |
+
k in self.arguments
|
234 |
+
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
235 |
+
|
236 |
+
if self.arguments[k].type is not None:
|
237 |
+
try:
|
238 |
+
self.arguments[k].val = self.arguments[k].type(v)
|
239 |
+
except ValueError:
|
240 |
+
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
241 |
+
|
242 |
+
if self.arguments[k].choices is not None:
|
243 |
+
assert (
|
244 |
+
v in self.arguments[k].choices
|
245 |
+
), f"""{k} must be one of {self.arguments[k].choices}."""
|
246 |
+
|
247 |
+
return config
|
248 |
+
|
249 |
+
def format_arguments(self):
|
250 |
+
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
251 |
+
|
252 |
+
def format_help(self):
|
253 |
+
# description + key-value pair string for each argument
|
254 |
+
help_msg = str(self.description)
|
255 |
+
return help_msg + ", available arguments: " + self.format_arguments()
|
256 |
+
|
257 |
+
def print_help(self):
|
258 |
+
# display help message
|
259 |
+
print(self.format_help())
|
260 |
+
|
261 |
+
|
262 |
+
def create_runner_config_validator():
|
263 |
+
validator = ConfigValidator(description="Runner configurations")
|
264 |
+
|
265 |
+
validator.add_argument(
|
266 |
+
"runner",
|
267 |
+
type=str,
|
268 |
+
choices=["runner_base", "runner_iter"],
|
269 |
+
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
270 |
+
runner runs based on iters. Default: runner_base""",
|
271 |
+
)
|
272 |
+
# add argumetns for training dataset ratios
|
273 |
+
validator.add_argument(
|
274 |
+
"train_dataset_ratios",
|
275 |
+
type=Dict[str, float],
|
276 |
+
help="""Ratios of training dataset. This is used in iteration-based runner.
|
277 |
+
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
278 |
+
Default: None""",
|
279 |
+
)
|
280 |
+
validator.add_argument(
|
281 |
+
"max_iters",
|
282 |
+
type=float,
|
283 |
+
help="Maximum number of iterations to run.",
|
284 |
+
)
|
285 |
+
validator.add_argument(
|
286 |
+
"max_epoch",
|
287 |
+
type=int,
|
288 |
+
help="Maximum number of epochs to run.",
|
289 |
+
)
|
290 |
+
# add arguments for iters_per_inner_epoch
|
291 |
+
validator.add_argument(
|
292 |
+
"iters_per_inner_epoch",
|
293 |
+
type=float,
|
294 |
+
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
295 |
+
)
|
296 |
+
lr_scheds_choices = registry.list_lr_schedulers()
|
297 |
+
validator.add_argument(
|
298 |
+
"lr_sched",
|
299 |
+
type=str,
|
300 |
+
choices=lr_scheds_choices,
|
301 |
+
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
302 |
+
)
|
303 |
+
task_choices = registry.list_tasks()
|
304 |
+
validator.add_argument(
|
305 |
+
"task",
|
306 |
+
type=str,
|
307 |
+
choices=task_choices,
|
308 |
+
help="Task to use, from {}".format(task_choices),
|
309 |
+
)
|
310 |
+
# add arguments for init_lr
|
311 |
+
validator.add_argument(
|
312 |
+
"init_lr",
|
313 |
+
type=float,
|
314 |
+
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
315 |
+
)
|
316 |
+
# add arguments for min_lr
|
317 |
+
validator.add_argument(
|
318 |
+
"min_lr",
|
319 |
+
type=float,
|
320 |
+
help="Minimum learning rate (after decay).",
|
321 |
+
)
|
322 |
+
# add arguments for warmup_lr
|
323 |
+
validator.add_argument(
|
324 |
+
"warmup_lr",
|
325 |
+
type=float,
|
326 |
+
help="Starting learning rate for warmup.",
|
327 |
+
)
|
328 |
+
# add arguments for learning rate decay rate
|
329 |
+
validator.add_argument(
|
330 |
+
"lr_decay_rate",
|
331 |
+
type=float,
|
332 |
+
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
333 |
+
)
|
334 |
+
# add arguments for weight decay
|
335 |
+
validator.add_argument(
|
336 |
+
"weight_decay",
|
337 |
+
type=float,
|
338 |
+
help="Weight decay rate.",
|
339 |
+
)
|
340 |
+
# add arguments for training batch size
|
341 |
+
validator.add_argument(
|
342 |
+
"batch_size_train",
|
343 |
+
type=int,
|
344 |
+
help="Training batch size.",
|
345 |
+
)
|
346 |
+
# add arguments for evaluation batch size
|
347 |
+
validator.add_argument(
|
348 |
+
"batch_size_eval",
|
349 |
+
type=int,
|
350 |
+
help="Evaluation batch size, including validation and testing.",
|
351 |
+
)
|
352 |
+
# add arguments for number of workers for data loading
|
353 |
+
validator.add_argument(
|
354 |
+
"num_workers",
|
355 |
+
help="Number of workers for data loading.",
|
356 |
+
)
|
357 |
+
# add arguments for warm up steps
|
358 |
+
validator.add_argument(
|
359 |
+
"warmup_steps",
|
360 |
+
type=int,
|
361 |
+
help="Number of warmup steps. Required if a warmup schedule is used.",
|
362 |
+
)
|
363 |
+
# add arguments for random seed
|
364 |
+
validator.add_argument(
|
365 |
+
"seed",
|
366 |
+
type=int,
|
367 |
+
help="Random seed.",
|
368 |
+
)
|
369 |
+
# add arguments for output directory
|
370 |
+
validator.add_argument(
|
371 |
+
"output_dir",
|
372 |
+
type=str,
|
373 |
+
help="Output directory to save checkpoints and logs.",
|
374 |
+
)
|
375 |
+
# add arguments for whether only use evaluation
|
376 |
+
validator.add_argument(
|
377 |
+
"evaluate",
|
378 |
+
help="Whether to only evaluate the model. If true, training will not be performed.",
|
379 |
+
)
|
380 |
+
# add arguments for splits used for training, e.g. ["train", "val"]
|
381 |
+
validator.add_argument(
|
382 |
+
"train_splits",
|
383 |
+
type=list,
|
384 |
+
help="Splits to use for training.",
|
385 |
+
)
|
386 |
+
# add arguments for splits used for validation, e.g. ["val"]
|
387 |
+
validator.add_argument(
|
388 |
+
"valid_splits",
|
389 |
+
type=list,
|
390 |
+
help="Splits to use for validation. If not provided, will skip the validation.",
|
391 |
+
)
|
392 |
+
# add arguments for splits used for testing, e.g. ["test"]
|
393 |
+
validator.add_argument(
|
394 |
+
"test_splits",
|
395 |
+
type=list,
|
396 |
+
help="Splits to use for testing. If not provided, will skip the testing.",
|
397 |
+
)
|
398 |
+
# add arguments for accumulating gradient for iterations
|
399 |
+
validator.add_argument(
|
400 |
+
"accum_grad_iters",
|
401 |
+
type=int,
|
402 |
+
help="Number of iterations to accumulate gradient for.",
|
403 |
+
)
|
404 |
+
|
405 |
+
# ====== distributed training ======
|
406 |
+
validator.add_argument(
|
407 |
+
"device",
|
408 |
+
type=str,
|
409 |
+
choices=["cpu", "cuda"],
|
410 |
+
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
411 |
+
)
|
412 |
+
validator.add_argument(
|
413 |
+
"world_size",
|
414 |
+
type=int,
|
415 |
+
help="Number of processes participating in the job.",
|
416 |
+
)
|
417 |
+
validator.add_argument("dist_url", type=str)
|
418 |
+
validator.add_argument("distributed", type=bool)
|
419 |
+
# add arguments to opt using distributed sampler during evaluation or not
|
420 |
+
validator.add_argument(
|
421 |
+
"use_dist_eval_sampler",
|
422 |
+
type=bool,
|
423 |
+
help="Whether to use distributed sampler during evaluation or not.",
|
424 |
+
)
|
425 |
+
|
426 |
+
# ====== task specific ======
|
427 |
+
# generation task specific arguments
|
428 |
+
# add arguments for maximal length of text output
|
429 |
+
validator.add_argument(
|
430 |
+
"max_len",
|
431 |
+
type=int,
|
432 |
+
help="Maximal length of text output.",
|
433 |
+
)
|
434 |
+
# add arguments for minimal length of text output
|
435 |
+
validator.add_argument(
|
436 |
+
"min_len",
|
437 |
+
type=int,
|
438 |
+
help="Minimal length of text output.",
|
439 |
+
)
|
440 |
+
# add arguments number of beams
|
441 |
+
validator.add_argument(
|
442 |
+
"num_beams",
|
443 |
+
type=int,
|
444 |
+
help="Number of beams used for beam search.",
|
445 |
+
)
|
446 |
+
|
447 |
+
# vqa task specific arguments
|
448 |
+
# add arguments for number of answer candidates
|
449 |
+
validator.add_argument(
|
450 |
+
"num_ans_candidates",
|
451 |
+
type=int,
|
452 |
+
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
453 |
+
)
|
454 |
+
# add arguments for inference method
|
455 |
+
validator.add_argument(
|
456 |
+
"inference_method",
|
457 |
+
type=str,
|
458 |
+
choices=["genearte", "rank"],
|
459 |
+
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
460 |
+
)
|
461 |
+
|
462 |
+
# ====== model specific ======
|
463 |
+
validator.add_argument(
|
464 |
+
"k_test",
|
465 |
+
type=int,
|
466 |
+
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
467 |
+
)
|
468 |
+
|
469 |
+
return validator
|
bliva/common/dist_utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import functools
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
import timm.models.hub as timm_hub
|
15 |
+
|
16 |
+
|
17 |
+
def setup_for_distributed(is_master):
|
18 |
+
"""
|
19 |
+
This function disables printing when not in master process
|
20 |
+
"""
|
21 |
+
import builtins as __builtin__
|
22 |
+
|
23 |
+
builtin_print = __builtin__.print
|
24 |
+
|
25 |
+
def print(*args, **kwargs):
|
26 |
+
force = kwargs.pop("force", False)
|
27 |
+
if is_master or force:
|
28 |
+
builtin_print(*args, **kwargs)
|
29 |
+
|
30 |
+
__builtin__.print = print
|
31 |
+
|
32 |
+
|
33 |
+
def is_dist_avail_and_initialized():
|
34 |
+
if not dist.is_available():
|
35 |
+
return False
|
36 |
+
if not dist.is_initialized():
|
37 |
+
return False
|
38 |
+
return True
|
39 |
+
|
40 |
+
|
41 |
+
def get_world_size():
|
42 |
+
if not is_dist_avail_and_initialized():
|
43 |
+
return 1
|
44 |
+
return dist.get_world_size()
|
45 |
+
|
46 |
+
|
47 |
+
def get_rank():
|
48 |
+
if not is_dist_avail_and_initialized():
|
49 |
+
return 0
|
50 |
+
return dist.get_rank()
|
51 |
+
|
52 |
+
|
53 |
+
def is_main_process():
|
54 |
+
return get_rank() == 0
|
55 |
+
|
56 |
+
|
57 |
+
def init_distributed_mode(args):
|
58 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
59 |
+
args.rank = int(os.environ["RANK"])
|
60 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
61 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
62 |
+
elif "SLURM_PROCID" in os.environ:
|
63 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
64 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
65 |
+
else:
|
66 |
+
print("Not using distributed mode")
|
67 |
+
args.distributed = False
|
68 |
+
return
|
69 |
+
|
70 |
+
args.distributed = True
|
71 |
+
|
72 |
+
torch.cuda.set_device(args.gpu)
|
73 |
+
args.dist_backend = "nccl"
|
74 |
+
print(
|
75 |
+
"| distributed init (rank {}, world {}): {}".format(
|
76 |
+
args.rank, args.world_size, args.dist_url
|
77 |
+
),
|
78 |
+
flush=True,
|
79 |
+
)
|
80 |
+
torch.distributed.init_process_group(
|
81 |
+
backend=args.dist_backend,
|
82 |
+
init_method=args.dist_url,
|
83 |
+
world_size=args.world_size,
|
84 |
+
rank=args.rank,
|
85 |
+
timeout=datetime.timedelta(
|
86 |
+
days=365
|
87 |
+
), # allow auto-downloading and de-compressing
|
88 |
+
)
|
89 |
+
torch.distributed.barrier()
|
90 |
+
setup_for_distributed(args.rank == 0)
|
91 |
+
|
92 |
+
|
93 |
+
def get_dist_info():
|
94 |
+
if torch.__version__ < "1.0":
|
95 |
+
initialized = dist._initialized
|
96 |
+
else:
|
97 |
+
initialized = dist.is_initialized()
|
98 |
+
if initialized:
|
99 |
+
rank = dist.get_rank()
|
100 |
+
world_size = dist.get_world_size()
|
101 |
+
else: # non-distributed training
|
102 |
+
rank = 0
|
103 |
+
world_size = 1
|
104 |
+
return rank, world_size
|
105 |
+
|
106 |
+
|
107 |
+
def main_process(func):
|
108 |
+
@functools.wraps(func)
|
109 |
+
def wrapper(*args, **kwargs):
|
110 |
+
rank, _ = get_dist_info()
|
111 |
+
if rank == 0:
|
112 |
+
return func(*args, **kwargs)
|
113 |
+
|
114 |
+
return wrapper
|
115 |
+
|
116 |
+
|
117 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
118 |
+
"""
|
119 |
+
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
120 |
+
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def get_cached_file_path():
|
124 |
+
# a hack to sync the file path across processes
|
125 |
+
parts = torch.hub.urlparse(url)
|
126 |
+
filename = os.path.basename(parts.path)
|
127 |
+
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
128 |
+
|
129 |
+
return cached_file
|
130 |
+
|
131 |
+
if is_main_process():
|
132 |
+
timm_hub.download_cached_file(url, check_hash, progress)
|
133 |
+
|
134 |
+
if is_dist_avail_and_initialized():
|
135 |
+
dist.barrier()
|
136 |
+
|
137 |
+
return get_cached_file_path()
|
bliva/common/gradcam.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from matplotlib import pyplot as plt
|
3 |
+
from scipy.ndimage import filters
|
4 |
+
from skimage import transform as skimage_transform
|
5 |
+
|
6 |
+
|
7 |
+
def getAttMap(img, attMap, blur=True, overlap=True):
|
8 |
+
attMap -= attMap.min()
|
9 |
+
if attMap.max() > 0:
|
10 |
+
attMap /= attMap.max()
|
11 |
+
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
12 |
+
if blur:
|
13 |
+
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
14 |
+
attMap -= attMap.min()
|
15 |
+
attMap /= attMap.max()
|
16 |
+
cmap = plt.get_cmap("jet")
|
17 |
+
attMapV = cmap(attMap)
|
18 |
+
attMapV = np.delete(attMapV, 3, 2)
|
19 |
+
if overlap:
|
20 |
+
attMap = (
|
21 |
+
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
22 |
+
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
23 |
+
)
|
24 |
+
return attMap
|
bliva/common/logger.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import logging
|
10 |
+
import time
|
11 |
+
from collections import defaultdict, deque
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
|
16 |
+
from bliva.common import dist_utils
|
17 |
+
|
18 |
+
class SmoothedValue(object):
|
19 |
+
"""Track a series of values and provide access to smoothed values over a
|
20 |
+
window or the global series average.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, window_size=20, fmt=None):
|
24 |
+
if fmt is None:
|
25 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
26 |
+
self.deque = deque(maxlen=window_size)
|
27 |
+
self.total = 0.0
|
28 |
+
self.count = 0
|
29 |
+
self.fmt = fmt
|
30 |
+
|
31 |
+
def update(self, value, n=1):
|
32 |
+
self.deque.append(value)
|
33 |
+
self.count += n
|
34 |
+
self.total += value * n
|
35 |
+
|
36 |
+
def synchronize_between_processes(self):
|
37 |
+
"""
|
38 |
+
Warning: does not synchronize the deque!
|
39 |
+
"""
|
40 |
+
if not dist_utils.is_dist_avail_and_initialized():
|
41 |
+
return
|
42 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
43 |
+
dist.barrier()
|
44 |
+
dist.all_reduce(t)
|
45 |
+
t = t.tolist()
|
46 |
+
self.count = int(t[0])
|
47 |
+
self.total = t[1]
|
48 |
+
|
49 |
+
@property
|
50 |
+
def median(self):
|
51 |
+
d = torch.tensor(list(self.deque))
|
52 |
+
return d.median().item()
|
53 |
+
|
54 |
+
@property
|
55 |
+
def avg(self):
|
56 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
57 |
+
return d.mean().item()
|
58 |
+
|
59 |
+
@property
|
60 |
+
def global_avg(self):
|
61 |
+
return self.total / self.count
|
62 |
+
|
63 |
+
@property
|
64 |
+
def max(self):
|
65 |
+
return max(self.deque)
|
66 |
+
|
67 |
+
@property
|
68 |
+
def value(self):
|
69 |
+
return self.deque[-1]
|
70 |
+
|
71 |
+
def __str__(self):
|
72 |
+
return self.fmt.format(
|
73 |
+
median=self.median,
|
74 |
+
avg=self.avg,
|
75 |
+
global_avg=self.global_avg,
|
76 |
+
max=self.max,
|
77 |
+
value=self.value,
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
class MetricLogger(object):
|
82 |
+
def __init__(self, delimiter="\t"):
|
83 |
+
self.meters = defaultdict(SmoothedValue)
|
84 |
+
self.delimiter = delimiter
|
85 |
+
|
86 |
+
def update(self, **kwargs):
|
87 |
+
for k, v in kwargs.items():
|
88 |
+
if isinstance(v, torch.Tensor):
|
89 |
+
v = v.item()
|
90 |
+
assert isinstance(v, (float, int))
|
91 |
+
self.meters[k].update(v)
|
92 |
+
|
93 |
+
def __getattr__(self, attr):
|
94 |
+
if attr in self.meters:
|
95 |
+
return self.meters[attr]
|
96 |
+
if attr in self.__dict__:
|
97 |
+
return self.__dict__[attr]
|
98 |
+
raise AttributeError(
|
99 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
100 |
+
)
|
101 |
+
|
102 |
+
def __str__(self):
|
103 |
+
loss_str = []
|
104 |
+
for name, meter in self.meters.items():
|
105 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
106 |
+
return self.delimiter.join(loss_str)
|
107 |
+
|
108 |
+
def global_avg(self):
|
109 |
+
loss_str = []
|
110 |
+
for name, meter in self.meters.items():
|
111 |
+
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
112 |
+
return self.delimiter.join(loss_str)
|
113 |
+
|
114 |
+
def synchronize_between_processes(self):
|
115 |
+
for meter in self.meters.values():
|
116 |
+
meter.synchronize_between_processes()
|
117 |
+
|
118 |
+
def add_meter(self, name, meter):
|
119 |
+
self.meters[name] = meter
|
120 |
+
|
121 |
+
def log_every(self, iterable, print_freq, header=None):
|
122 |
+
i = 0
|
123 |
+
if not header:
|
124 |
+
header = ""
|
125 |
+
start_time = time.time()
|
126 |
+
end = time.time()
|
127 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
128 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
129 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
130 |
+
log_msg = [
|
131 |
+
header,
|
132 |
+
"[{0" + space_fmt + "}/{1}]",
|
133 |
+
"eta: {eta}",
|
134 |
+
"{meters}",
|
135 |
+
"time: {time}",
|
136 |
+
"data: {data}",
|
137 |
+
]
|
138 |
+
if torch.cuda.is_available():
|
139 |
+
log_msg.append("max mem: {memory:.0f}")
|
140 |
+
log_msg = self.delimiter.join(log_msg)
|
141 |
+
MB = 1024.0 * 1024.0
|
142 |
+
for obj in iterable:
|
143 |
+
data_time.update(time.time() - end)
|
144 |
+
yield obj
|
145 |
+
iter_time.update(time.time() - end)
|
146 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
147 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
148 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
149 |
+
if torch.cuda.is_available():
|
150 |
+
print(
|
151 |
+
log_msg.format(
|
152 |
+
i,
|
153 |
+
len(iterable),
|
154 |
+
eta=eta_string,
|
155 |
+
meters=str(self),
|
156 |
+
time=str(iter_time),
|
157 |
+
data=str(data_time),
|
158 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
159 |
+
)
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
print(
|
163 |
+
log_msg.format(
|
164 |
+
i,
|
165 |
+
len(iterable),
|
166 |
+
eta=eta_string,
|
167 |
+
meters=str(self),
|
168 |
+
time=str(iter_time),
|
169 |
+
data=str(data_time),
|
170 |
+
)
|
171 |
+
)
|
172 |
+
i += 1
|
173 |
+
end = time.time()
|
174 |
+
total_time = time.time() - start_time
|
175 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
176 |
+
print(
|
177 |
+
"{} Total time: {} ({:.4f} s / it)".format(
|
178 |
+
header, total_time_str, total_time / len(iterable)
|
179 |
+
)
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
class AttrDict(dict):
|
184 |
+
def __init__(self, *args, **kwargs):
|
185 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
186 |
+
self.__dict__ = self
|
187 |
+
|
188 |
+
def setup_logger():
|
189 |
+
logging.basicConfig(
|
190 |
+
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
191 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
192 |
+
handlers=[logging.StreamHandler()],
|
193 |
+
)
|
bliva/common/optims.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
from bliva.common.registry import registry
|
11 |
+
|
12 |
+
|
13 |
+
@registry.register_lr_scheduler("linear_warmup_step_lr")
|
14 |
+
class LinearWarmupStepLRScheduler:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
optimizer,
|
18 |
+
max_epoch,
|
19 |
+
min_lr,
|
20 |
+
init_lr,
|
21 |
+
decay_rate=1,
|
22 |
+
warmup_start_lr=-1,
|
23 |
+
warmup_steps=0,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
self.optimizer = optimizer
|
27 |
+
|
28 |
+
self.max_epoch = max_epoch
|
29 |
+
self.min_lr = min_lr
|
30 |
+
|
31 |
+
self.decay_rate = decay_rate
|
32 |
+
|
33 |
+
self.init_lr = init_lr
|
34 |
+
self.warmup_steps = warmup_steps
|
35 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
36 |
+
|
37 |
+
def step(self, cur_epoch, cur_step):
|
38 |
+
if cur_epoch == 0:
|
39 |
+
warmup_lr_schedule(
|
40 |
+
step=cur_step,
|
41 |
+
optimizer=self.optimizer,
|
42 |
+
max_step=self.warmup_steps,
|
43 |
+
init_lr=self.warmup_start_lr,
|
44 |
+
max_lr=self.init_lr,
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
step_lr_schedule(
|
48 |
+
epoch=cur_epoch,
|
49 |
+
optimizer=self.optimizer,
|
50 |
+
init_lr=self.init_lr,
|
51 |
+
min_lr=self.min_lr,
|
52 |
+
decay_rate=self.decay_rate,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
|
57 |
+
class LinearWarmupCosineLRScheduler:
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
optimizer,
|
61 |
+
max_epoch,
|
62 |
+
min_lr,
|
63 |
+
init_lr,
|
64 |
+
warmup_steps=0,
|
65 |
+
warmup_start_lr=-1,
|
66 |
+
**kwargs
|
67 |
+
):
|
68 |
+
self.optimizer = optimizer
|
69 |
+
|
70 |
+
self.max_epoch = max_epoch
|
71 |
+
self.min_lr = min_lr
|
72 |
+
|
73 |
+
self.init_lr = init_lr
|
74 |
+
self.warmup_steps = warmup_steps
|
75 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
76 |
+
|
77 |
+
def step(self, cur_epoch, cur_step):
|
78 |
+
# assuming the warmup iters less than one epoch
|
79 |
+
if cur_epoch == 0:
|
80 |
+
warmup_lr_schedule(
|
81 |
+
step=cur_step,
|
82 |
+
optimizer=self.optimizer,
|
83 |
+
max_step=self.warmup_steps,
|
84 |
+
init_lr=self.warmup_start_lr,
|
85 |
+
max_lr=self.init_lr,
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
cosine_lr_schedule(
|
89 |
+
epoch=cur_epoch,
|
90 |
+
optimizer=self.optimizer,
|
91 |
+
max_epoch=self.max_epoch,
|
92 |
+
init_lr=self.init_lr,
|
93 |
+
min_lr=self.min_lr,
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
98 |
+
"""Decay the learning rate"""
|
99 |
+
lr = (init_lr - min_lr) * 0.5 * (
|
100 |
+
1.0 + math.cos(math.pi * epoch / max_epoch)
|
101 |
+
) + min_lr
|
102 |
+
for param_group in optimizer.param_groups:
|
103 |
+
param_group["lr"] = lr
|
104 |
+
|
105 |
+
|
106 |
+
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
107 |
+
"""Warmup the learning rate"""
|
108 |
+
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
109 |
+
for param_group in optimizer.param_groups:
|
110 |
+
param_group["lr"] = lr
|
111 |
+
|
112 |
+
|
113 |
+
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
114 |
+
"""Decay the learning rate"""
|
115 |
+
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
116 |
+
for param_group in optimizer.param_groups:
|
117 |
+
param_group["lr"] = lr
|
bliva/common/registry.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class Registry:
|
10 |
+
mapping = {
|
11 |
+
"builder_name_mapping": {},
|
12 |
+
"task_name_mapping": {},
|
13 |
+
"processor_name_mapping": {},
|
14 |
+
"model_name_mapping": {},
|
15 |
+
"lr_scheduler_name_mapping": {},
|
16 |
+
"runner_name_mapping": {},
|
17 |
+
"state": {},
|
18 |
+
"paths": {},
|
19 |
+
}
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def register_model(cls, name):
|
23 |
+
r"""Register a task to registry with key 'name'
|
24 |
+
|
25 |
+
Args:
|
26 |
+
name: Key with which the task will be registered.
|
27 |
+
|
28 |
+
Usage:
|
29 |
+
|
30 |
+
from bliva.common.registry import registry
|
31 |
+
"""
|
32 |
+
|
33 |
+
def wrap(model_cls):
|
34 |
+
from bliva.models import BaseModel
|
35 |
+
|
36 |
+
assert issubclass(
|
37 |
+
model_cls, BaseModel
|
38 |
+
), "All models must inherit BaseModel class"
|
39 |
+
if name in cls.mapping["model_name_mapping"]:
|
40 |
+
raise KeyError(
|
41 |
+
"Name '{}' already registered for {}.".format(
|
42 |
+
name, cls.mapping["model_name_mapping"][name]
|
43 |
+
)
|
44 |
+
)
|
45 |
+
cls.mapping["model_name_mapping"][name] = model_cls
|
46 |
+
return model_cls
|
47 |
+
|
48 |
+
return wrap
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def register_processor(cls, name):
|
52 |
+
r"""Register a processor to registry with key 'name'
|
53 |
+
|
54 |
+
Args:
|
55 |
+
name: Key with which the task will be registered.
|
56 |
+
|
57 |
+
Usage:
|
58 |
+
|
59 |
+
from bliva.common.registry import registry
|
60 |
+
"""
|
61 |
+
|
62 |
+
def wrap(processor_cls):
|
63 |
+
from bliva.processors import BaseProcessor
|
64 |
+
|
65 |
+
assert issubclass(
|
66 |
+
processor_cls, BaseProcessor
|
67 |
+
), "All processors must inherit BaseProcessor class"
|
68 |
+
if name in cls.mapping["processor_name_mapping"]:
|
69 |
+
raise KeyError(
|
70 |
+
"Name '{}' already registered for {}.".format(
|
71 |
+
name, cls.mapping["processor_name_mapping"][name]
|
72 |
+
)
|
73 |
+
)
|
74 |
+
cls.mapping["processor_name_mapping"][name] = processor_cls
|
75 |
+
return processor_cls
|
76 |
+
|
77 |
+
return wrap
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def register_lr_scheduler(cls, name):
|
81 |
+
r"""Register a model to registry with key 'name'
|
82 |
+
|
83 |
+
Args:
|
84 |
+
name: Key with which the task will be registered.
|
85 |
+
|
86 |
+
Usage:
|
87 |
+
|
88 |
+
from bliva.common.registry import registry
|
89 |
+
"""
|
90 |
+
|
91 |
+
def wrap(lr_sched_cls):
|
92 |
+
if name in cls.mapping["lr_scheduler_name_mapping"]:
|
93 |
+
raise KeyError(
|
94 |
+
"Name '{}' already registered for {}.".format(
|
95 |
+
name, cls.mapping["lr_scheduler_name_mapping"][name]
|
96 |
+
)
|
97 |
+
)
|
98 |
+
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
|
99 |
+
return lr_sched_cls
|
100 |
+
|
101 |
+
return wrap
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def register_runner(cls, name):
|
105 |
+
r"""Register a model to registry with key 'name'
|
106 |
+
|
107 |
+
Args:
|
108 |
+
name: Key with which the task will be registered.
|
109 |
+
|
110 |
+
Usage:
|
111 |
+
|
112 |
+
from bliva.common.registry import registry
|
113 |
+
"""
|
114 |
+
|
115 |
+
def wrap(runner_cls):
|
116 |
+
if name in cls.mapping["runner_name_mapping"]:
|
117 |
+
raise KeyError(
|
118 |
+
"Name '{}' already registered for {}.".format(
|
119 |
+
name, cls.mapping["runner_name_mapping"][name]
|
120 |
+
)
|
121 |
+
)
|
122 |
+
cls.mapping["runner_name_mapping"][name] = runner_cls
|
123 |
+
return runner_cls
|
124 |
+
|
125 |
+
return wrap
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def register_path(cls, name, path):
|
129 |
+
r"""Register a path to registry with key 'name'
|
130 |
+
|
131 |
+
Args:
|
132 |
+
name: Key with which the path will be registered.
|
133 |
+
|
134 |
+
Usage:
|
135 |
+
|
136 |
+
from bliva.common.registry import registry
|
137 |
+
"""
|
138 |
+
assert isinstance(path, str), "All path must be str."
|
139 |
+
if name in cls.mapping["paths"]:
|
140 |
+
raise KeyError("Name '{}' already registered.".format(name))
|
141 |
+
cls.mapping["paths"][name] = path
|
142 |
+
|
143 |
+
@classmethod
|
144 |
+
def register(cls, name, obj):
|
145 |
+
r"""Register an item to registry with key 'name'
|
146 |
+
|
147 |
+
Args:
|
148 |
+
name: Key with which the item will be registered.
|
149 |
+
|
150 |
+
Usage::
|
151 |
+
|
152 |
+
from bliva.common.registry import registry
|
153 |
+
|
154 |
+
registry.register("config", {})
|
155 |
+
"""
|
156 |
+
path = name.split(".")
|
157 |
+
current = cls.mapping["state"]
|
158 |
+
|
159 |
+
for part in path[:-1]:
|
160 |
+
if part not in current:
|
161 |
+
current[part] = {}
|
162 |
+
current = current[part]
|
163 |
+
|
164 |
+
current[path[-1]] = obj
|
165 |
+
|
166 |
+
# @classmethod
|
167 |
+
# def get_trainer_class(cls, name):
|
168 |
+
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
169 |
+
|
170 |
+
@classmethod
|
171 |
+
def get_builder_class(cls, name):
|
172 |
+
return cls.mapping["builder_name_mapping"].get(name, None)
|
173 |
+
|
174 |
+
@classmethod
|
175 |
+
def get_model_class(cls, name):
|
176 |
+
return cls.mapping["model_name_mapping"].get(name, None)
|
177 |
+
|
178 |
+
@classmethod
|
179 |
+
def get_task_class(cls, name):
|
180 |
+
return cls.mapping["task_name_mapping"].get(name, None)
|
181 |
+
|
182 |
+
@classmethod
|
183 |
+
def get_processor_class(cls, name):
|
184 |
+
return cls.mapping["processor_name_mapping"].get(name, None)
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def get_lr_scheduler_class(cls, name):
|
188 |
+
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
|
189 |
+
|
190 |
+
@classmethod
|
191 |
+
def get_runner_class(cls, name):
|
192 |
+
return cls.mapping["runner_name_mapping"].get(name, None)
|
193 |
+
|
194 |
+
@classmethod
|
195 |
+
def list_runners(cls):
|
196 |
+
return sorted(cls.mapping["runner_name_mapping"].keys())
|
197 |
+
|
198 |
+
@classmethod
|
199 |
+
def list_models(cls):
|
200 |
+
return sorted(cls.mapping["model_name_mapping"].keys())
|
201 |
+
|
202 |
+
@classmethod
|
203 |
+
def list_tasks(cls):
|
204 |
+
return sorted(cls.mapping["task_name_mapping"].keys())
|
205 |
+
|
206 |
+
@classmethod
|
207 |
+
def list_processors(cls):
|
208 |
+
return sorted(cls.mapping["processor_name_mapping"].keys())
|
209 |
+
|
210 |
+
@classmethod
|
211 |
+
def list_lr_schedulers(cls):
|
212 |
+
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
|
213 |
+
|
214 |
+
@classmethod
|
215 |
+
def list_datasets(cls):
|
216 |
+
return sorted(cls.mapping["builder_name_mapping"].keys())
|
217 |
+
|
218 |
+
@classmethod
|
219 |
+
def get_path(cls, name):
|
220 |
+
return cls.mapping["paths"].get(name, None)
|
221 |
+
|
222 |
+
@classmethod
|
223 |
+
def get(cls, name, default=None, no_warning=False):
|
224 |
+
r"""Get an item from registry with key 'name'
|
225 |
+
|
226 |
+
Args:
|
227 |
+
name (string): Key whose value needs to be retrieved.
|
228 |
+
default: If passed and key is not in registry, default value will
|
229 |
+
be returned with a warning. Default: None
|
230 |
+
no_warning (bool): If passed as True, warning when key doesn't exist
|
231 |
+
will not be generated. Useful for MMF's
|
232 |
+
internal operations. Default: False
|
233 |
+
"""
|
234 |
+
original_name = name
|
235 |
+
name = name.split(".")
|
236 |
+
value = cls.mapping["state"]
|
237 |
+
for subname in name:
|
238 |
+
value = value.get(subname, default)
|
239 |
+
if value is default:
|
240 |
+
break
|
241 |
+
|
242 |
+
if (
|
243 |
+
"writer" in cls.mapping["state"]
|
244 |
+
and value == default
|
245 |
+
and no_warning is False
|
246 |
+
):
|
247 |
+
cls.mapping["state"]["writer"].warning(
|
248 |
+
"Key {} is not present in registry, returning default value "
|
249 |
+
"of {}".format(original_name, default)
|
250 |
+
)
|
251 |
+
return value
|
252 |
+
|
253 |
+
@classmethod
|
254 |
+
def unregister(cls, name):
|
255 |
+
r"""Remove an item from registry with key 'name'
|
256 |
+
|
257 |
+
Args:
|
258 |
+
name: Key which needs to be removed.
|
259 |
+
Usage::
|
260 |
+
|
261 |
+
from mmf.common.registry import registry
|
262 |
+
|
263 |
+
config = registry.unregister("config")
|
264 |
+
"""
|
265 |
+
return cls.mapping["state"].pop(name, None)
|
266 |
+
|
267 |
+
|
268 |
+
registry = Registry()
|
bliva/common/utils.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import io
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import pickle
|
13 |
+
import re
|
14 |
+
import shutil
|
15 |
+
import urllib
|
16 |
+
import urllib.error
|
17 |
+
import urllib.request
|
18 |
+
from typing import Optional
|
19 |
+
from urllib.parse import urlparse
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import pandas as pd
|
23 |
+
import yaml
|
24 |
+
from iopath.common.download import download
|
25 |
+
from iopath.common.file_io import file_lock, g_pathmgr
|
26 |
+
from bliva.common.registry import registry
|
27 |
+
from torch.utils.model_zoo import tqdm
|
28 |
+
from torchvision.datasets.utils import (
|
29 |
+
check_integrity,
|
30 |
+
download_file_from_google_drive,
|
31 |
+
extract_archive,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def now():
|
36 |
+
from datetime import datetime
|
37 |
+
|
38 |
+
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
|
39 |
+
|
40 |
+
|
41 |
+
def is_url(url_or_filename):
|
42 |
+
parsed = urlparse(url_or_filename)
|
43 |
+
return parsed.scheme in ("http", "https")
|
44 |
+
|
45 |
+
|
46 |
+
def get_cache_path(rel_path):
|
47 |
+
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
48 |
+
|
49 |
+
|
50 |
+
def get_abs_path(rel_path):
|
51 |
+
return os.path.join(registry.get_path("library_root"), rel_path)
|
52 |
+
|
53 |
+
|
54 |
+
def load_json(filename):
|
55 |
+
with open(filename, "r") as f:
|
56 |
+
return json.load(f)
|
57 |
+
|
58 |
+
|
59 |
+
# The following are adapted from torchvision and vissl
|
60 |
+
# torchvision: https://github.com/pytorch/vision
|
61 |
+
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
|
62 |
+
|
63 |
+
|
64 |
+
def makedir(dir_path):
|
65 |
+
"""
|
66 |
+
Create the directory if it does not exist.
|
67 |
+
"""
|
68 |
+
is_success = False
|
69 |
+
try:
|
70 |
+
if not g_pathmgr.exists(dir_path):
|
71 |
+
g_pathmgr.mkdirs(dir_path)
|
72 |
+
is_success = True
|
73 |
+
except BaseException:
|
74 |
+
print(f"Error creating directory: {dir_path}")
|
75 |
+
return is_success
|
76 |
+
|
77 |
+
|
78 |
+
def get_redirected_url(url: str):
|
79 |
+
"""
|
80 |
+
Given a URL, returns the URL it redirects to or the
|
81 |
+
original URL in case of no indirection
|
82 |
+
"""
|
83 |
+
import requests
|
84 |
+
|
85 |
+
with requests.Session() as session:
|
86 |
+
with session.get(url, stream=True, allow_redirects=True) as response:
|
87 |
+
if response.history:
|
88 |
+
return response.url
|
89 |
+
else:
|
90 |
+
return url
|
91 |
+
|
92 |
+
|
93 |
+
def to_google_drive_download_url(view_url: str) -> str:
|
94 |
+
"""
|
95 |
+
Utility function to transform a view URL of google drive
|
96 |
+
to a download URL for google drive
|
97 |
+
Example input:
|
98 |
+
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
99 |
+
Example output:
|
100 |
+
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
101 |
+
"""
|
102 |
+
splits = view_url.split("/")
|
103 |
+
assert splits[-1] == "view"
|
104 |
+
file_id = splits[-2]
|
105 |
+
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
106 |
+
|
107 |
+
|
108 |
+
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
109 |
+
"""
|
110 |
+
Download a file from google drive
|
111 |
+
Downloading an URL from google drive requires confirmation when
|
112 |
+
the file of the size is too big (google drive notifies that
|
113 |
+
anti-viral checks cannot be performed on such files)
|
114 |
+
"""
|
115 |
+
import requests
|
116 |
+
|
117 |
+
with requests.Session() as session:
|
118 |
+
|
119 |
+
# First get the confirmation token and append it to the URL
|
120 |
+
with session.get(url, stream=True, allow_redirects=True) as response:
|
121 |
+
for k, v in response.cookies.items():
|
122 |
+
if k.startswith("download_warning"):
|
123 |
+
url = url + "&confirm=" + v
|
124 |
+
|
125 |
+
# Then download the content of the file
|
126 |
+
with session.get(url, stream=True, verify=True) as response:
|
127 |
+
makedir(output_path)
|
128 |
+
path = os.path.join(output_path, output_file_name)
|
129 |
+
total_size = int(response.headers.get("Content-length", 0))
|
130 |
+
with open(path, "wb") as file:
|
131 |
+
from tqdm import tqdm
|
132 |
+
|
133 |
+
with tqdm(total=total_size) as progress_bar:
|
134 |
+
for block in response.iter_content(
|
135 |
+
chunk_size=io.DEFAULT_BUFFER_SIZE
|
136 |
+
):
|
137 |
+
file.write(block)
|
138 |
+
progress_bar.update(len(block))
|
139 |
+
|
140 |
+
|
141 |
+
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
142 |
+
parts = urlparse(url)
|
143 |
+
|
144 |
+
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
145 |
+
return None
|
146 |
+
|
147 |
+
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
148 |
+
if match is None:
|
149 |
+
return None
|
150 |
+
|
151 |
+
return match.group("id")
|
152 |
+
|
153 |
+
|
154 |
+
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
155 |
+
with open(filename, "wb") as fh:
|
156 |
+
with urllib.request.urlopen(
|
157 |
+
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
158 |
+
) as response:
|
159 |
+
with tqdm(total=response.length) as pbar:
|
160 |
+
for chunk in iter(lambda: response.read(chunk_size), ""):
|
161 |
+
if not chunk:
|
162 |
+
break
|
163 |
+
pbar.update(chunk_size)
|
164 |
+
fh.write(chunk)
|
165 |
+
|
166 |
+
|
167 |
+
def download_url(
|
168 |
+
url: str,
|
169 |
+
root: str,
|
170 |
+
filename: Optional[str] = None,
|
171 |
+
md5: Optional[str] = None,
|
172 |
+
) -> None:
|
173 |
+
"""Download a file from a url and place it in root.
|
174 |
+
Args:
|
175 |
+
url (str): URL to download file from
|
176 |
+
root (str): Directory to place downloaded file in
|
177 |
+
filename (str, optional): Name to save the file under.
|
178 |
+
If None, use the basename of the URL.
|
179 |
+
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
180 |
+
"""
|
181 |
+
root = os.path.expanduser(root)
|
182 |
+
if not filename:
|
183 |
+
filename = os.path.basename(url)
|
184 |
+
fpath = os.path.join(root, filename)
|
185 |
+
|
186 |
+
makedir(root)
|
187 |
+
|
188 |
+
# check if file is already present locally
|
189 |
+
if check_integrity(fpath, md5):
|
190 |
+
print("Using downloaded and verified file: " + fpath)
|
191 |
+
return
|
192 |
+
|
193 |
+
# expand redirect chain if needed
|
194 |
+
url = get_redirected_url(url)
|
195 |
+
|
196 |
+
# check if file is located on Google Drive
|
197 |
+
file_id = _get_google_drive_file_id(url)
|
198 |
+
if file_id is not None:
|
199 |
+
return download_file_from_google_drive(file_id, root, filename, md5)
|
200 |
+
|
201 |
+
# download the file
|
202 |
+
try:
|
203 |
+
print("Downloading " + url + " to " + fpath)
|
204 |
+
_urlretrieve(url, fpath)
|
205 |
+
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
|
206 |
+
if url[:5] == "https":
|
207 |
+
url = url.replace("https:", "http:")
|
208 |
+
print(
|
209 |
+
"Failed download. Trying https -> http instead."
|
210 |
+
" Downloading " + url + " to " + fpath
|
211 |
+
)
|
212 |
+
_urlretrieve(url, fpath)
|
213 |
+
else:
|
214 |
+
raise e
|
215 |
+
|
216 |
+
# check integrity of downloaded file
|
217 |
+
if not check_integrity(fpath, md5):
|
218 |
+
raise RuntimeError("File not found or corrupted.")
|
219 |
+
|
220 |
+
|
221 |
+
def download_and_extract_archive(
|
222 |
+
url: str,
|
223 |
+
download_root: str,
|
224 |
+
extract_root: Optional[str] = None,
|
225 |
+
filename: Optional[str] = None,
|
226 |
+
md5: Optional[str] = None,
|
227 |
+
remove_finished: bool = False,
|
228 |
+
) -> None:
|
229 |
+
download_root = os.path.expanduser(download_root)
|
230 |
+
if extract_root is None:
|
231 |
+
extract_root = download_root
|
232 |
+
if not filename:
|
233 |
+
filename = os.path.basename(url)
|
234 |
+
|
235 |
+
download_url(url, download_root, filename, md5)
|
236 |
+
|
237 |
+
archive = os.path.join(download_root, filename)
|
238 |
+
print("Extracting {} to {}".format(archive, extract_root))
|
239 |
+
extract_archive(archive, extract_root, remove_finished)
|
240 |
+
|
241 |
+
|
242 |
+
def cache_url(url: str, cache_dir: str) -> str:
|
243 |
+
"""
|
244 |
+
This implementation downloads the remote resource and caches it locally.
|
245 |
+
The resource will only be downloaded if not previously requested.
|
246 |
+
"""
|
247 |
+
parsed_url = urlparse(url)
|
248 |
+
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
249 |
+
makedir(dirname)
|
250 |
+
filename = url.split("/")[-1]
|
251 |
+
cached = os.path.join(dirname, filename)
|
252 |
+
with file_lock(cached):
|
253 |
+
if not os.path.isfile(cached):
|
254 |
+
logging.info(f"Downloading {url} to {cached} ...")
|
255 |
+
cached = download(url, dirname, filename=filename)
|
256 |
+
logging.info(f"URL {url} cached in {cached}")
|
257 |
+
return cached
|
258 |
+
|
259 |
+
|
260 |
+
# TODO (prigoyal): convert this into RAII-style API
|
261 |
+
def create_file_symlink(file1, file2):
|
262 |
+
"""
|
263 |
+
Simply create the symlinks for a given file1 to file2.
|
264 |
+
Useful during model checkpointing to symlinks to the
|
265 |
+
latest successful checkpoint.
|
266 |
+
"""
|
267 |
+
try:
|
268 |
+
if g_pathmgr.exists(file2):
|
269 |
+
g_pathmgr.rm(file2)
|
270 |
+
g_pathmgr.symlink(file1, file2)
|
271 |
+
except Exception as e:
|
272 |
+
logging.info(f"Could NOT create symlink. Error: {e}")
|
273 |
+
|
274 |
+
|
275 |
+
def save_file(data, filename, append_to_json=True, verbose=True):
|
276 |
+
"""
|
277 |
+
Common i/o utility to handle saving data to various file formats.
|
278 |
+
Supported:
|
279 |
+
.pkl, .pickle, .npy, .json
|
280 |
+
Specifically for .json, users have the option to either append (default)
|
281 |
+
or rewrite by passing in Boolean value to append_to_json.
|
282 |
+
"""
|
283 |
+
if verbose:
|
284 |
+
logging.info(f"Saving data to file: {filename}")
|
285 |
+
file_ext = os.path.splitext(filename)[1]
|
286 |
+
if file_ext in [".pkl", ".pickle"]:
|
287 |
+
with g_pathmgr.open(filename, "wb") as fopen:
|
288 |
+
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
289 |
+
elif file_ext == ".npy":
|
290 |
+
with g_pathmgr.open(filename, "wb") as fopen:
|
291 |
+
np.save(fopen, data)
|
292 |
+
elif file_ext == ".json":
|
293 |
+
if append_to_json:
|
294 |
+
with g_pathmgr.open(filename, "a") as fopen:
|
295 |
+
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
296 |
+
fopen.flush()
|
297 |
+
else:
|
298 |
+
with g_pathmgr.open(filename, "w") as fopen:
|
299 |
+
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
300 |
+
fopen.flush()
|
301 |
+
elif file_ext == ".yaml":
|
302 |
+
with g_pathmgr.open(filename, "w") as fopen:
|
303 |
+
dump = yaml.dump(data)
|
304 |
+
fopen.write(dump)
|
305 |
+
fopen.flush()
|
306 |
+
else:
|
307 |
+
raise Exception(f"Saving {file_ext} is not supported yet")
|
308 |
+
|
309 |
+
if verbose:
|
310 |
+
logging.info(f"Saved data to file: {filename}")
|
311 |
+
|
312 |
+
|
313 |
+
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
314 |
+
"""
|
315 |
+
Common i/o utility to handle loading data from various file formats.
|
316 |
+
Supported:
|
317 |
+
.pkl, .pickle, .npy, .json
|
318 |
+
For the npy files, we support reading the files in mmap_mode.
|
319 |
+
If the mmap_mode of reading is not successful, we load data without the
|
320 |
+
mmap_mode.
|
321 |
+
"""
|
322 |
+
if verbose:
|
323 |
+
logging.info(f"Loading data from file: {filename}")
|
324 |
+
|
325 |
+
file_ext = os.path.splitext(filename)[1]
|
326 |
+
if file_ext == ".txt":
|
327 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
328 |
+
data = fopen.readlines()
|
329 |
+
elif file_ext in [".pkl", ".pickle"]:
|
330 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
331 |
+
data = pickle.load(fopen, encoding="latin1")
|
332 |
+
elif file_ext == ".npy":
|
333 |
+
if mmap_mode:
|
334 |
+
try:
|
335 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
336 |
+
data = np.load(
|
337 |
+
fopen,
|
338 |
+
allow_pickle=allow_pickle,
|
339 |
+
encoding="latin1",
|
340 |
+
mmap_mode=mmap_mode,
|
341 |
+
)
|
342 |
+
except ValueError as e:
|
343 |
+
logging.info(
|
344 |
+
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
345 |
+
)
|
346 |
+
data = np.load(
|
347 |
+
filename,
|
348 |
+
allow_pickle=allow_pickle,
|
349 |
+
encoding="latin1",
|
350 |
+
mmap_mode=mmap_mode,
|
351 |
+
)
|
352 |
+
logging.info("Successfully loaded without g_pathmgr")
|
353 |
+
except Exception:
|
354 |
+
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
355 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
356 |
+
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
357 |
+
else:
|
358 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
359 |
+
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
360 |
+
elif file_ext == ".json":
|
361 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
362 |
+
data = json.load(fopen)
|
363 |
+
elif file_ext == ".yaml":
|
364 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
365 |
+
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
366 |
+
elif file_ext == ".csv":
|
367 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
368 |
+
data = pd.read_csv(fopen)
|
369 |
+
else:
|
370 |
+
raise Exception(f"Reading from {file_ext} is not supported yet")
|
371 |
+
return data
|
372 |
+
|
373 |
+
|
374 |
+
def abspath(resource_path: str):
|
375 |
+
"""
|
376 |
+
Make a path absolute, but take into account prefixes like
|
377 |
+
"http://" or "manifold://"
|
378 |
+
"""
|
379 |
+
regex = re.compile(r"^\w+://")
|
380 |
+
if regex.match(resource_path) is None:
|
381 |
+
return os.path.abspath(resource_path)
|
382 |
+
else:
|
383 |
+
return resource_path
|
384 |
+
|
385 |
+
|
386 |
+
def makedir(dir_path):
|
387 |
+
"""
|
388 |
+
Create the directory if it does not exist.
|
389 |
+
"""
|
390 |
+
is_success = False
|
391 |
+
try:
|
392 |
+
if not g_pathmgr.exists(dir_path):
|
393 |
+
g_pathmgr.mkdirs(dir_path)
|
394 |
+
is_success = True
|
395 |
+
except BaseException:
|
396 |
+
logging.info(f"Error creating directory: {dir_path}")
|
397 |
+
return is_success
|
398 |
+
|
399 |
+
|
400 |
+
def is_url(input_url):
|
401 |
+
"""
|
402 |
+
Check if an input string is a url. look for http(s):// and ignoring the case
|
403 |
+
"""
|
404 |
+
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
405 |
+
return is_url
|
406 |
+
|
407 |
+
|
408 |
+
def cleanup_dir(dir):
|
409 |
+
"""
|
410 |
+
Utility for deleting a directory. Useful for cleaning the storage space
|
411 |
+
that contains various training artifacts like checkpoints, data etc.
|
412 |
+
"""
|
413 |
+
if os.path.exists(dir):
|
414 |
+
logging.info(f"Deleting directory: {dir}")
|
415 |
+
shutil.rmtree(dir)
|
416 |
+
logging.info(f"Deleted contents of directory: {dir}")
|
417 |
+
|
418 |
+
|
419 |
+
def get_file_size(filename):
|
420 |
+
"""
|
421 |
+
Given a file, get the size of file in MB
|
422 |
+
"""
|
423 |
+
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
424 |
+
return size_in_mb
|
bliva/configs/default.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
env:
|
3 |
+
# For default users
|
4 |
+
# cache_root: "cache"
|
5 |
+
# For internal use with persistent storage
|
6 |
+
cache_root: "~/.cache/bliva"
|
7 |
+
|
bliva/configs/models/bliva_flant5xxl.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
model:
|
3 |
+
arch: flant5xxl
|
4 |
+
load_finetuned: True
|
5 |
+
load_pretrained: False
|
6 |
+
|
7 |
+
pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_flanxxl_trimmed.pth"
|
8 |
+
finetuned: ''
|
9 |
+
|
10 |
+
# vit encoder
|
11 |
+
image_size: 224
|
12 |
+
drop_path_rate: 0
|
13 |
+
use_grad_checkpoint: False
|
14 |
+
vit_precision: "fp16"
|
15 |
+
freeze_vit: True
|
16 |
+
|
17 |
+
# Q-Former
|
18 |
+
num_query_token: 32
|
19 |
+
|
20 |
+
# T5
|
21 |
+
t5_model: "google/flan-t5-xxl"
|
22 |
+
|
23 |
+
# generation configs
|
24 |
+
prompt: ""
|
25 |
+
|
26 |
+
|
27 |
+
preprocess:
|
28 |
+
vis_processor:
|
29 |
+
train:
|
30 |
+
name: "blip_image_train"
|
31 |
+
image_size: 224
|
32 |
+
eval:
|
33 |
+
name: "blip_image_eval"
|
34 |
+
image_size: 224
|
35 |
+
text_processor:
|
36 |
+
train:
|
37 |
+
name: "blip_caption"
|
38 |
+
eval:
|
39 |
+
name: "blip_caption"
|
bliva/configs/models/bliva_vicuna7b.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
model:
|
3 |
+
arch: vicuna7b
|
4 |
+
load_finetuned: True
|
5 |
+
load_pretrained: False
|
6 |
+
|
7 |
+
pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth"
|
8 |
+
finetuned: 'bliva_vicuna7b.pth'
|
9 |
+
|
10 |
+
# vit encoder
|
11 |
+
image_size: 224 #336
|
12 |
+
drop_path_rate: 0
|
13 |
+
use_grad_checkpoint: False
|
14 |
+
vit_precision: "fp16"
|
15 |
+
freeze_vit: True
|
16 |
+
|
17 |
+
# Q-Former
|
18 |
+
num_query_token: 32
|
19 |
+
|
20 |
+
# path to Vicuna checkpoint
|
21 |
+
llm_model: "hf_vicuna_7b"
|
22 |
+
|
23 |
+
# generation configs
|
24 |
+
prompt: ""
|
25 |
+
|
26 |
+
|
27 |
+
preprocess:
|
28 |
+
vis_processor:
|
29 |
+
train:
|
30 |
+
name: "blip2_image_train"
|
31 |
+
image_size: 224
|
32 |
+
eval:
|
33 |
+
name: "blip_image_eval"
|
34 |
+
image_size: 224
|
35 |
+
text_processor:
|
36 |
+
train:
|
37 |
+
name: "blip_caption"
|
38 |
+
eval:
|
39 |
+
name: "blip_caption"
|
bliva/conversation/__init__.py
ADDED
File without changes
|
bliva/conversation/conversation.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
7 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
8 |
+
|
9 |
+
import dataclasses
|
10 |
+
from enum import auto, Enum
|
11 |
+
from typing import List, Tuple, Any
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
class SeparatorStyle(Enum):
|
16 |
+
"""Different separator style."""
|
17 |
+
SINGLE = auto()
|
18 |
+
TWO = auto()
|
19 |
+
THREE = auto()
|
20 |
+
|
21 |
+
|
22 |
+
@dataclasses.dataclass
|
23 |
+
class Conversation:
|
24 |
+
"""A class that keeps all conversation history."""
|
25 |
+
system: str
|
26 |
+
roles: List[str]
|
27 |
+
messages: List[List[str]]
|
28 |
+
offset: int
|
29 |
+
# system_img: List[Image.Image] = []
|
30 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
31 |
+
sep: str = "###"
|
32 |
+
sep2: str = None
|
33 |
+
|
34 |
+
skip_next: bool = False
|
35 |
+
conv_id: Any = None
|
36 |
+
|
37 |
+
def get_prompt(self):
|
38 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
39 |
+
ret = self.system + self.sep
|
40 |
+
for role, message in self.messages:
|
41 |
+
if message:
|
42 |
+
ret += role + ": " + message + self.sep
|
43 |
+
else:
|
44 |
+
ret += role + ":"
|
45 |
+
return ret
|
46 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
47 |
+
seps = [self.sep, self.sep2]
|
48 |
+
ret = self.system + seps[0]
|
49 |
+
for i, (role, message) in enumerate(self.messages):
|
50 |
+
if message:
|
51 |
+
ret += role + ": " + message + seps[i % 2]
|
52 |
+
else:
|
53 |
+
ret += role + ":"
|
54 |
+
return ret
|
55 |
+
elif self.sep_style == SeparatorStyle.THREE:
|
56 |
+
ret = self.system
|
57 |
+
for i, (role, message) in enumerate(self.messages):
|
58 |
+
if message:
|
59 |
+
if type(message) == list:
|
60 |
+
message = message[0]
|
61 |
+
ret += role + ": " + message
|
62 |
+
else:
|
63 |
+
ret += role + ":"
|
64 |
+
return ret
|
65 |
+
else:
|
66 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
67 |
+
|
68 |
+
def append_message(self, role, message):
|
69 |
+
self.messages.append([role, message])
|
70 |
+
|
71 |
+
def to_gradio_chatbot(self):
|
72 |
+
print('to_gradio_chatbot')
|
73 |
+
print(self.messages)
|
74 |
+
ret = []
|
75 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
76 |
+
if i % 2 == 0:
|
77 |
+
ret.append([msg, None])
|
78 |
+
else:
|
79 |
+
ret[-1][-1] = msg
|
80 |
+
return ret
|
81 |
+
|
82 |
+
def copy(self):
|
83 |
+
return Conversation(
|
84 |
+
system=self.system,
|
85 |
+
# system_img=self.system_img,
|
86 |
+
roles=self.roles,
|
87 |
+
messages=[[x, y] for x, y in self.messages],
|
88 |
+
offset=self.offset,
|
89 |
+
sep_style=self.sep_style,
|
90 |
+
sep=self.sep,
|
91 |
+
sep2=self.sep2,
|
92 |
+
conv_id=self.conv_id)
|
93 |
+
|
94 |
+
def dict(self):
|
95 |
+
return {
|
96 |
+
"system": self.system,
|
97 |
+
# "system_img": self.system_img,
|
98 |
+
"roles": self.roles,
|
99 |
+
"messages": self.messages,
|
100 |
+
"offset": self.offset,
|
101 |
+
"sep": self.sep,
|
102 |
+
"sep2": self.sep2,
|
103 |
+
"conv_id": self.conv_id,
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
108 |
+
|
109 |
+
def __init__(self, stops=[], encounters=1):
|
110 |
+
super().__init__()
|
111 |
+
self.stops = stops
|
112 |
+
|
113 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
114 |
+
for stop in self.stops:
|
115 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
116 |
+
return True
|
117 |
+
|
118 |
+
return False
|
119 |
+
|
120 |
+
|
121 |
+
CONV_VISION = Conversation(
|
122 |
+
system="A chat between human who asks question and you give helpful, detailed, and insightful answers to his question.",
|
123 |
+
roles=(" Question", " Answer"),
|
124 |
+
messages=[],
|
125 |
+
offset=2,
|
126 |
+
sep_style=SeparatorStyle.THREE,
|
127 |
+
sep="###",
|
128 |
+
)
|
129 |
+
|
130 |
+
CONV_DIRECT= Conversation(
|
131 |
+
system="",
|
132 |
+
roles=("", ""),
|
133 |
+
messages=[],
|
134 |
+
offset=2,
|
135 |
+
sep_style=SeparatorStyle.THREE,
|
136 |
+
sep="###",
|
137 |
+
)
|
138 |
+
|
139 |
+
class Chat:
|
140 |
+
def __init__(self, model, vis_processor, device='cuda:0'):
|
141 |
+
self.device = device
|
142 |
+
self.model = model
|
143 |
+
self.vis_processor = vis_processor
|
144 |
+
|
145 |
+
def ask(self, text, conv):
|
146 |
+
#conv.messages = [] #hack not keeping history.
|
147 |
+
conv.append_message(conv.roles[0], text)
|
148 |
+
|
149 |
+
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
150 |
+
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
151 |
+
conv.append_message(conv.roles[1], None)
|
152 |
+
|
153 |
+
question = conv.get_prompt()
|
154 |
+
image = img_list[0] #torch.stack(img_list).to(self.device)
|
155 |
+
output_text = self.model.generate({"image": image, "prompt": question}, num_beams=num_beams, temperature=temperature)
|
156 |
+
|
157 |
+
conv.messages[-1][1] = output_text
|
158 |
+
return output_text, ''
|
159 |
+
|
160 |
+
def upload_img(self, image, conv, img_list):
|
161 |
+
if isinstance(image, str): # is a image path
|
162 |
+
raw_image = Image.open(image).convert('RGB')
|
163 |
+
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
164 |
+
elif isinstance(image, Image.Image):
|
165 |
+
raw_image = image
|
166 |
+
raw_image = raw_image.convert('RGB')
|
167 |
+
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
168 |
+
elif isinstance(image, torch.Tensor):
|
169 |
+
if len(image.shape) == 3:
|
170 |
+
image = image.unsqueeze(0)
|
171 |
+
image = image.to(self.device)
|
172 |
+
|
173 |
+
#image_emb, _ = self.model.encode_img(image)
|
174 |
+
img_list.append(image)
|
175 |
+
#conv.append_message(conv.roles[0], "")
|
176 |
+
msg = "Received."
|
177 |
+
# self.conv.append_message(self.conv.roles[1], msg)
|
178 |
+
return msg
|
179 |
+
|
180 |
+
|
bliva/models/Qformer.py
ADDED
@@ -0,0 +1,1216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
* Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on huggingface code base
|
8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Dict, Any
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import Tensor, device, dtype, nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from transformers.activations import ACT2FN
|
25 |
+
from transformers.file_utils import (
|
26 |
+
ModelOutput,
|
27 |
+
)
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
31 |
+
CausalLMOutputWithCrossAttentions,
|
32 |
+
MaskedLMOutput,
|
33 |
+
MultipleChoiceModelOutput,
|
34 |
+
NextSentencePredictorOutput,
|
35 |
+
QuestionAnsweringModelOutput,
|
36 |
+
SequenceClassifierOutput,
|
37 |
+
TokenClassifierOutput,
|
38 |
+
)
|
39 |
+
from transformers.modeling_utils import (
|
40 |
+
PreTrainedModel,
|
41 |
+
apply_chunking_to_forward,
|
42 |
+
find_pruneable_heads_and_indices,
|
43 |
+
prune_linear_layer,
|
44 |
+
)
|
45 |
+
from transformers.utils import logging
|
46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
class BertEmbeddings(nn.Module):
|
52 |
+
"""Construct the embeddings from word and position embeddings."""
|
53 |
+
|
54 |
+
def __init__(self, config):
|
55 |
+
super().__init__()
|
56 |
+
self.word_embeddings = nn.Embedding(
|
57 |
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
58 |
+
)
|
59 |
+
self.position_embeddings = nn.Embedding(
|
60 |
+
config.max_position_embeddings, config.hidden_size
|
61 |
+
)
|
62 |
+
|
63 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
64 |
+
# any TensorFlow checkpoint file
|
65 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
66 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
67 |
+
|
68 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
69 |
+
self.register_buffer(
|
70 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
|
71 |
+
)
|
72 |
+
self.position_embedding_type = getattr(
|
73 |
+
config, "position_embedding_type", "absolute"
|
74 |
+
)
|
75 |
+
|
76 |
+
self.config = config
|
77 |
+
|
78 |
+
def forward(
|
79 |
+
self,
|
80 |
+
input_ids=None,
|
81 |
+
position_ids=None,
|
82 |
+
query_embeds=None,
|
83 |
+
past_key_values_length=0,
|
84 |
+
):
|
85 |
+
if input_ids is not None:
|
86 |
+
seq_length = input_ids.size()[1]
|
87 |
+
else:
|
88 |
+
seq_length = 0
|
89 |
+
|
90 |
+
if position_ids is None:
|
91 |
+
position_ids = self.position_ids[
|
92 |
+
:, past_key_values_length : seq_length + past_key_values_length
|
93 |
+
].clone()
|
94 |
+
|
95 |
+
if input_ids is not None:
|
96 |
+
embeddings = self.word_embeddings(input_ids)
|
97 |
+
if self.position_embedding_type == "absolute":
|
98 |
+
position_embeddings = self.position_embeddings(position_ids)
|
99 |
+
embeddings = embeddings + position_embeddings
|
100 |
+
|
101 |
+
if query_embeds is not None:
|
102 |
+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
103 |
+
else:
|
104 |
+
embeddings = query_embeds
|
105 |
+
|
106 |
+
embeddings = self.LayerNorm(embeddings)
|
107 |
+
embeddings = self.dropout(embeddings)
|
108 |
+
return embeddings
|
109 |
+
|
110 |
+
|
111 |
+
class BertSelfAttention(nn.Module):
|
112 |
+
def __init__(self, config, is_cross_attention):
|
113 |
+
super().__init__()
|
114 |
+
self.config = config
|
115 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
116 |
+
config, "embedding_size"
|
117 |
+
):
|
118 |
+
raise ValueError(
|
119 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
120 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
121 |
+
)
|
122 |
+
|
123 |
+
self.num_attention_heads = config.num_attention_heads
|
124 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
125 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
126 |
+
|
127 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
128 |
+
if is_cross_attention:
|
129 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
130 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
131 |
+
else:
|
132 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
133 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
134 |
+
|
135 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
136 |
+
self.position_embedding_type = getattr(
|
137 |
+
config, "position_embedding_type", "absolute"
|
138 |
+
)
|
139 |
+
if (
|
140 |
+
self.position_embedding_type == "relative_key"
|
141 |
+
or self.position_embedding_type == "relative_key_query"
|
142 |
+
):
|
143 |
+
self.max_position_embeddings = config.max_position_embeddings
|
144 |
+
self.distance_embedding = nn.Embedding(
|
145 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size
|
146 |
+
)
|
147 |
+
self.save_attention = False
|
148 |
+
|
149 |
+
def save_attn_gradients(self, attn_gradients):
|
150 |
+
self.attn_gradients = attn_gradients
|
151 |
+
|
152 |
+
def get_attn_gradients(self):
|
153 |
+
return self.attn_gradients
|
154 |
+
|
155 |
+
def save_attention_map(self, attention_map):
|
156 |
+
self.attention_map = attention_map
|
157 |
+
|
158 |
+
def get_attention_map(self):
|
159 |
+
return self.attention_map
|
160 |
+
|
161 |
+
def transpose_for_scores(self, x):
|
162 |
+
new_x_shape = x.size()[:-1] + (
|
163 |
+
self.num_attention_heads,
|
164 |
+
self.attention_head_size,
|
165 |
+
)
|
166 |
+
x = x.view(*new_x_shape)
|
167 |
+
return x.permute(0, 2, 1, 3)
|
168 |
+
|
169 |
+
def forward(
|
170 |
+
self,
|
171 |
+
hidden_states,
|
172 |
+
attention_mask=None,
|
173 |
+
head_mask=None,
|
174 |
+
encoder_hidden_states=None,
|
175 |
+
encoder_attention_mask=None,
|
176 |
+
past_key_value=None,
|
177 |
+
output_attentions=False,
|
178 |
+
):
|
179 |
+
|
180 |
+
# If this is instantiated as a cross-attention module, the keys
|
181 |
+
# and values come from an encoder; the attention mask needs to be
|
182 |
+
# such that the encoder's padding tokens are not attended to.
|
183 |
+
is_cross_attention = encoder_hidden_states is not None
|
184 |
+
|
185 |
+
if is_cross_attention:
|
186 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
187 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
188 |
+
attention_mask = encoder_attention_mask
|
189 |
+
elif past_key_value is not None:
|
190 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
191 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
192 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
193 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
194 |
+
else:
|
195 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
196 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
197 |
+
|
198 |
+
mixed_query_layer = self.query(hidden_states)
|
199 |
+
|
200 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
201 |
+
|
202 |
+
past_key_value = (key_layer, value_layer)
|
203 |
+
|
204 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
205 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
206 |
+
|
207 |
+
if (
|
208 |
+
self.position_embedding_type == "relative_key"
|
209 |
+
or self.position_embedding_type == "relative_key_query"
|
210 |
+
):
|
211 |
+
seq_length = hidden_states.size()[1]
|
212 |
+
position_ids_l = torch.arange(
|
213 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
214 |
+
).view(-1, 1)
|
215 |
+
position_ids_r = torch.arange(
|
216 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
217 |
+
).view(1, -1)
|
218 |
+
distance = position_ids_l - position_ids_r
|
219 |
+
positional_embedding = self.distance_embedding(
|
220 |
+
distance + self.max_position_embeddings - 1
|
221 |
+
)
|
222 |
+
positional_embedding = positional_embedding.to(
|
223 |
+
dtype=query_layer.dtype
|
224 |
+
) # fp16 compatibility
|
225 |
+
|
226 |
+
if self.position_embedding_type == "relative_key":
|
227 |
+
relative_position_scores = torch.einsum(
|
228 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
229 |
+
)
|
230 |
+
attention_scores = attention_scores + relative_position_scores
|
231 |
+
elif self.position_embedding_type == "relative_key_query":
|
232 |
+
relative_position_scores_query = torch.einsum(
|
233 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
234 |
+
)
|
235 |
+
relative_position_scores_key = torch.einsum(
|
236 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
237 |
+
)
|
238 |
+
attention_scores = (
|
239 |
+
attention_scores
|
240 |
+
+ relative_position_scores_query
|
241 |
+
+ relative_position_scores_key
|
242 |
+
)
|
243 |
+
|
244 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
245 |
+
if attention_mask is not None:
|
246 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
247 |
+
attention_scores = attention_scores + attention_mask
|
248 |
+
|
249 |
+
# Normalize the attention scores to probabilities.
|
250 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
251 |
+
|
252 |
+
if is_cross_attention and self.save_attention:
|
253 |
+
self.save_attention_map(attention_probs)
|
254 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
255 |
+
|
256 |
+
# This is actually dropping out entire tokens to attend to, which might
|
257 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
258 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
259 |
+
|
260 |
+
# Mask heads if we want to
|
261 |
+
if head_mask is not None:
|
262 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
263 |
+
|
264 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
265 |
+
|
266 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
267 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
268 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
269 |
+
|
270 |
+
outputs = (
|
271 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
272 |
+
)
|
273 |
+
|
274 |
+
outputs = outputs + (past_key_value,)
|
275 |
+
return outputs
|
276 |
+
|
277 |
+
|
278 |
+
class BertSelfOutput(nn.Module):
|
279 |
+
def __init__(self, config):
|
280 |
+
super().__init__()
|
281 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
282 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
283 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
284 |
+
|
285 |
+
def forward(self, hidden_states, input_tensor):
|
286 |
+
hidden_states = self.dense(hidden_states)
|
287 |
+
hidden_states = self.dropout(hidden_states)
|
288 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
289 |
+
return hidden_states
|
290 |
+
|
291 |
+
|
292 |
+
class BertAttention(nn.Module):
|
293 |
+
def __init__(self, config, is_cross_attention=False):
|
294 |
+
super().__init__()
|
295 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
296 |
+
self.output = BertSelfOutput(config)
|
297 |
+
self.pruned_heads = set()
|
298 |
+
|
299 |
+
def prune_heads(self, heads):
|
300 |
+
if len(heads) == 0:
|
301 |
+
return
|
302 |
+
heads, index = find_pruneable_heads_and_indices(
|
303 |
+
heads,
|
304 |
+
self.self.num_attention_heads,
|
305 |
+
self.self.attention_head_size,
|
306 |
+
self.pruned_heads,
|
307 |
+
)
|
308 |
+
|
309 |
+
# Prune linear layers
|
310 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
311 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
312 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
313 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
314 |
+
|
315 |
+
# Update hyper params and store pruned heads
|
316 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
317 |
+
self.self.all_head_size = (
|
318 |
+
self.self.attention_head_size * self.self.num_attention_heads
|
319 |
+
)
|
320 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
321 |
+
|
322 |
+
def forward(
|
323 |
+
self,
|
324 |
+
hidden_states,
|
325 |
+
attention_mask=None,
|
326 |
+
head_mask=None,
|
327 |
+
encoder_hidden_states=None,
|
328 |
+
encoder_attention_mask=None,
|
329 |
+
past_key_value=None,
|
330 |
+
output_attentions=False,
|
331 |
+
):
|
332 |
+
self_outputs = self.self(
|
333 |
+
hidden_states,
|
334 |
+
attention_mask,
|
335 |
+
head_mask,
|
336 |
+
encoder_hidden_states,
|
337 |
+
encoder_attention_mask,
|
338 |
+
past_key_value,
|
339 |
+
output_attentions,
|
340 |
+
)
|
341 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
342 |
+
|
343 |
+
outputs = (attention_output,) + self_outputs[
|
344 |
+
1:
|
345 |
+
] # add attentions if we output them
|
346 |
+
return outputs
|
347 |
+
|
348 |
+
|
349 |
+
class BertIntermediate(nn.Module):
|
350 |
+
def __init__(self, config):
|
351 |
+
super().__init__()
|
352 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
353 |
+
if isinstance(config.hidden_act, str):
|
354 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
355 |
+
else:
|
356 |
+
self.intermediate_act_fn = config.hidden_act
|
357 |
+
|
358 |
+
def forward(self, hidden_states):
|
359 |
+
hidden_states = self.dense(hidden_states)
|
360 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
361 |
+
return hidden_states
|
362 |
+
|
363 |
+
|
364 |
+
class BertOutput(nn.Module):
|
365 |
+
def __init__(self, config):
|
366 |
+
super().__init__()
|
367 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
368 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
369 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
370 |
+
|
371 |
+
def forward(self, hidden_states, input_tensor):
|
372 |
+
hidden_states = self.dense(hidden_states)
|
373 |
+
hidden_states = self.dropout(hidden_states)
|
374 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
375 |
+
return hidden_states
|
376 |
+
|
377 |
+
|
378 |
+
class BertLayer(nn.Module):
|
379 |
+
def __init__(self, config, layer_num):
|
380 |
+
super().__init__()
|
381 |
+
self.config = config
|
382 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
383 |
+
self.seq_len_dim = 1
|
384 |
+
self.attention = BertAttention(config)
|
385 |
+
self.layer_num = layer_num
|
386 |
+
if (
|
387 |
+
self.config.add_cross_attention
|
388 |
+
and layer_num % self.config.cross_attention_freq == 0
|
389 |
+
):
|
390 |
+
self.crossattention = BertAttention(
|
391 |
+
config, is_cross_attention=self.config.add_cross_attention
|
392 |
+
)
|
393 |
+
self.has_cross_attention = True
|
394 |
+
else:
|
395 |
+
self.has_cross_attention = False
|
396 |
+
self.intermediate = BertIntermediate(config)
|
397 |
+
self.output = BertOutput(config)
|
398 |
+
|
399 |
+
self.intermediate_query = BertIntermediate(config)
|
400 |
+
self.output_query = BertOutput(config)
|
401 |
+
|
402 |
+
def forward(
|
403 |
+
self,
|
404 |
+
hidden_states,
|
405 |
+
attention_mask=None,
|
406 |
+
head_mask=None,
|
407 |
+
encoder_hidden_states=None,
|
408 |
+
encoder_attention_mask=None,
|
409 |
+
past_key_value=None,
|
410 |
+
output_attentions=False,
|
411 |
+
query_length=0,
|
412 |
+
):
|
413 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
414 |
+
self_attn_past_key_value = (
|
415 |
+
past_key_value[:2] if past_key_value is not None else None
|
416 |
+
)
|
417 |
+
self_attention_outputs = self.attention(
|
418 |
+
hidden_states,
|
419 |
+
attention_mask,
|
420 |
+
head_mask,
|
421 |
+
output_attentions=output_attentions,
|
422 |
+
past_key_value=self_attn_past_key_value,
|
423 |
+
)
|
424 |
+
attention_output = self_attention_outputs[0]
|
425 |
+
outputs = self_attention_outputs[1:-1]
|
426 |
+
|
427 |
+
present_key_value = self_attention_outputs[-1]
|
428 |
+
|
429 |
+
if query_length > 0:
|
430 |
+
query_attention_output = attention_output[:, :query_length, :]
|
431 |
+
|
432 |
+
if self.has_cross_attention:
|
433 |
+
assert (
|
434 |
+
encoder_hidden_states is not None
|
435 |
+
), "encoder_hidden_states must be given for cross-attention layers"
|
436 |
+
cross_attention_outputs = self.crossattention(
|
437 |
+
query_attention_output,
|
438 |
+
attention_mask,
|
439 |
+
head_mask,
|
440 |
+
encoder_hidden_states,
|
441 |
+
encoder_attention_mask,
|
442 |
+
output_attentions=output_attentions,
|
443 |
+
)
|
444 |
+
query_attention_output = cross_attention_outputs[0]
|
445 |
+
outputs = (
|
446 |
+
outputs + cross_attention_outputs[1:-1]
|
447 |
+
) # add cross attentions if we output attention weights
|
448 |
+
|
449 |
+
layer_output = apply_chunking_to_forward(
|
450 |
+
self.feed_forward_chunk_query,
|
451 |
+
self.chunk_size_feed_forward,
|
452 |
+
self.seq_len_dim,
|
453 |
+
query_attention_output,
|
454 |
+
)
|
455 |
+
if attention_output.shape[1] > query_length:
|
456 |
+
layer_output_text = apply_chunking_to_forward(
|
457 |
+
self.feed_forward_chunk,
|
458 |
+
self.chunk_size_feed_forward,
|
459 |
+
self.seq_len_dim,
|
460 |
+
attention_output[:, query_length:, :],
|
461 |
+
)
|
462 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
463 |
+
else:
|
464 |
+
layer_output = apply_chunking_to_forward(
|
465 |
+
self.feed_forward_chunk,
|
466 |
+
self.chunk_size_feed_forward,
|
467 |
+
self.seq_len_dim,
|
468 |
+
attention_output,
|
469 |
+
)
|
470 |
+
outputs = (layer_output,) + outputs
|
471 |
+
|
472 |
+
outputs = outputs + (present_key_value,)
|
473 |
+
|
474 |
+
return outputs
|
475 |
+
|
476 |
+
def feed_forward_chunk(self, attention_output):
|
477 |
+
intermediate_output = self.intermediate(attention_output)
|
478 |
+
layer_output = self.output(intermediate_output, attention_output)
|
479 |
+
return layer_output
|
480 |
+
|
481 |
+
def feed_forward_chunk_query(self, attention_output):
|
482 |
+
intermediate_output = self.intermediate_query(attention_output)
|
483 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
484 |
+
return layer_output
|
485 |
+
|
486 |
+
|
487 |
+
class BertEncoder(nn.Module):
|
488 |
+
def __init__(self, config):
|
489 |
+
super().__init__()
|
490 |
+
self.config = config
|
491 |
+
self.layer = nn.ModuleList(
|
492 |
+
[BertLayer(config, i) for i in range(config.num_hidden_layers)]
|
493 |
+
)
|
494 |
+
|
495 |
+
def forward(
|
496 |
+
self,
|
497 |
+
hidden_states,
|
498 |
+
attention_mask=None,
|
499 |
+
head_mask=None,
|
500 |
+
encoder_hidden_states=None,
|
501 |
+
encoder_attention_mask=None,
|
502 |
+
past_key_values=None,
|
503 |
+
use_cache=None,
|
504 |
+
output_attentions=False,
|
505 |
+
output_hidden_states=False,
|
506 |
+
return_dict=True,
|
507 |
+
query_length=0,
|
508 |
+
):
|
509 |
+
all_hidden_states = () if output_hidden_states else None
|
510 |
+
all_self_attentions = () if output_attentions else None
|
511 |
+
all_cross_attentions = (
|
512 |
+
() if output_attentions and self.config.add_cross_attention else None
|
513 |
+
)
|
514 |
+
|
515 |
+
next_decoder_cache = () if use_cache else None
|
516 |
+
|
517 |
+
for i in range(self.config.num_hidden_layers):
|
518 |
+
layer_module = self.layer[i]
|
519 |
+
if output_hidden_states:
|
520 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
521 |
+
|
522 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
523 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
524 |
+
|
525 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
526 |
+
|
527 |
+
if use_cache:
|
528 |
+
logger.warn(
|
529 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
530 |
+
)
|
531 |
+
use_cache = False
|
532 |
+
|
533 |
+
def create_custom_forward(module):
|
534 |
+
def custom_forward(*inputs):
|
535 |
+
return module(
|
536 |
+
*inputs, past_key_value, output_attentions, query_length
|
537 |
+
)
|
538 |
+
|
539 |
+
return custom_forward
|
540 |
+
|
541 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
542 |
+
create_custom_forward(layer_module),
|
543 |
+
hidden_states,
|
544 |
+
attention_mask,
|
545 |
+
layer_head_mask,
|
546 |
+
encoder_hidden_states,
|
547 |
+
encoder_attention_mask,
|
548 |
+
)
|
549 |
+
else:
|
550 |
+
layer_outputs = layer_module(
|
551 |
+
hidden_states,
|
552 |
+
attention_mask,
|
553 |
+
layer_head_mask,
|
554 |
+
encoder_hidden_states,
|
555 |
+
encoder_attention_mask,
|
556 |
+
past_key_value,
|
557 |
+
output_attentions,
|
558 |
+
query_length,
|
559 |
+
)
|
560 |
+
|
561 |
+
hidden_states = layer_outputs[0]
|
562 |
+
if use_cache:
|
563 |
+
next_decoder_cache += (layer_outputs[-1],)
|
564 |
+
if output_attentions:
|
565 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
566 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
567 |
+
|
568 |
+
if output_hidden_states:
|
569 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
570 |
+
|
571 |
+
if not return_dict:
|
572 |
+
return tuple(
|
573 |
+
v
|
574 |
+
for v in [
|
575 |
+
hidden_states,
|
576 |
+
next_decoder_cache,
|
577 |
+
all_hidden_states,
|
578 |
+
all_self_attentions,
|
579 |
+
all_cross_attentions,
|
580 |
+
]
|
581 |
+
if v is not None
|
582 |
+
)
|
583 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
584 |
+
last_hidden_state=hidden_states,
|
585 |
+
past_key_values=next_decoder_cache,
|
586 |
+
hidden_states=all_hidden_states,
|
587 |
+
attentions=all_self_attentions,
|
588 |
+
cross_attentions=all_cross_attentions,
|
589 |
+
)
|
590 |
+
|
591 |
+
|
592 |
+
class BertPooler(nn.Module):
|
593 |
+
def __init__(self, config):
|
594 |
+
super().__init__()
|
595 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
596 |
+
self.activation = nn.Tanh()
|
597 |
+
|
598 |
+
def forward(self, hidden_states):
|
599 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
600 |
+
# to the first token.
|
601 |
+
first_token_tensor = hidden_states[:, 0]
|
602 |
+
pooled_output = self.dense(first_token_tensor)
|
603 |
+
pooled_output = self.activation(pooled_output)
|
604 |
+
return pooled_output
|
605 |
+
|
606 |
+
|
607 |
+
class BertPredictionHeadTransform(nn.Module):
|
608 |
+
def __init__(self, config):
|
609 |
+
super().__init__()
|
610 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
611 |
+
if isinstance(config.hidden_act, str):
|
612 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
613 |
+
else:
|
614 |
+
self.transform_act_fn = config.hidden_act
|
615 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
616 |
+
|
617 |
+
def forward(self, hidden_states):
|
618 |
+
hidden_states = self.dense(hidden_states)
|
619 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
620 |
+
hidden_states = self.LayerNorm(hidden_states)
|
621 |
+
return hidden_states
|
622 |
+
|
623 |
+
|
624 |
+
class BertLMPredictionHead(nn.Module):
|
625 |
+
def __init__(self, config):
|
626 |
+
super().__init__()
|
627 |
+
self.transform = BertPredictionHeadTransform(config)
|
628 |
+
|
629 |
+
# The output weights are the same as the input embeddings, but there is
|
630 |
+
# an output-only bias for each token.
|
631 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
632 |
+
|
633 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
634 |
+
|
635 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
636 |
+
self.decoder.bias = self.bias
|
637 |
+
|
638 |
+
def forward(self, hidden_states):
|
639 |
+
hidden_states = self.transform(hidden_states)
|
640 |
+
hidden_states = self.decoder(hidden_states)
|
641 |
+
return hidden_states
|
642 |
+
|
643 |
+
|
644 |
+
class BertOnlyMLMHead(nn.Module):
|
645 |
+
def __init__(self, config):
|
646 |
+
super().__init__()
|
647 |
+
self.predictions = BertLMPredictionHead(config)
|
648 |
+
|
649 |
+
def forward(self, sequence_output):
|
650 |
+
prediction_scores = self.predictions(sequence_output)
|
651 |
+
return prediction_scores
|
652 |
+
|
653 |
+
|
654 |
+
class BertPreTrainedModel(PreTrainedModel):
|
655 |
+
"""
|
656 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
657 |
+
models.
|
658 |
+
"""
|
659 |
+
|
660 |
+
config_class = BertConfig
|
661 |
+
base_model_prefix = "bert"
|
662 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
663 |
+
|
664 |
+
def _init_weights(self, module):
|
665 |
+
"""Initialize the weights"""
|
666 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
667 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
668 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
669 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
670 |
+
elif isinstance(module, nn.LayerNorm):
|
671 |
+
module.bias.data.zero_()
|
672 |
+
module.weight.data.fill_(1.0)
|
673 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
674 |
+
module.bias.data.zero_()
|
675 |
+
|
676 |
+
|
677 |
+
class BertModel(BertPreTrainedModel):
|
678 |
+
"""
|
679 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
680 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
681 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
682 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
683 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
684 |
+
input to the forward pass.
|
685 |
+
"""
|
686 |
+
|
687 |
+
def __init__(self, config, add_pooling_layer=False):
|
688 |
+
super().__init__(config)
|
689 |
+
self.config = config
|
690 |
+
|
691 |
+
self.embeddings = BertEmbeddings(config)
|
692 |
+
|
693 |
+
self.encoder = BertEncoder(config)
|
694 |
+
|
695 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
696 |
+
|
697 |
+
self.init_weights()
|
698 |
+
|
699 |
+
def get_input_embeddings(self):
|
700 |
+
return self.embeddings.word_embeddings
|
701 |
+
|
702 |
+
def set_input_embeddings(self, value):
|
703 |
+
self.embeddings.word_embeddings = value
|
704 |
+
|
705 |
+
def _prune_heads(self, heads_to_prune):
|
706 |
+
"""
|
707 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
708 |
+
class PreTrainedModel
|
709 |
+
"""
|
710 |
+
for layer, heads in heads_to_prune.items():
|
711 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
712 |
+
|
713 |
+
def get_extended_attention_mask(
|
714 |
+
self,
|
715 |
+
attention_mask: Tensor,
|
716 |
+
input_shape: Tuple[int],
|
717 |
+
device: device,
|
718 |
+
is_decoder: bool,
|
719 |
+
has_query: bool = False,
|
720 |
+
) -> Tensor:
|
721 |
+
"""
|
722 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
723 |
+
|
724 |
+
Arguments:
|
725 |
+
attention_mask (:obj:`torch.Tensor`):
|
726 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
727 |
+
input_shape (:obj:`Tuple[int]`):
|
728 |
+
The shape of the input to the model.
|
729 |
+
device: (:obj:`torch.device`):
|
730 |
+
The device of the input to the model.
|
731 |
+
|
732 |
+
Returns:
|
733 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
734 |
+
"""
|
735 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
736 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
737 |
+
if attention_mask.dim() == 3:
|
738 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
739 |
+
elif attention_mask.dim() == 2:
|
740 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
741 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
742 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
743 |
+
if is_decoder:
|
744 |
+
batch_size, seq_length = input_shape
|
745 |
+
|
746 |
+
seq_ids = torch.arange(seq_length, device=device)
|
747 |
+
causal_mask = (
|
748 |
+
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
|
749 |
+
<= seq_ids[None, :, None]
|
750 |
+
)
|
751 |
+
|
752 |
+
# add a prefix ones mask to the causal mask
|
753 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
754 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
755 |
+
|
756 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
757 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
758 |
+
if has_query: # UniLM style attention mask
|
759 |
+
causal_mask = torch.cat(
|
760 |
+
[
|
761 |
+
torch.zeros(
|
762 |
+
(batch_size, prefix_seq_len, seq_length),
|
763 |
+
device=device,
|
764 |
+
dtype=causal_mask.dtype,
|
765 |
+
),
|
766 |
+
causal_mask,
|
767 |
+
],
|
768 |
+
axis=1,
|
769 |
+
)
|
770 |
+
causal_mask = torch.cat(
|
771 |
+
[
|
772 |
+
torch.ones(
|
773 |
+
(batch_size, causal_mask.shape[1], prefix_seq_len),
|
774 |
+
device=device,
|
775 |
+
dtype=causal_mask.dtype,
|
776 |
+
),
|
777 |
+
causal_mask,
|
778 |
+
],
|
779 |
+
axis=-1,
|
780 |
+
)
|
781 |
+
extended_attention_mask = (
|
782 |
+
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
783 |
+
)
|
784 |
+
else:
|
785 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
786 |
+
else:
|
787 |
+
raise ValueError(
|
788 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
789 |
+
input_shape, attention_mask.shape
|
790 |
+
)
|
791 |
+
)
|
792 |
+
|
793 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
794 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
795 |
+
# positions we want to attend and -10000.0 for masked positions.
|
796 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
797 |
+
# effectively the same as removing these entirely.
|
798 |
+
extended_attention_mask = extended_attention_mask.to(
|
799 |
+
dtype=self.dtype
|
800 |
+
) # fp16 compatibility
|
801 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
802 |
+
return extended_attention_mask
|
803 |
+
|
804 |
+
def forward(
|
805 |
+
self,
|
806 |
+
input_ids=None,
|
807 |
+
attention_mask=None,
|
808 |
+
position_ids=None,
|
809 |
+
head_mask=None,
|
810 |
+
query_embeds=None,
|
811 |
+
encoder_hidden_states=None,
|
812 |
+
encoder_attention_mask=None,
|
813 |
+
past_key_values=None,
|
814 |
+
use_cache=None,
|
815 |
+
output_attentions=None,
|
816 |
+
output_hidden_states=None,
|
817 |
+
return_dict=None,
|
818 |
+
is_decoder=False,
|
819 |
+
):
|
820 |
+
r"""
|
821 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
822 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
823 |
+
the model is configured as a decoder.
|
824 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
825 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
826 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
827 |
+
- 1 for tokens that are **not masked**,
|
828 |
+
- 0 for tokens that are **masked**.
|
829 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
830 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
831 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
832 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
833 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
834 |
+
use_cache (:obj:`bool`, `optional`):
|
835 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
836 |
+
decoding (see :obj:`past_key_values`).
|
837 |
+
"""
|
838 |
+
output_attentions = (
|
839 |
+
output_attentions
|
840 |
+
if output_attentions is not None
|
841 |
+
else self.config.output_attentions
|
842 |
+
)
|
843 |
+
output_hidden_states = (
|
844 |
+
output_hidden_states
|
845 |
+
if output_hidden_states is not None
|
846 |
+
else self.config.output_hidden_states
|
847 |
+
)
|
848 |
+
return_dict = (
|
849 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
850 |
+
)
|
851 |
+
|
852 |
+
# use_cache = use_cache if use_cache is not None else self.config.use_cache
|
853 |
+
|
854 |
+
if input_ids is None:
|
855 |
+
assert (
|
856 |
+
query_embeds is not None
|
857 |
+
), "You have to specify query_embeds when input_ids is None"
|
858 |
+
|
859 |
+
# past_key_values_length
|
860 |
+
past_key_values_length = (
|
861 |
+
past_key_values[0][0].shape[2] - self.config.query_length
|
862 |
+
if past_key_values is not None
|
863 |
+
else 0
|
864 |
+
)
|
865 |
+
|
866 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
867 |
+
|
868 |
+
embedding_output = self.embeddings(
|
869 |
+
input_ids=input_ids,
|
870 |
+
position_ids=position_ids,
|
871 |
+
query_embeds=query_embeds,
|
872 |
+
past_key_values_length=past_key_values_length,
|
873 |
+
)
|
874 |
+
|
875 |
+
input_shape = embedding_output.size()[:-1]
|
876 |
+
batch_size, seq_length = input_shape
|
877 |
+
device = embedding_output.device
|
878 |
+
|
879 |
+
if attention_mask is None:
|
880 |
+
attention_mask = torch.ones(
|
881 |
+
((batch_size, seq_length + past_key_values_length)), device=device
|
882 |
+
)
|
883 |
+
|
884 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
885 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
886 |
+
if is_decoder:
|
887 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
888 |
+
attention_mask,
|
889 |
+
input_ids.shape,
|
890 |
+
device,
|
891 |
+
is_decoder,
|
892 |
+
has_query=(query_embeds is not None),
|
893 |
+
)
|
894 |
+
else:
|
895 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
896 |
+
attention_mask, input_shape, device, is_decoder
|
897 |
+
)
|
898 |
+
|
899 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
900 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
901 |
+
if encoder_hidden_states is not None:
|
902 |
+
if type(encoder_hidden_states) == list:
|
903 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
|
904 |
+
0
|
905 |
+
].size()
|
906 |
+
else:
|
907 |
+
(
|
908 |
+
encoder_batch_size,
|
909 |
+
encoder_sequence_length,
|
910 |
+
_,
|
911 |
+
) = encoder_hidden_states.size()
|
912 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
913 |
+
|
914 |
+
if type(encoder_attention_mask) == list:
|
915 |
+
encoder_extended_attention_mask = [
|
916 |
+
self.invert_attention_mask(mask) for mask in encoder_attention_mask
|
917 |
+
]
|
918 |
+
elif encoder_attention_mask is None:
|
919 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
920 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
921 |
+
encoder_attention_mask
|
922 |
+
)
|
923 |
+
else:
|
924 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
925 |
+
encoder_attention_mask
|
926 |
+
)
|
927 |
+
else:
|
928 |
+
encoder_extended_attention_mask = None
|
929 |
+
|
930 |
+
# Prepare head mask if needed
|
931 |
+
# 1.0 in head_mask indicate we keep the head
|
932 |
+
# attention_probs has shape bsz x n_heads x N x N
|
933 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
934 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
935 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
936 |
+
|
937 |
+
encoder_outputs = self.encoder(
|
938 |
+
embedding_output,
|
939 |
+
attention_mask=extended_attention_mask,
|
940 |
+
head_mask=head_mask,
|
941 |
+
encoder_hidden_states=encoder_hidden_states,
|
942 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
943 |
+
past_key_values=past_key_values,
|
944 |
+
use_cache=use_cache,
|
945 |
+
output_attentions=output_attentions,
|
946 |
+
output_hidden_states=output_hidden_states,
|
947 |
+
return_dict=return_dict,
|
948 |
+
query_length=query_length,
|
949 |
+
)
|
950 |
+
sequence_output = encoder_outputs[0]
|
951 |
+
pooled_output = (
|
952 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
953 |
+
)
|
954 |
+
|
955 |
+
if not return_dict:
|
956 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
957 |
+
|
958 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
959 |
+
last_hidden_state=sequence_output,
|
960 |
+
pooler_output=pooled_output,
|
961 |
+
past_key_values=encoder_outputs.past_key_values,
|
962 |
+
hidden_states=encoder_outputs.hidden_states,
|
963 |
+
attentions=encoder_outputs.attentions,
|
964 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
965 |
+
)
|
966 |
+
|
967 |
+
|
968 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
969 |
+
|
970 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
971 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
972 |
+
|
973 |
+
def __init__(self, config):
|
974 |
+
super().__init__(config)
|
975 |
+
|
976 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
977 |
+
self.cls = BertOnlyMLMHead(config)
|
978 |
+
|
979 |
+
self.init_weights()
|
980 |
+
|
981 |
+
def get_output_embeddings(self):
|
982 |
+
return self.cls.predictions.decoder
|
983 |
+
|
984 |
+
def set_output_embeddings(self, new_embeddings):
|
985 |
+
self.cls.predictions.decoder = new_embeddings
|
986 |
+
|
987 |
+
def forward(
|
988 |
+
self,
|
989 |
+
input_ids=None,
|
990 |
+
attention_mask=None,
|
991 |
+
position_ids=None,
|
992 |
+
head_mask=None,
|
993 |
+
query_embeds=None,
|
994 |
+
encoder_hidden_states=None,
|
995 |
+
encoder_attention_mask=None,
|
996 |
+
labels=None,
|
997 |
+
past_key_values=None,
|
998 |
+
use_cache=True,
|
999 |
+
output_attentions=None,
|
1000 |
+
output_hidden_states=None,
|
1001 |
+
return_dict=None,
|
1002 |
+
return_logits=False,
|
1003 |
+
is_decoder=True,
|
1004 |
+
reduction="mean",
|
1005 |
+
):
|
1006 |
+
r"""
|
1007 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
1008 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
1009 |
+
the model is configured as a decoder.
|
1010 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1011 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1012 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
1013 |
+
- 1 for tokens that are **not masked**,
|
1014 |
+
- 0 for tokens that are **masked**.
|
1015 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1016 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
1017 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
1018 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
1019 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1020 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1021 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
1022 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
1023 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
1024 |
+
use_cache (:obj:`bool`, `optional`):
|
1025 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
1026 |
+
decoding (see :obj:`past_key_values`).
|
1027 |
+
Returns:
|
1028 |
+
Example::
|
1029 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
1030 |
+
>>> import torch
|
1031 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
1032 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
1033 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
1034 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1035 |
+
>>> outputs = model(**inputs)
|
1036 |
+
>>> prediction_logits = outputs.logits
|
1037 |
+
"""
|
1038 |
+
return_dict = (
|
1039 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1040 |
+
)
|
1041 |
+
if labels is not None:
|
1042 |
+
use_cache = False
|
1043 |
+
if past_key_values is not None:
|
1044 |
+
query_embeds = None
|
1045 |
+
|
1046 |
+
outputs = self.bert(
|
1047 |
+
input_ids,
|
1048 |
+
attention_mask=attention_mask,
|
1049 |
+
position_ids=position_ids,
|
1050 |
+
head_mask=head_mask,
|
1051 |
+
query_embeds=query_embeds,
|
1052 |
+
encoder_hidden_states=encoder_hidden_states,
|
1053 |
+
encoder_attention_mask=encoder_attention_mask,
|
1054 |
+
past_key_values=past_key_values,
|
1055 |
+
use_cache=use_cache,
|
1056 |
+
output_attentions=output_attentions,
|
1057 |
+
output_hidden_states=output_hidden_states,
|
1058 |
+
return_dict=return_dict,
|
1059 |
+
is_decoder=is_decoder,
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
sequence_output = outputs[0]
|
1063 |
+
if query_embeds is not None:
|
1064 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
1065 |
+
|
1066 |
+
prediction_scores = self.cls(sequence_output)
|
1067 |
+
|
1068 |
+
if return_logits:
|
1069 |
+
return prediction_scores[:, :-1, :].contiguous()
|
1070 |
+
|
1071 |
+
lm_loss = None
|
1072 |
+
if labels is not None:
|
1073 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1074 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
1075 |
+
labels = labels[:, 1:].contiguous()
|
1076 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
1077 |
+
lm_loss = loss_fct(
|
1078 |
+
shifted_prediction_scores.view(-1, self.config.vocab_size),
|
1079 |
+
labels.view(-1),
|
1080 |
+
)
|
1081 |
+
if reduction == "none":
|
1082 |
+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
1083 |
+
|
1084 |
+
if not return_dict:
|
1085 |
+
output = (prediction_scores,) + outputs[2:]
|
1086 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
1087 |
+
|
1088 |
+
return CausalLMOutputWithCrossAttentions(
|
1089 |
+
loss=lm_loss,
|
1090 |
+
logits=prediction_scores,
|
1091 |
+
past_key_values=outputs.past_key_values,
|
1092 |
+
hidden_states=outputs.hidden_states,
|
1093 |
+
attentions=outputs.attentions,
|
1094 |
+
cross_attentions=outputs.cross_attentions,
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
def prepare_inputs_for_generation(
|
1098 |
+
self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
|
1099 |
+
):
|
1100 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1101 |
+
if attention_mask is None:
|
1102 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
1103 |
+
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
|
1104 |
+
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
|
1105 |
+
|
1106 |
+
# cut decoder_input_ids if past is used
|
1107 |
+
if past is not None:
|
1108 |
+
input_ids = input_ids[:, -1:]
|
1109 |
+
|
1110 |
+
return {
|
1111 |
+
"input_ids": input_ids,
|
1112 |
+
"query_embeds": query_embeds,
|
1113 |
+
"attention_mask": attention_mask,
|
1114 |
+
"past_key_values": past,
|
1115 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
1116 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
1117 |
+
"is_decoder": True,
|
1118 |
+
}
|
1119 |
+
|
1120 |
+
def _reorder_cache(self, past, beam_idx):
|
1121 |
+
reordered_past = ()
|
1122 |
+
for layer_past in past:
|
1123 |
+
reordered_past += (
|
1124 |
+
tuple(
|
1125 |
+
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1126 |
+
),
|
1127 |
+
)
|
1128 |
+
return reordered_past
|
1129 |
+
|
1130 |
+
|
1131 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
1132 |
+
|
1133 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1134 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1135 |
+
|
1136 |
+
def __init__(self, config):
|
1137 |
+
super().__init__(config)
|
1138 |
+
|
1139 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1140 |
+
self.cls = BertOnlyMLMHead(config)
|
1141 |
+
|
1142 |
+
self.init_weights()
|
1143 |
+
|
1144 |
+
def get_output_embeddings(self):
|
1145 |
+
return self.cls.predictions.decoder
|
1146 |
+
|
1147 |
+
def set_output_embeddings(self, new_embeddings):
|
1148 |
+
self.cls.predictions.decoder = new_embeddings
|
1149 |
+
|
1150 |
+
def forward(
|
1151 |
+
self,
|
1152 |
+
input_ids=None,
|
1153 |
+
attention_mask=None,
|
1154 |
+
position_ids=None,
|
1155 |
+
head_mask=None,
|
1156 |
+
query_embeds=None,
|
1157 |
+
encoder_hidden_states=None,
|
1158 |
+
encoder_attention_mask=None,
|
1159 |
+
labels=None,
|
1160 |
+
output_attentions=None,
|
1161 |
+
output_hidden_states=None,
|
1162 |
+
return_dict=None,
|
1163 |
+
return_logits=False,
|
1164 |
+
is_decoder=False,
|
1165 |
+
):
|
1166 |
+
r"""
|
1167 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1168 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
1169 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
1170 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
1171 |
+
"""
|
1172 |
+
|
1173 |
+
return_dict = (
|
1174 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
outputs = self.bert(
|
1178 |
+
input_ids,
|
1179 |
+
attention_mask=attention_mask,
|
1180 |
+
position_ids=position_ids,
|
1181 |
+
head_mask=head_mask,
|
1182 |
+
query_embeds=query_embeds,
|
1183 |
+
encoder_hidden_states=encoder_hidden_states,
|
1184 |
+
encoder_attention_mask=encoder_attention_mask,
|
1185 |
+
output_attentions=output_attentions,
|
1186 |
+
output_hidden_states=output_hidden_states,
|
1187 |
+
return_dict=return_dict,
|
1188 |
+
is_decoder=is_decoder,
|
1189 |
+
)
|
1190 |
+
|
1191 |
+
if query_embeds is not None:
|
1192 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
1193 |
+
prediction_scores = self.cls(sequence_output)
|
1194 |
+
|
1195 |
+
if return_logits:
|
1196 |
+
return prediction_scores
|
1197 |
+
|
1198 |
+
masked_lm_loss = None
|
1199 |
+
if labels is not None:
|
1200 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
1201 |
+
masked_lm_loss = loss_fct(
|
1202 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
1203 |
+
)
|
1204 |
+
|
1205 |
+
if not return_dict:
|
1206 |
+
output = (prediction_scores,) + outputs[2:]
|
1207 |
+
return (
|
1208 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1209 |
+
)
|
1210 |
+
|
1211 |
+
return MaskedLMOutput(
|
1212 |
+
loss=masked_lm_loss,
|
1213 |
+
logits=prediction_scores,
|
1214 |
+
hidden_states=outputs.hidden_states,
|
1215 |
+
attentions=outputs.attentions,
|
1216 |
+
)
|
bliva/models/__init__.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import torch
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from bliva.common.registry import registry
|
12 |
+
|
13 |
+
from bliva.models.base_model import BaseModel
|
14 |
+
|
15 |
+
from bliva.models.blip2 import Blip2Base
|
16 |
+
|
17 |
+
from bliva.models.bliva_flant5xxl import BLIVAFlanT5
|
18 |
+
from bliva.models.bliva_vicuna7b import BLIVAVicuna
|
19 |
+
|
20 |
+
from bliva.models.vit import VisionTransformerEncoder
|
21 |
+
|
22 |
+
|
23 |
+
from bliva.processors.base_processor import BaseProcessor
|
24 |
+
|
25 |
+
|
26 |
+
__all__ = [
|
27 |
+
"load_model",
|
28 |
+
"BaseModel",
|
29 |
+
"Blip2Base",
|
30 |
+
"BLIVAFlanT5",
|
31 |
+
"BLIVAVicuna",
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
|
36 |
+
"""
|
37 |
+
Load supported models.
|
38 |
+
|
39 |
+
To list all available models and types in registry:
|
40 |
+
>>> from bliva.models import model_zoo
|
41 |
+
>>> print(model_zoo)
|
42 |
+
|
43 |
+
Args:
|
44 |
+
name (str): name of the model.
|
45 |
+
model_type (str): type of the model.
|
46 |
+
is_eval (bool): whether the model is in eval mode. Default: False.
|
47 |
+
device (str): device to use. Default: "cpu".
|
48 |
+
checkpoint (str): path or to checkpoint. Default: None.
|
49 |
+
Note that expecting the checkpoint to have the same keys in state_dict as the model.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
model (torch.nn.Module): model.
|
53 |
+
"""
|
54 |
+
|
55 |
+
model = registry.get_model_class(name).from_pretrained(model_type=model_type)
|
56 |
+
|
57 |
+
if checkpoint is not None:
|
58 |
+
model.load_checkpoint(checkpoint)
|
59 |
+
|
60 |
+
if is_eval:
|
61 |
+
model.eval()
|
62 |
+
|
63 |
+
if device == "cpu":
|
64 |
+
model = model.float()
|
65 |
+
|
66 |
+
return model.to(device)
|
67 |
+
|
68 |
+
|
69 |
+
def load_preprocess(config):
|
70 |
+
"""
|
71 |
+
Load preprocessor configs and construct preprocessors.
|
72 |
+
|
73 |
+
If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
config (dict): preprocessor configs.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
vis_processors (dict): preprocessors for visual inputs.
|
80 |
+
txt_processors (dict): preprocessors for text inputs.
|
81 |
+
|
82 |
+
Key is "train" or "eval" for processors used in training and evaluation respectively.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def _build_proc_from_cfg(cfg):
|
86 |
+
return (
|
87 |
+
registry.get_processor_class(cfg.name).from_config(cfg)
|
88 |
+
if cfg is not None
|
89 |
+
else BaseProcessor()
|
90 |
+
)
|
91 |
+
|
92 |
+
vis_processors = dict()
|
93 |
+
txt_processors = dict()
|
94 |
+
|
95 |
+
vis_proc_cfg = config.get("vis_processor")
|
96 |
+
txt_proc_cfg = config.get("text_processor")
|
97 |
+
|
98 |
+
if vis_proc_cfg is not None:
|
99 |
+
vis_train_cfg = vis_proc_cfg.get("train")
|
100 |
+
vis_eval_cfg = vis_proc_cfg.get("eval")
|
101 |
+
else:
|
102 |
+
vis_train_cfg = None
|
103 |
+
vis_eval_cfg = None
|
104 |
+
|
105 |
+
vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
|
106 |
+
vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
|
107 |
+
|
108 |
+
if txt_proc_cfg is not None:
|
109 |
+
txt_train_cfg = txt_proc_cfg.get("train")
|
110 |
+
txt_eval_cfg = txt_proc_cfg.get("eval")
|
111 |
+
else:
|
112 |
+
txt_train_cfg = None
|
113 |
+
txt_eval_cfg = None
|
114 |
+
|
115 |
+
txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
|
116 |
+
txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
|
117 |
+
|
118 |
+
return vis_processors, txt_processors
|
119 |
+
|
120 |
+
|
121 |
+
def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
|
122 |
+
"""
|
123 |
+
Load model and its related preprocessors.
|
124 |
+
|
125 |
+
List all available models and types in registry:
|
126 |
+
>>> from bliva.models import model_zoo
|
127 |
+
>>> print(model_zoo)
|
128 |
+
|
129 |
+
Args:
|
130 |
+
name (str): name of the model.
|
131 |
+
model_type (str): type of the model.
|
132 |
+
is_eval (bool): whether the model is in eval mode. Default: False.
|
133 |
+
device (str): device to use. Default: "cpu".
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
model (torch.nn.Module): model.
|
137 |
+
vis_processors (dict): preprocessors for visual inputs.
|
138 |
+
txt_processors (dict): preprocessors for text inputs.
|
139 |
+
"""
|
140 |
+
model_cls = registry.get_model_class(name)
|
141 |
+
|
142 |
+
# load model
|
143 |
+
model = model_cls.from_pretrained(model_type=model_type)
|
144 |
+
|
145 |
+
if is_eval:
|
146 |
+
model.eval()
|
147 |
+
|
148 |
+
# load preprocess
|
149 |
+
cfg = OmegaConf.load(model_cls.default_config_path(model_type))
|
150 |
+
if cfg is not None:
|
151 |
+
preprocess_cfg = cfg.preprocess
|
152 |
+
|
153 |
+
vis_processors, txt_processors = load_preprocess(preprocess_cfg)
|
154 |
+
else:
|
155 |
+
vis_processors, txt_processors = None, None
|
156 |
+
logging.info(
|
157 |
+
f"""No default preprocess for model {name} ({model_type}).
|
158 |
+
This can happen if the model is not finetuned on downstream datasets,
|
159 |
+
or it is not intended for direct use without finetuning.
|
160 |
+
"""
|
161 |
+
)
|
162 |
+
|
163 |
+
if device == "cpu" or device == torch.device("cpu"):
|
164 |
+
model = model.float()
|
165 |
+
|
166 |
+
return model.to(device), vis_processors, txt_processors
|
167 |
+
|
168 |
+
|
169 |
+
class ModelZoo:
|
170 |
+
"""
|
171 |
+
A utility class to create string representation of available model architectures and types.
|
172 |
+
|
173 |
+
>>> from bliva.models import model_zoo
|
174 |
+
>>> # list all available models
|
175 |
+
>>> print(model_zoo)
|
176 |
+
>>> # show total number of models
|
177 |
+
>>> print(len(model_zoo))
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self) -> None:
|
181 |
+
self.model_zoo = {
|
182 |
+
k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
|
183 |
+
for k, v in registry.mapping["model_name_mapping"].items()
|
184 |
+
}
|
185 |
+
|
186 |
+
def __str__(self) -> str:
|
187 |
+
return (
|
188 |
+
"=" * 50
|
189 |
+
+ "\n"
|
190 |
+
+ f"{'Architectures':<30} {'Types'}\n"
|
191 |
+
+ "=" * 50
|
192 |
+
+ "\n"
|
193 |
+
+ "\n".join(
|
194 |
+
[
|
195 |
+
f"{name:<30} {', '.join(types)}"
|
196 |
+
for name, types in self.model_zoo.items()
|
197 |
+
]
|
198 |
+
)
|
199 |
+
)
|
200 |
+
|
201 |
+
def __iter__(self):
|
202 |
+
return iter(self.model_zoo.items())
|
203 |
+
|
204 |
+
def __len__(self):
|
205 |
+
return sum([len(v) for v in self.model_zoo.values()])
|
206 |
+
|
207 |
+
|
208 |
+
model_zoo = ModelZoo()
|
bliva/models/base_model.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from bliva.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
15 |
+
from bliva.common.utils import get_abs_path, is_url
|
16 |
+
from omegaconf import OmegaConf
|
17 |
+
|
18 |
+
|
19 |
+
class BaseModel(nn.Module):
|
20 |
+
"""Base class for models."""
|
21 |
+
|
22 |
+
def __init__(self):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
@property
|
26 |
+
def device(self):
|
27 |
+
return list(self.parameters())[0].device
|
28 |
+
|
29 |
+
def load_checkpoint(self, url_or_filename):
|
30 |
+
"""
|
31 |
+
Load from a finetuned checkpoint.
|
32 |
+
|
33 |
+
This should expect no mismatch in the model keys and the checkpoint keys.
|
34 |
+
"""
|
35 |
+
|
36 |
+
if is_url(url_or_filename):
|
37 |
+
cached_file = download_cached_file(
|
38 |
+
url_or_filename, check_hash=False, progress=True
|
39 |
+
)
|
40 |
+
checkpoint = torch.load(cached_file, map_location="cuda") #hack cpu
|
41 |
+
elif os.path.isfile(url_or_filename):
|
42 |
+
checkpoint = torch.load(url_or_filename, map_location="cuda") #hack cpu
|
43 |
+
else:
|
44 |
+
raise RuntimeError("checkpoint url or path is invalid")
|
45 |
+
|
46 |
+
if "model" in checkpoint.keys():
|
47 |
+
state_dict = checkpoint["model"]
|
48 |
+
else:
|
49 |
+
state_dict = checkpoint
|
50 |
+
|
51 |
+
msg = self.load_state_dict(state_dict, strict=False)
|
52 |
+
|
53 |
+
logging.info("Missing keys {}".format(msg.missing_keys))
|
54 |
+
logging.info("load checkpoint from %s" % url_or_filename)
|
55 |
+
|
56 |
+
return msg
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
def from_pretrained(cls, model_type):
|
60 |
+
"""
|
61 |
+
Build a pretrained model from default configuration file, specified by model_type.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
- model_type (str): model type, specifying architecture and checkpoints.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
|
68 |
+
"""
|
69 |
+
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
70 |
+
model = cls.from_config(model_cfg)
|
71 |
+
|
72 |
+
return model
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def default_config_path(cls, model_type):
|
76 |
+
assert (
|
77 |
+
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
|
78 |
+
), "Unknown model type {}".format(model_type)
|
79 |
+
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
80 |
+
|
81 |
+
def load_checkpoint_from_config(self, cfg, **kwargs):
|
82 |
+
"""
|
83 |
+
Load checkpoint as specified in the config file.
|
84 |
+
|
85 |
+
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
|
86 |
+
When loading the pretrained model, each task-specific architecture may define their
|
87 |
+
own load_from_pretrained() method.
|
88 |
+
"""
|
89 |
+
load_finetuned = cfg.get("load_finetuned", True)
|
90 |
+
#load_finetuned = False
|
91 |
+
if load_finetuned:
|
92 |
+
finetune_path = cfg.get("finetuned", None)
|
93 |
+
assert (
|
94 |
+
finetune_path is not None
|
95 |
+
), "Found load_finetuned is True, but finetune_path is None."
|
96 |
+
self.load_checkpoint(url_or_filename=finetune_path)
|
97 |
+
else:
|
98 |
+
load_pretrained = cfg.get("load_pretrained", True)
|
99 |
+
if load_pretrained:
|
100 |
+
# load pre-trained weights
|
101 |
+
pretrain_path = cfg.get("pretrained", None)
|
102 |
+
assert "Found load_finetuned is False, but pretrain_path is None."
|
103 |
+
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
|
104 |
+
|
105 |
+
|
106 |
+
def before_evaluation(self, **kwargs):
|
107 |
+
pass
|
108 |
+
|
109 |
+
def show_n_params(self, return_str=True):
|
110 |
+
tot = 0
|
111 |
+
for p in self.parameters():
|
112 |
+
w = 1
|
113 |
+
for x in p.shape:
|
114 |
+
w *= x
|
115 |
+
tot += w
|
116 |
+
if return_str:
|
117 |
+
if tot >= 1e6:
|
118 |
+
return "{:.1f}M".format(tot / 1e6)
|
119 |
+
else:
|
120 |
+
return "{:.1f}K".format(tot / 1e3)
|
121 |
+
else:
|
122 |
+
return tot
|
123 |
+
|
124 |
+
|
125 |
+
class BaseEncoder(nn.Module):
|
126 |
+
"""
|
127 |
+
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(self):
|
131 |
+
super().__init__()
|
132 |
+
|
133 |
+
def forward_features(self, samples, **kwargs):
|
134 |
+
raise NotImplementedError
|
135 |
+
|
136 |
+
@property
|
137 |
+
def device(self):
|
138 |
+
return list(self.parameters())[0].device
|
139 |
+
|
140 |
+
|
141 |
+
class SharedQueueMixin:
|
142 |
+
@torch.no_grad()
|
143 |
+
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
144 |
+
# gather keys before updating queue
|
145 |
+
image_feats = concat_all_gather(image_feat)
|
146 |
+
text_feats = concat_all_gather(text_feat)
|
147 |
+
|
148 |
+
batch_size = image_feats.shape[0]
|
149 |
+
|
150 |
+
ptr = int(self.queue_ptr)
|
151 |
+
assert self.queue_size % batch_size == 0 # for simplicity
|
152 |
+
|
153 |
+
# replace the keys at ptr (dequeue and enqueue)
|
154 |
+
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
155 |
+
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
156 |
+
|
157 |
+
if idxs is not None:
|
158 |
+
idxs = concat_all_gather(idxs)
|
159 |
+
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
160 |
+
|
161 |
+
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
162 |
+
self.queue_ptr[0] = ptr
|
163 |
+
|
164 |
+
|
165 |
+
class MomentumDistilationMixin:
|
166 |
+
@torch.no_grad()
|
167 |
+
def copy_params(self):
|
168 |
+
for model_pair in self.model_pairs:
|
169 |
+
for param, param_m in zip(
|
170 |
+
model_pair[0].parameters(), model_pair[1].parameters()
|
171 |
+
):
|
172 |
+
param_m.data.copy_(param.data) # initialize
|
173 |
+
param_m.requires_grad = False # not update by gradient
|
174 |
+
|
175 |
+
@torch.no_grad()
|
176 |
+
def _momentum_update(self):
|
177 |
+
for model_pair in self.model_pairs:
|
178 |
+
for param, param_m in zip(
|
179 |
+
model_pair[0].parameters(), model_pair[1].parameters()
|
180 |
+
):
|
181 |
+
param_m.data = param_m.data * self.momentum + param.data * (
|
182 |
+
1.0 - self.momentum
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
class GatherLayer(torch.autograd.Function):
|
187 |
+
"""
|
188 |
+
Gather tensors from all workers with support for backward propagation:
|
189 |
+
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
190 |
+
"""
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def forward(ctx, x):
|
194 |
+
output = [
|
195 |
+
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
196 |
+
]
|
197 |
+
torch.distributed.all_gather(output, x)
|
198 |
+
return tuple(output)
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def backward(ctx, *grads):
|
202 |
+
all_gradients = torch.stack(grads)
|
203 |
+
torch.distributed.all_reduce(all_gradients)
|
204 |
+
return all_gradients[torch.distributed.get_rank()]
|
205 |
+
|
206 |
+
|
207 |
+
def all_gather_with_grad(tensors):
|
208 |
+
"""
|
209 |
+
Performs all_gather operation on the provided tensors.
|
210 |
+
Graph remains connected for backward grad computation.
|
211 |
+
"""
|
212 |
+
# Queue the gathered tensors
|
213 |
+
world_size = torch.distributed.get_world_size()
|
214 |
+
# There is no need for reduction in the single-proc case
|
215 |
+
if world_size == 1:
|
216 |
+
return tensors
|
217 |
+
|
218 |
+
# tensor_all = GatherLayer.apply(tensors)
|
219 |
+
tensor_all = GatherLayer.apply(tensors)
|
220 |
+
|
221 |
+
return torch.cat(tensor_all, dim=0)
|
222 |
+
|
223 |
+
|
224 |
+
@torch.no_grad()
|
225 |
+
def concat_all_gather(tensor):
|
226 |
+
"""
|
227 |
+
Performs all_gather operation on the provided tensors.
|
228 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
229 |
+
"""
|
230 |
+
# if use distributed training
|
231 |
+
if not is_dist_avail_and_initialized():
|
232 |
+
return tensor
|
233 |
+
|
234 |
+
tensors_gather = [
|
235 |
+
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
236 |
+
]
|
237 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
238 |
+
|
239 |
+
output = torch.cat(tensors_gather, dim=0)
|
240 |
+
return output
|
241 |
+
|
242 |
+
|
243 |
+
def tile(x, dim, n_tile):
|
244 |
+
init_dim = x.size(dim)
|
245 |
+
repeat_idx = [1] * x.dim()
|
246 |
+
repeat_idx[dim] = n_tile
|
247 |
+
x = x.repeat(*(repeat_idx))
|
248 |
+
order_index = torch.LongTensor(
|
249 |
+
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
250 |
+
)
|
251 |
+
return torch.index_select(x, dim, order_index.to(x.device))
|
bliva/models/blip2.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
import contextlib
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import datetime
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.distributed as dist
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
import bliva.common.dist_utils as dist_utils
|
19 |
+
from bliva.common.dist_utils import download_cached_file
|
20 |
+
from bliva.common.utils import is_url
|
21 |
+
from bliva.common.logger import MetricLogger
|
22 |
+
from bliva.models.base_model import BaseModel
|
23 |
+
from bliva.models.Qformer import BertConfig, BertLMHeadModel
|
24 |
+
from bliva.models.eva_vit import create_eva_vit_g
|
25 |
+
from bliva.models.clip_vit import create_clip_vit_L
|
26 |
+
from transformers import BertTokenizer
|
27 |
+
|
28 |
+
|
29 |
+
class Blip2Base(BaseModel):
|
30 |
+
@classmethod
|
31 |
+
def init_tokenizer(cls, truncation_side="right"):
|
32 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side)
|
33 |
+
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
34 |
+
return tokenizer
|
35 |
+
|
36 |
+
def maybe_autocast(self, dtype=torch.float16):
|
37 |
+
# if on cpu, don't use autocast
|
38 |
+
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
39 |
+
enable_autocast = self.device != torch.device("cpu")
|
40 |
+
|
41 |
+
if enable_autocast:
|
42 |
+
return torch.cuda.amp.autocast(dtype=dtype)
|
43 |
+
else:
|
44 |
+
return contextlib.nullcontext()
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
48 |
+
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
49 |
+
encoder_config.encoder_width = vision_width
|
50 |
+
# insert cross-attention layer every other block
|
51 |
+
encoder_config.add_cross_attention = True
|
52 |
+
encoder_config.cross_attention_freq = cross_attention_freq
|
53 |
+
encoder_config.query_length = num_query_token
|
54 |
+
Qformer = BertLMHeadModel.from_pretrained(
|
55 |
+
"bert-base-uncased", config=encoder_config
|
56 |
+
)
|
57 |
+
query_tokens = nn.Parameter(
|
58 |
+
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
59 |
+
)
|
60 |
+
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) #0.02
|
61 |
+
return Qformer, query_tokens
|
62 |
+
|
63 |
+
def init_vision_encoder(
|
64 |
+
self, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
|
65 |
+
):
|
66 |
+
assert model_name in [
|
67 |
+
"eva_clip_g",
|
68 |
+
"eva2_clip_L",
|
69 |
+
"clip_L",
|
70 |
+
'cpe_eva_clip_g'
|
71 |
+
], "vit model must be eva_clip_g, eva2_clip_L or clip_L or cpe_eva_clip_g"
|
72 |
+
if model_name == "eva_clip_g":
|
73 |
+
visual_encoder = create_eva_vit_g(
|
74 |
+
img_size, drop_path_rate, use_grad_checkpoint, precision
|
75 |
+
)
|
76 |
+
# elif model_name == "eva2_clip_L":
|
77 |
+
# visual_encoder = create_eva2_vit_L(
|
78 |
+
# img_size, drop_path_rate, use_grad_checkpoint, precision
|
79 |
+
# )
|
80 |
+
elif model_name == "clip_L":
|
81 |
+
visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
|
82 |
+
|
83 |
+
ln_vision = LayerNorm(visual_encoder.num_features)
|
84 |
+
self.vit_name = model_name
|
85 |
+
return visual_encoder, ln_vision
|
86 |
+
|
87 |
+
def load_from_pretrained(self, url_or_filename):
|
88 |
+
if is_url(url_or_filename):
|
89 |
+
cached_file = download_cached_file(
|
90 |
+
url_or_filename, check_hash=False, progress=True
|
91 |
+
)
|
92 |
+
checkpoint = torch.load(cached_file, map_location="cpu")
|
93 |
+
elif os.path.isfile(url_or_filename):
|
94 |
+
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
95 |
+
else:
|
96 |
+
raise RuntimeError("checkpoint url or path is invalid")
|
97 |
+
|
98 |
+
state_dict = checkpoint["model"]
|
99 |
+
|
100 |
+
msg = self.load_state_dict(state_dict, strict=False)
|
101 |
+
|
102 |
+
# logging.info("Missing keys {}".format(msg.missing_keys))
|
103 |
+
logging.info("load checkpoint from %s" % url_or_filename)
|
104 |
+
|
105 |
+
return msg
|
106 |
+
|
107 |
+
def get_optimizer_params(self, weight_decay, lr_scale=1):
|
108 |
+
if self.vit_name == "eva_clip_g":
|
109 |
+
vit_num_layers = self.visual_encoder.get_num_layer()
|
110 |
+
lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
|
111 |
+
|
112 |
+
parameter_group_names = {}
|
113 |
+
parameter_group_vars = {}
|
114 |
+
|
115 |
+
for name, param in self.named_parameters():
|
116 |
+
if not param.requires_grad:
|
117 |
+
continue # frozen weights
|
118 |
+
if len(param.shape) == 1 or name.endswith(".bias"):
|
119 |
+
group_name = "no_decay"
|
120 |
+
this_weight_decay = 0.
|
121 |
+
else:
|
122 |
+
group_name = "decay"
|
123 |
+
this_weight_decay = weight_decay
|
124 |
+
if 'visual_encoder' in name:
|
125 |
+
layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.',''))
|
126 |
+
group_name = "vit_layer_%d_%s" % (layer_id, group_name)
|
127 |
+
else:
|
128 |
+
layer_id = None
|
129 |
+
|
130 |
+
if group_name not in parameter_group_names:
|
131 |
+
if layer_id is not None:
|
132 |
+
scale = lr_scales[layer_id]
|
133 |
+
else:
|
134 |
+
scale = 1
|
135 |
+
parameter_group_names[group_name] = {
|
136 |
+
"weight_decay": this_weight_decay,
|
137 |
+
"params": [],
|
138 |
+
"lr_scale": scale
|
139 |
+
}
|
140 |
+
parameter_group_vars[group_name] = {
|
141 |
+
"weight_decay": this_weight_decay,
|
142 |
+
"params": [],
|
143 |
+
"lr_scale": scale
|
144 |
+
}
|
145 |
+
parameter_group_vars[group_name]["params"].append(param)
|
146 |
+
parameter_group_names[group_name]["params"].append(name)
|
147 |
+
# import json
|
148 |
+
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
149 |
+
optim_params = list(parameter_group_vars.values())
|
150 |
+
return optim_params
|
151 |
+
else:
|
152 |
+
return super().get_optimizer_params(weight_decay,lr_scale)
|
153 |
+
|
154 |
+
def _lemmatize(self, answers):
|
155 |
+
def apply(answer):
|
156 |
+
doc = self.lemmatizer(answer)
|
157 |
+
|
158 |
+
words = []
|
159 |
+
for token in doc:
|
160 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
161 |
+
words.append(token.lemma_)
|
162 |
+
else:
|
163 |
+
words.append(token.text)
|
164 |
+
answer = " ".join(words)
|
165 |
+
|
166 |
+
return answer
|
167 |
+
|
168 |
+
return [apply(answer) for answer in answers]
|
169 |
+
|
170 |
+
@property
|
171 |
+
def lemmatizer(self):
|
172 |
+
if self._lemmatizer is None:
|
173 |
+
try:
|
174 |
+
import spacy
|
175 |
+
|
176 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
177 |
+
except ImportError:
|
178 |
+
logging.error(
|
179 |
+
"""
|
180 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
181 |
+
python -m spacy download en_core_web_sm
|
182 |
+
OR
|
183 |
+
import spacy.cli
|
184 |
+
spacy.cli.download("en_core_web_sm")
|
185 |
+
"""
|
186 |
+
)
|
187 |
+
exit(1)
|
188 |
+
|
189 |
+
return self._lemmatizer
|
190 |
+
|
191 |
+
def disabled_train(self, mode=True):
|
192 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
193 |
+
does not change anymore."""
|
194 |
+
return self
|
195 |
+
|
196 |
+
|
197 |
+
class LayerNorm(nn.LayerNorm):
|
198 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
199 |
+
|
200 |
+
def forward(self, x: torch.Tensor):
|
201 |
+
orig_type = x.dtype
|
202 |
+
ret = super().forward(x.type(torch.float32))
|
203 |
+
return ret.type(orig_type)
|
204 |
+
|
205 |
+
|
206 |
+
def compute_sim_matrix(model, data_loader, **kwargs):
|
207 |
+
k_test = kwargs.pop("k_test")
|
208 |
+
|
209 |
+
metric_logger = MetricLogger(delimiter=" ")
|
210 |
+
header = "Evaluation:"
|
211 |
+
|
212 |
+
logging.info("Computing features for evaluation...")
|
213 |
+
start_time = time.time()
|
214 |
+
|
215 |
+
texts = data_loader.dataset.text
|
216 |
+
num_text = len(texts)
|
217 |
+
text_bs = 256
|
218 |
+
text_ids = []
|
219 |
+
text_embeds = []
|
220 |
+
text_atts = []
|
221 |
+
for i in range(0, num_text, text_bs):
|
222 |
+
text = texts[i : min(num_text, i + text_bs)]
|
223 |
+
text_input = model.tokenizer(
|
224 |
+
text,
|
225 |
+
padding="max_length",
|
226 |
+
truncation=True,
|
227 |
+
max_length=35,
|
228 |
+
return_tensors="pt",
|
229 |
+
).to(model.device)
|
230 |
+
text_feat = model.forward_text(text_input)
|
231 |
+
text_embed = F.normalize(model.text_proj(text_feat))
|
232 |
+
text_embeds.append(text_embed)
|
233 |
+
text_ids.append(text_input.input_ids)
|
234 |
+
text_atts.append(text_input.attention_mask)
|
235 |
+
|
236 |
+
text_embeds = torch.cat(text_embeds, dim=0)
|
237 |
+
text_ids = torch.cat(text_ids, dim=0)
|
238 |
+
text_atts = torch.cat(text_atts, dim=0)
|
239 |
+
|
240 |
+
vit_feats = []
|
241 |
+
image_embeds = []
|
242 |
+
for samples in data_loader:
|
243 |
+
image = samples["image"]
|
244 |
+
|
245 |
+
image = image.to(model.device)
|
246 |
+
image_feat, vit_feat = model.forward_image(image)
|
247 |
+
image_embed = model.vision_proj(image_feat)
|
248 |
+
image_embed = F.normalize(image_embed, dim=-1)
|
249 |
+
|
250 |
+
vit_feats.append(vit_feat.cpu())
|
251 |
+
image_embeds.append(image_embed)
|
252 |
+
|
253 |
+
vit_feats = torch.cat(vit_feats, dim=0)
|
254 |
+
image_embeds = torch.cat(image_embeds, dim=0)
|
255 |
+
|
256 |
+
sims_matrix = []
|
257 |
+
for image_embed in image_embeds:
|
258 |
+
sim_q2t = image_embed @ text_embeds.t()
|
259 |
+
sim_i2t, _ = sim_q2t.max(0)
|
260 |
+
sims_matrix.append(sim_i2t)
|
261 |
+
sims_matrix = torch.stack(sims_matrix, dim=0)
|
262 |
+
|
263 |
+
score_matrix_i2t = torch.full(
|
264 |
+
(len(data_loader.dataset.image), len(texts)), -100.0
|
265 |
+
).to(model.device)
|
266 |
+
|
267 |
+
num_tasks = dist_utils.get_world_size()
|
268 |
+
rank = dist_utils.get_rank()
|
269 |
+
step = sims_matrix.size(0) // num_tasks + 1
|
270 |
+
start = rank * step
|
271 |
+
end = min(sims_matrix.size(0), start + step)
|
272 |
+
|
273 |
+
for i, sims in enumerate(
|
274 |
+
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
275 |
+
):
|
276 |
+
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
277 |
+
image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
|
278 |
+
score = model.compute_itm(
|
279 |
+
image_inputs=image_inputs,
|
280 |
+
text_ids=text_ids[topk_idx],
|
281 |
+
text_atts=text_atts[topk_idx],
|
282 |
+
).float()
|
283 |
+
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
|
284 |
+
|
285 |
+
sims_matrix = sims_matrix.t()
|
286 |
+
score_matrix_t2i = torch.full(
|
287 |
+
(len(texts), len(data_loader.dataset.image)), -100.0
|
288 |
+
).to(model.device)
|
289 |
+
|
290 |
+
step = sims_matrix.size(0) // num_tasks + 1
|
291 |
+
start = rank * step
|
292 |
+
end = min(sims_matrix.size(0), start + step)
|
293 |
+
|
294 |
+
for i, sims in enumerate(
|
295 |
+
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
296 |
+
):
|
297 |
+
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
298 |
+
image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
|
299 |
+
score = model.compute_itm(
|
300 |
+
image_inputs=image_inputs,
|
301 |
+
text_ids=text_ids[start + i].repeat(k_test, 1),
|
302 |
+
text_atts=text_atts[start + i].repeat(k_test, 1),
|
303 |
+
).float()
|
304 |
+
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
|
305 |
+
|
306 |
+
if dist_utils.is_dist_avail_and_initialized():
|
307 |
+
dist.barrier()
|
308 |
+
torch.distributed.all_reduce(
|
309 |
+
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
|
310 |
+
)
|
311 |
+
torch.distributed.all_reduce(
|
312 |
+
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
|
313 |
+
)
|
314 |
+
|
315 |
+
total_time = time.time() - start_time
|
316 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
317 |
+
logging.info("Evaluation time {}".format(total_time_str))
|
318 |
+
|
319 |
+
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
|
bliva/models/bliva_flant5xxl.py
ADDED
@@ -0,0 +1,803 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import string
|
3 |
+
import random
|
4 |
+
import copy
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.cuda.amp import autocast as autocast
|
9 |
+
from transformers import T5TokenizerFast
|
10 |
+
|
11 |
+
from bliva.common.registry import registry
|
12 |
+
from bliva.models.blip2 import Blip2Base, disabled_train
|
13 |
+
from bliva.models.modeling_t5 import T5Config, T5ForConditionalGeneration
|
14 |
+
from transformers.modeling_outputs import BaseModelOutput
|
15 |
+
|
16 |
+
|
17 |
+
@registry.register_model("bliva_flant5")
|
18 |
+
class BLIVAFlanT5(Blip2Base):
|
19 |
+
|
20 |
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
21 |
+
"flant5xxl": "configs/models/bliva_flant5xxl.yaml",
|
22 |
+
}
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
vit_model="eva_clip_g",
|
27 |
+
img_size=224,
|
28 |
+
drop_path_rate=0,
|
29 |
+
use_grad_checkpoint=False,
|
30 |
+
vit_precision="fp16",
|
31 |
+
freeze_vit=True,
|
32 |
+
num_query_token=32,
|
33 |
+
t5_model="google/flan-t5-xl",
|
34 |
+
prompt="",
|
35 |
+
max_txt_len=128,
|
36 |
+
max_output_txt_len=256,
|
37 |
+
apply_lemmatizer=False,
|
38 |
+
num_few_shot_examples=0,
|
39 |
+
few_shot_prob=0,
|
40 |
+
qformer_text_input=True,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
|
44 |
+
"""
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.tokenizer = self.init_tokenizer(truncation_side="left")
|
48 |
+
|
49 |
+
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
50 |
+
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
51 |
+
)
|
52 |
+
if freeze_vit:
|
53 |
+
for name, param in self.visual_encoder.named_parameters():
|
54 |
+
param.requires_grad = False
|
55 |
+
self.visual_encoder = self.visual_encoder.eval()
|
56 |
+
self.visual_encoder.train = disabled_train
|
57 |
+
logging.info("freeze vision encoder")
|
58 |
+
|
59 |
+
self.Qformer, self.query_tokens = self.init_Qformer(
|
60 |
+
num_query_token, self.visual_encoder.num_features
|
61 |
+
)
|
62 |
+
|
63 |
+
if not qformer_text_input:
|
64 |
+
self.Qformer.bert.embeddings.word_embeddings = None
|
65 |
+
self.Qformer.bert.embeddings.position_embeddings = None
|
66 |
+
for layer in self.Qformer.bert.encoder.layer:
|
67 |
+
layer.output = None
|
68 |
+
layer.intermediate = None
|
69 |
+
else:
|
70 |
+
self.Qformer.resize_token_embeddings(len(self.tokenizer))
|
71 |
+
self.Qformer.cls = None
|
72 |
+
|
73 |
+
self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='left')
|
74 |
+
self.t5_output_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='right')
|
75 |
+
|
76 |
+
t5_config = T5Config.from_pretrained(t5_model)
|
77 |
+
t5_config.dense_act_fn = "gelu"
|
78 |
+
self.t5_model = T5ForConditionalGeneration.from_pretrained(
|
79 |
+
t5_model, config=t5_config
|
80 |
+
)
|
81 |
+
|
82 |
+
for name, param in self.t5_model.named_parameters():
|
83 |
+
param.requires_grad = False
|
84 |
+
param.data = param.data.bfloat16()
|
85 |
+
|
86 |
+
self.t5_proj = nn.Linear(
|
87 |
+
self.Qformer.config.hidden_size, self.t5_model.config.hidden_size
|
88 |
+
)
|
89 |
+
|
90 |
+
self.max_txt_len = max_txt_len
|
91 |
+
self.max_output_txt_len = max_output_txt_len
|
92 |
+
self.prompt = prompt
|
93 |
+
|
94 |
+
self._apply_lemmatizer = apply_lemmatizer
|
95 |
+
self._lemmatizer = None
|
96 |
+
|
97 |
+
self.num_few_shot_examples = num_few_shot_examples
|
98 |
+
self.few_shot_prob = few_shot_prob
|
99 |
+
|
100 |
+
self.qformer_text_input = qformer_text_input
|
101 |
+
self.vision_project = nn.Linear(self.visual_encoder.num_features, self.t5_model.config.hidden_size)
|
102 |
+
|
103 |
+
def forward(self, samples):
|
104 |
+
|
105 |
+
image = samples["image"]
|
106 |
+
image_features= self.visual_encoder.get_intermediate_layers(image)[-2] # [batch_size, 257, 1408]
|
107 |
+
image_features = image_features[:, 1:]
|
108 |
+
add_feature_llm = self.vision_project(image_features)
|
109 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
110 |
+
|
111 |
+
with self.maybe_autocast():
|
112 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
113 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
114 |
+
|
115 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
116 |
+
if self.qformer_text_input:
|
117 |
+
text_Qformer = self.tokenizer(
|
118 |
+
samples["text_input"],
|
119 |
+
padding='longest',
|
120 |
+
truncation=True,
|
121 |
+
max_length=self.max_txt_len,
|
122 |
+
return_tensors="pt",
|
123 |
+
).to(image.device)
|
124 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
125 |
+
Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
|
126 |
+
|
127 |
+
query_output = self.Qformer.bert(
|
128 |
+
text_Qformer.input_ids,
|
129 |
+
attention_mask=Qformer_atts,
|
130 |
+
query_embeds=query_tokens,
|
131 |
+
encoder_hidden_states=image_embeds,
|
132 |
+
encoder_attention_mask=image_atts,
|
133 |
+
return_dict=True,
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
query_output = self.Qformer.bert(
|
137 |
+
query_embeds=query_tokens,
|
138 |
+
encoder_hidden_states=image_embeds,
|
139 |
+
encoder_attention_mask=image_atts,
|
140 |
+
return_dict=True,
|
141 |
+
)
|
142 |
+
|
143 |
+
inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
144 |
+
atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
|
145 |
+
|
146 |
+
fs_embeds, fs_atts = None, None
|
147 |
+
if self.few_shot_prob > 0 and "few_shot_samples" in samples.keys():
|
148 |
+
fs_embeds, fs_atts = self.prepare_few_shot_embeds(samples['few_shot_samples'])
|
149 |
+
|
150 |
+
with self.maybe_autocast(dtype=torch.bfloat16):
|
151 |
+
input_tokens = self.t5_tokenizer(
|
152 |
+
samples["text_input"],
|
153 |
+
padding="longest",
|
154 |
+
truncation=True,
|
155 |
+
max_length=self.max_txt_len,
|
156 |
+
return_tensors="pt",
|
157 |
+
).to(image.device)
|
158 |
+
output_tokens = self.t5_output_tokenizer(
|
159 |
+
samples["text_output"],
|
160 |
+
padding="longest",
|
161 |
+
truncation=True,
|
162 |
+
max_length=self.max_output_txt_len,
|
163 |
+
return_tensors="pt",
|
164 |
+
).to(image.device)
|
165 |
+
|
166 |
+
encoder_atts = torch.cat([atts_t5, atts_add_feature_llm, input_tokens.attention_mask], dim=1)
|
167 |
+
|
168 |
+
targets = output_tokens.input_ids.masked_fill(
|
169 |
+
output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100
|
170 |
+
)
|
171 |
+
|
172 |
+
inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
|
173 |
+
inputs_embeds = torch.cat([inputs_t5, add_feature_llm, inputs_embeds], dim=1)
|
174 |
+
|
175 |
+
if fs_embeds is not None:
|
176 |
+
inputs_embeds = torch.cat([fs_embeds, inputs_embeds], dim=1)
|
177 |
+
encoder_atts = torch.cat([fs_atts, encoder_atts], dim=1)
|
178 |
+
|
179 |
+
outputs = self.t5_model(
|
180 |
+
inputs_embeds=inputs_embeds,
|
181 |
+
attention_mask=encoder_atts,
|
182 |
+
decoder_attention_mask=output_tokens.attention_mask,
|
183 |
+
return_dict=True,
|
184 |
+
labels=targets,
|
185 |
+
)
|
186 |
+
loss = outputs.loss
|
187 |
+
|
188 |
+
return {"loss": loss}
|
189 |
+
|
190 |
+
def prepare_few_shot_embeds(self, samples):
|
191 |
+
this_n_fs = random.choices(
|
192 |
+
list(range(self.num_few_shot_examples + 1)),
|
193 |
+
weights=[1 - self.few_shot_prob] + [self.few_shot_prob / self.num_few_shot_examples] * self.num_few_shot_examples
|
194 |
+
)[0]
|
195 |
+
|
196 |
+
if this_n_fs == 0:
|
197 |
+
return None, None
|
198 |
+
|
199 |
+
images = []
|
200 |
+
text_input = []
|
201 |
+
for sample in samples:
|
202 |
+
for n in range(this_n_fs):
|
203 |
+
images.append(sample['image'][n])
|
204 |
+
text_input.append(sample['text_input'][n])
|
205 |
+
images = torch.stack(images, dim=0)
|
206 |
+
|
207 |
+
image = images
|
208 |
+
|
209 |
+
with self.maybe_autocast():
|
210 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
211 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
212 |
+
image.device
|
213 |
+
)
|
214 |
+
|
215 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
216 |
+
if self.qformer_text_input:
|
217 |
+
text_Qformer = self.tokenizer(
|
218 |
+
text_input,
|
219 |
+
padding='longest',
|
220 |
+
truncation=True,
|
221 |
+
max_length=self.max_txt_len,
|
222 |
+
return_tensors="pt",
|
223 |
+
).to(image.device)
|
224 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
225 |
+
Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
|
226 |
+
query_output = self.Qformer.bert(
|
227 |
+
text_Qformer.input_ids,
|
228 |
+
attention_mask = Qformer_atts,
|
229 |
+
query_embeds=query_tokens,
|
230 |
+
encoder_hidden_states=image_embeds,
|
231 |
+
encoder_attention_mask=image_atts,
|
232 |
+
return_dict=True,
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
query_output = self.Qformer.bert(
|
236 |
+
query_embeds=query_tokens,
|
237 |
+
encoder_hidden_states=image_embeds,
|
238 |
+
encoder_attention_mask=image_atts,
|
239 |
+
return_dict=True,
|
240 |
+
)
|
241 |
+
|
242 |
+
inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
243 |
+
atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
|
244 |
+
|
245 |
+
with self.maybe_autocast(dtype=torch.bfloat16):
|
246 |
+
input_tokens = self.t5_tokenizer(
|
247 |
+
text_input,
|
248 |
+
padding="longest",
|
249 |
+
truncation=True,
|
250 |
+
max_length=self.max_txt_len,
|
251 |
+
return_tensors="pt",
|
252 |
+
).to(image.device)
|
253 |
+
|
254 |
+
encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
|
255 |
+
|
256 |
+
inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
|
257 |
+
inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
|
258 |
+
|
259 |
+
if this_n_fs > 1:
|
260 |
+
encoder_atts = encoder_atts.reshape(encoder_atts.size(0) // this_n_fs, encoder_atts.size(1) * this_n_fs)
|
261 |
+
inputs_embeds = inputs_embeds.reshape(inputs_embeds.size(0) // this_n_fs, inputs_embeds.size(1) * this_n_fs, inputs_embeds.size(2))
|
262 |
+
|
263 |
+
return inputs_embeds, encoder_atts
|
264 |
+
|
265 |
+
@torch.no_grad()
|
266 |
+
def generate(
|
267 |
+
self,
|
268 |
+
samples,
|
269 |
+
use_nucleus_sampling=False,
|
270 |
+
num_beams=5,
|
271 |
+
max_length=256,
|
272 |
+
min_length=1,
|
273 |
+
top_p=0.9,
|
274 |
+
repetition_penalty=1.5,
|
275 |
+
length_penalty=1.0,
|
276 |
+
num_captions=1,
|
277 |
+
temperature=1,
|
278 |
+
):
|
279 |
+
if "prompt" in samples.keys():
|
280 |
+
prompt = samples["prompt"]
|
281 |
+
else:
|
282 |
+
prompt = self.prompt
|
283 |
+
|
284 |
+
image = samples["image"]
|
285 |
+
|
286 |
+
bs = image.size(0)
|
287 |
+
|
288 |
+
if isinstance(prompt, str):
|
289 |
+
prompt = [prompt] * bs
|
290 |
+
else:
|
291 |
+
assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
|
292 |
+
|
293 |
+
# For TextCaps
|
294 |
+
if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
|
295 |
+
prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
|
296 |
+
if 'context' in samples.keys() and samples['context'] != '':
|
297 |
+
prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
|
298 |
+
print('using context')
|
299 |
+
query_tokens = self.query_tokens.expand(bs, -1, -1)
|
300 |
+
if self.qformer_text_input:
|
301 |
+
# remove ocr tokens in q_former (for eval textvqa)
|
302 |
+
# qformer_prompt = prompt
|
303 |
+
# qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
|
304 |
+
|
305 |
+
text_Qformer = self.tokenizer(
|
306 |
+
prompt,
|
307 |
+
padding='longest',
|
308 |
+
truncation=True,
|
309 |
+
max_length=self.max_txt_len,
|
310 |
+
return_tensors="pt",
|
311 |
+
).to(image.device)
|
312 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
313 |
+
Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
|
314 |
+
|
315 |
+
# For video data
|
316 |
+
if image.dim() == 5:
|
317 |
+
inputs_t5, atts_t5 = [], []
|
318 |
+
add_inputs_llm, add_atts_llm = [], []
|
319 |
+
for j in range(image.size(2)):
|
320 |
+
this_frame = image[:,:,j,:,:]
|
321 |
+
with self.maybe_autocast():
|
322 |
+
frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
|
323 |
+
frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
324 |
+
frame_features =self.visual_encoder.get_intermediate_layers(this_frame)[-2]
|
325 |
+
|
326 |
+
frame_features = frame_features[:, 1:]
|
327 |
+
|
328 |
+
add_feature_llm = self.vision_project(frame_features)
|
329 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
330 |
+
|
331 |
+
if self.qformer_text_input:
|
332 |
+
frame_query_output = self.Qformer.bert(
|
333 |
+
text_Qformer.input_ids,
|
334 |
+
attention_mask = Qformer_atts,
|
335 |
+
query_embeds=query_tokens,
|
336 |
+
encoder_hidden_states=frame_embeds,
|
337 |
+
encoder_attention_mask=frame_atts,
|
338 |
+
return_dict=True,
|
339 |
+
)
|
340 |
+
else:
|
341 |
+
frame_query_output = self.Qformer.bert(
|
342 |
+
query_embeds=query_tokens,
|
343 |
+
encoder_hidden_states=frame_embeds,
|
344 |
+
encoder_attention_mask=frame_atts,
|
345 |
+
return_dict=True,
|
346 |
+
)
|
347 |
+
|
348 |
+
frame_inputs_t5 = self.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
349 |
+
frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
|
350 |
+
inputs_t5.append(frame_inputs_t5)
|
351 |
+
atts_t5.append(frame_atts_t5)
|
352 |
+
add_inputs_llm.append(add_feature_llm)
|
353 |
+
add_atts_llm.append(atts_add_feature_llm)
|
354 |
+
inputs_t5 = torch.cat(inputs_t5, dim=1)
|
355 |
+
atts_t5 = torch.cat(atts_t5, dim=1)
|
356 |
+
add_feature_llm = torch.cat(add_inputs_llm, dim=1)
|
357 |
+
atts_add_feature_llm = torch.cat(add_atts_llm, dim=1)
|
358 |
+
else:
|
359 |
+
with self.maybe_autocast():
|
360 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
361 |
+
image_features= self.visual_encoder.get_intermediate_layers(image)[-2]
|
362 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
363 |
+
|
364 |
+
image_features = image_features[:, 1:]
|
365 |
+
add_feature_llm = self.vision_project(image_features)
|
366 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
367 |
+
if self.qformer_text_input:
|
368 |
+
query_output = self.Qformer.bert(
|
369 |
+
text_Qformer.input_ids,
|
370 |
+
attention_mask=Qformer_atts,
|
371 |
+
query_embeds=query_tokens,
|
372 |
+
encoder_hidden_states=image_embeds,
|
373 |
+
encoder_attention_mask=image_atts,
|
374 |
+
return_dict=True,
|
375 |
+
)
|
376 |
+
else:
|
377 |
+
query_output = self.Qformer.bert(
|
378 |
+
query_embeds=query_tokens,
|
379 |
+
encoder_hidden_states=image_embeds,
|
380 |
+
encoder_attention_mask=image_atts,
|
381 |
+
return_dict=True,
|
382 |
+
)
|
383 |
+
|
384 |
+
inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
385 |
+
atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
|
386 |
+
|
387 |
+
input_tokens = self.t5_tokenizer(
|
388 |
+
prompt,
|
389 |
+
padding="longest",
|
390 |
+
return_tensors="pt"
|
391 |
+
).to(image.device)
|
392 |
+
|
393 |
+
encoder_atts = torch.cat([atts_t5, atts_add_feature_llm,input_tokens.attention_mask], dim=1)
|
394 |
+
|
395 |
+
with self.maybe_autocast(dtype=torch.bfloat16):
|
396 |
+
inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
|
397 |
+
inputs_embeds = torch.cat([inputs_t5, add_feature_llm, inputs_embeds], dim=1)
|
398 |
+
|
399 |
+
outputs = self.t5_model.generate(
|
400 |
+
inputs_embeds=inputs_embeds,
|
401 |
+
attention_mask=encoder_atts,
|
402 |
+
do_sample=use_nucleus_sampling,
|
403 |
+
top_p=top_p,
|
404 |
+
temperature=temperature,
|
405 |
+
num_beams=num_beams,
|
406 |
+
max_new_tokens=max_length,
|
407 |
+
min_length=min_length,
|
408 |
+
repetition_penalty=repetition_penalty,
|
409 |
+
length_penalty=length_penalty,
|
410 |
+
num_return_sequences=num_captions,
|
411 |
+
)
|
412 |
+
output_text = self.t5_tokenizer.batch_decode(
|
413 |
+
outputs, skip_special_tokens=True
|
414 |
+
)
|
415 |
+
|
416 |
+
return output_text
|
417 |
+
|
418 |
+
def predict_answers(
|
419 |
+
self,
|
420 |
+
samples,
|
421 |
+
num_beams=5,
|
422 |
+
inference_method="generate",
|
423 |
+
max_len=10,
|
424 |
+
min_len=1,
|
425 |
+
num_ans_candidates=128,
|
426 |
+
answer_list=None,
|
427 |
+
prompt="",
|
428 |
+
length_penalty=-1,
|
429 |
+
**kwargs
|
430 |
+
):
|
431 |
+
if isinstance(samples["text_input"], str):
|
432 |
+
samples["text_input"] = [samples["text_input"]]
|
433 |
+
|
434 |
+
if prompt:
|
435 |
+
if prompt.count("{}") == 2:
|
436 |
+
if 'ocr_tokens' in samples:
|
437 |
+
text_input = [
|
438 |
+
prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i])
|
439 |
+
for i in range(len(samples["text_input"]))]
|
440 |
+
elif 'choices' in samples:
|
441 |
+
text_input = []
|
442 |
+
for i in range(len(samples["text_input"])):
|
443 |
+
this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])]
|
444 |
+
this_choices = " ".join(this_choices)
|
445 |
+
text_input.append(prompt.format(samples["text_input"][i], this_choices))
|
446 |
+
else:
|
447 |
+
text_input = [prompt.format(question) for question in samples["text_input"]]
|
448 |
+
else:
|
449 |
+
text_input = samples["text_input"]
|
450 |
+
|
451 |
+
samples["prompt"] = text_input
|
452 |
+
|
453 |
+
output_text = self.generate(
|
454 |
+
samples,
|
455 |
+
num_beams=num_beams,
|
456 |
+
max_length=max_len,
|
457 |
+
min_length=min_len,
|
458 |
+
length_penalty=length_penalty
|
459 |
+
)
|
460 |
+
|
461 |
+
if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]):
|
462 |
+
output_text = self._lemmatize(output_text)
|
463 |
+
|
464 |
+
return output_text
|
465 |
+
|
466 |
+
def predict_class(
|
467 |
+
self,
|
468 |
+
samples,
|
469 |
+
candidates,
|
470 |
+
n_segments=1,
|
471 |
+
):
|
472 |
+
# If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one
|
473 |
+
if type(candidates[0]) == list:
|
474 |
+
results = []
|
475 |
+
|
476 |
+
for i in range(samples["image"].size(0)):
|
477 |
+
this_sample = {
|
478 |
+
"image": samples["image"][i].unsqueeze(0),
|
479 |
+
"prompt": samples["prompt"][i],
|
480 |
+
}
|
481 |
+
|
482 |
+
if "text_input" in samples.keys():
|
483 |
+
this_sample["text_input"] = [samples["text_input"][i]]
|
484 |
+
|
485 |
+
if 'context' in samples.keys():
|
486 |
+
this_sample['context'] = [samples["context"][i]]
|
487 |
+
|
488 |
+
if 'history' in samples.keys():
|
489 |
+
this_sample['history'] = [samples["history"][i]]
|
490 |
+
|
491 |
+
if 'caption' in samples.keys():
|
492 |
+
this_sample['caption'] = [samples["caption"][i]]
|
493 |
+
|
494 |
+
this_result = self._predict_class(this_sample, candidates[i], n_segments)
|
495 |
+
results.append(this_result)
|
496 |
+
|
497 |
+
try:
|
498 |
+
results = torch.cat(results, dim=0)
|
499 |
+
except:
|
500 |
+
results = [res.tolist()[0] for res in results]
|
501 |
+
|
502 |
+
return results
|
503 |
+
|
504 |
+
return self._predict_class(samples, candidates, n_segments)
|
505 |
+
|
506 |
+
def _predict_class(
|
507 |
+
self,
|
508 |
+
samples,
|
509 |
+
candidates,
|
510 |
+
n_segments=1,
|
511 |
+
):
|
512 |
+
"""
|
513 |
+
Args:
|
514 |
+
samples (dict): A dictionary containing the following keys:
|
515 |
+
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
|
516 |
+
- prompt: the instruction
|
517 |
+
candidates:
|
518 |
+
(list): A list of candidate class names;
|
519 |
+
n_segments:
|
520 |
+
(int): Split the candidates into n_segments and predict one by one. This is useful when the number of candidates is too large.
|
521 |
+
Returns:
|
522 |
+
output_class: predicted class index
|
523 |
+
"""
|
524 |
+
|
525 |
+
image = samples["image"]
|
526 |
+
prompt = samples["prompt"]
|
527 |
+
|
528 |
+
bs = image.size(0)
|
529 |
+
|
530 |
+
if isinstance(prompt, str):
|
531 |
+
prompt = [prompt] * bs
|
532 |
+
else:
|
533 |
+
assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
|
534 |
+
|
535 |
+
if "text_input" in samples.keys():
|
536 |
+
if type(samples["text_input"][0]) == list:
|
537 |
+
prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))]
|
538 |
+
else:
|
539 |
+
prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))]
|
540 |
+
|
541 |
+
# scienceqa
|
542 |
+
if 'context' in samples.keys() and samples['context'] != '':
|
543 |
+
prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
|
544 |
+
|
545 |
+
# visual dialog
|
546 |
+
if 'history' in samples.keys() and samples['history'][0] != '':
|
547 |
+
prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))]
|
548 |
+
|
549 |
+
if 'caption' in samples.keys() and samples['caption'][0] != '':
|
550 |
+
prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))]
|
551 |
+
|
552 |
+
query_tokens = self.query_tokens.expand(bs, -1, -1)
|
553 |
+
if self.qformer_text_input:
|
554 |
+
text_Qformer = self.tokenizer(
|
555 |
+
prompt,
|
556 |
+
padding='longest',
|
557 |
+
truncation=True,
|
558 |
+
max_length=self.max_txt_len,
|
559 |
+
return_tensors="pt"
|
560 |
+
).to(image.device)
|
561 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
562 |
+
Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask], dim=1)
|
563 |
+
|
564 |
+
if image.dim() == 5:
|
565 |
+
inputs_t5, atts_t5 = [], []
|
566 |
+
add_inputs_llm, add_atts_llm = [], []
|
567 |
+
for j in range(image.size(2)):
|
568 |
+
this_frame = image[:,:,j,:,:]
|
569 |
+
with self.maybe_autocast():
|
570 |
+
frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
|
571 |
+
frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
572 |
+
frame_features =self.visual_encoder.get_intermediate_layers(this_frame)[-2]
|
573 |
+
|
574 |
+
frame_features = frame_features[:, 1:]
|
575 |
+
|
576 |
+
add_feature_llm = self.vision_project(frame_features)
|
577 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
578 |
+
if self.qformer_text_input:
|
579 |
+
frame_query_output = self.Qformer.bert(
|
580 |
+
text_Qformer.input_ids,
|
581 |
+
attention_mask=Qformer_atts,
|
582 |
+
query_embeds=query_tokens,
|
583 |
+
encoder_hidden_states=frame_embeds,
|
584 |
+
encoder_attention_mask=frame_atts,
|
585 |
+
return_dict=True,
|
586 |
+
)
|
587 |
+
else:
|
588 |
+
frame_query_output = self.Qformer.bert(
|
589 |
+
query_embeds=query_tokens,
|
590 |
+
encoder_hidden_states=frame_embeds,
|
591 |
+
encoder_attention_mask=frame_atts,
|
592 |
+
return_dict=True,
|
593 |
+
)
|
594 |
+
|
595 |
+
frame_inputs_t5 = self.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
596 |
+
frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
|
597 |
+
inputs_t5.append(frame_inputs_t5)
|
598 |
+
atts_t5.append(frame_atts_t5)
|
599 |
+
add_inputs_llm.append(add_feature_llm)
|
600 |
+
add_atts_llm.append(atts_add_feature_llm)
|
601 |
+
inputs_t5 = torch.cat(inputs_t5, dim=1)
|
602 |
+
atts_t5 = torch.cat(atts_t5, dim=1)
|
603 |
+
add_feature_llm = torch.cat(add_inputs_llm, dim=1)
|
604 |
+
atts_add_feature_llm = torch.cat(add_atts_llm, dim=1)
|
605 |
+
else:
|
606 |
+
with self.maybe_autocast():
|
607 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
608 |
+
image_features= self.visual_encoder.get_intermediate_layers(image)[-2]
|
609 |
+
|
610 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
611 |
+
|
612 |
+
image_features = image_features[:, 1:]
|
613 |
+
add_feature_llm = self.vision_project(image_features)
|
614 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
615 |
+
|
616 |
+
if self.qformer_text_input:
|
617 |
+
query_output = self.Qformer.bert(
|
618 |
+
text_Qformer.input_ids,
|
619 |
+
attention_mask=Qformer_atts,
|
620 |
+
query_embeds=query_tokens,
|
621 |
+
encoder_hidden_states=image_embeds,
|
622 |
+
encoder_attention_mask=image_atts,
|
623 |
+
return_dict=True,
|
624 |
+
)
|
625 |
+
else:
|
626 |
+
query_output = self.Qformer.bert(
|
627 |
+
query_embeds=query_tokens,
|
628 |
+
encoder_hidden_states=image_embeds,
|
629 |
+
encoder_attention_mask=image_atts,
|
630 |
+
return_dict=True,
|
631 |
+
)
|
632 |
+
|
633 |
+
inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
634 |
+
atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
|
635 |
+
|
636 |
+
input_tokens = self.t5_tokenizer(
|
637 |
+
prompt, padding="longest", return_tensors="pt"
|
638 |
+
).to(image.device)
|
639 |
+
output_tokens = self.t5_tokenizer(
|
640 |
+
candidates, padding="longest", return_tensors="pt"
|
641 |
+
).to(image.device)
|
642 |
+
|
643 |
+
encoder_atts = torch.cat([atts_t5, atts_add_feature_llm, input_tokens.attention_mask], dim=1)
|
644 |
+
|
645 |
+
n_cands = len(candidates)
|
646 |
+
|
647 |
+
with self.maybe_autocast(dtype=torch.bfloat16):
|
648 |
+
inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
|
649 |
+
inputs_embeds = torch.cat([inputs_t5,add_feature_llm, inputs_embeds], dim=1)
|
650 |
+
|
651 |
+
encoder_outputs = self.t5_model.encoder(
|
652 |
+
inputs_embeds=inputs_embeds,
|
653 |
+
attention_mask=encoder_atts,
|
654 |
+
)
|
655 |
+
|
656 |
+
all_losses = []
|
657 |
+
for n in range(n_segments):
|
658 |
+
seg_len = n_cands // n_segments
|
659 |
+
if n == (n_segments - 1):
|
660 |
+
seg_len = n_cands - seg_len * (n_segments - 1)
|
661 |
+
|
662 |
+
# this_encoder_outputs = copy.deepcopy(encoder_outputs)
|
663 |
+
this_encoder_outputs = BaseModelOutput(
|
664 |
+
last_hidden_state=encoder_outputs[0].clone(),
|
665 |
+
)
|
666 |
+
|
667 |
+
this_encoder_outputs['last_hidden_state'] = this_encoder_outputs[0].repeat_interleave(seg_len, dim=0)
|
668 |
+
this_encoder_atts = encoder_atts.repeat_interleave(seg_len, dim=0)
|
669 |
+
|
670 |
+
start_i = n * (n_cands // n_segments)
|
671 |
+
end_i = start_i + seg_len
|
672 |
+
this_output_tokens_ids = output_tokens.input_ids[start_i:end_i].repeat(bs, 1)
|
673 |
+
this_output_tokens_atts = output_tokens.attention_mask[start_i:end_i].repeat(bs, 1)
|
674 |
+
|
675 |
+
this_targets = this_output_tokens_ids.masked_fill(this_output_tokens_ids == self.t5_tokenizer.pad_token_id, -100)
|
676 |
+
|
677 |
+
outputs = self.t5_model(
|
678 |
+
encoder_outputs=this_encoder_outputs,
|
679 |
+
attention_mask=this_encoder_atts,
|
680 |
+
decoder_attention_mask=this_output_tokens_atts,
|
681 |
+
return_dict=True,
|
682 |
+
labels=this_targets,
|
683 |
+
reduction="none",
|
684 |
+
)
|
685 |
+
loss = outputs.loss
|
686 |
+
|
687 |
+
loss = loss.reshape(bs, seg_len)
|
688 |
+
# output_class_ranks = torch.argsort(loss, dim=-1)
|
689 |
+
all_losses.append(loss)
|
690 |
+
|
691 |
+
all_losses = torch.cat(all_losses, dim=-1)
|
692 |
+
output_class_ranks = torch.argsort(all_losses, dim=-1)
|
693 |
+
|
694 |
+
# encoder_outputs['last_hidden_state'] = encoder_outputs[0].repeat_interleave(n_cands, dim=0)
|
695 |
+
# encoder_atts = encoder_atts.repeat_interleave(n_cands, dim=0)
|
696 |
+
# output_tokens.input_ids = output_tokens.input_ids.repeat(bs, 1)
|
697 |
+
# output_tokens.attention_mask = output_tokens.attention_mask.repeat(bs, 1)
|
698 |
+
|
699 |
+
# # compute the LM loss for each candidate (sum logprob across all tokens) and select the highest
|
700 |
+
# targets = output_tokens.input_ids.masked_fill(output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100)
|
701 |
+
|
702 |
+
# outputs = self.t5_model(
|
703 |
+
# encoder_outputs=encoder_outputs,
|
704 |
+
# attention_mask=encoder_atts,
|
705 |
+
# decoder_attention_mask=output_tokens.attention_mask,
|
706 |
+
# return_dict=True,
|
707 |
+
# labels=targets,
|
708 |
+
# reduction="none",
|
709 |
+
# )
|
710 |
+
# loss = outputs.loss
|
711 |
+
|
712 |
+
# loss = loss.reshape(bs, n_cands)
|
713 |
+
# output_class_ranks = torch.argsort(loss, dim=-1) # (bs, num_candidates)
|
714 |
+
|
715 |
+
return output_class_ranks
|
716 |
+
|
717 |
+
def _lemmatize(self, answers):
|
718 |
+
def apply(answer):
|
719 |
+
doc = self.lemmatizer(answer)
|
720 |
+
|
721 |
+
words = []
|
722 |
+
for token in doc:
|
723 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
724 |
+
words.append(token.lemma_)
|
725 |
+
else:
|
726 |
+
words.append(token.text)
|
727 |
+
answer = " ".join(words)
|
728 |
+
|
729 |
+
return answer
|
730 |
+
|
731 |
+
return [apply(answer) for answer in answers]
|
732 |
+
|
733 |
+
@property
|
734 |
+
def lemmatizer(self):
|
735 |
+
if self._lemmatizer is None:
|
736 |
+
try:
|
737 |
+
import spacy
|
738 |
+
|
739 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
740 |
+
except ImportError:
|
741 |
+
logging.error(
|
742 |
+
"""
|
743 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
744 |
+
python -m spacy download en_core_web_sm
|
745 |
+
OR
|
746 |
+
import spacy.cli
|
747 |
+
spacy.cli.download("en_core_web_sm")
|
748 |
+
"""
|
749 |
+
)
|
750 |
+
exit(1)
|
751 |
+
|
752 |
+
return self._lemmatizer
|
753 |
+
|
754 |
+
@classmethod
|
755 |
+
def from_config(cls, cfg):
|
756 |
+
vit_model = cfg.get("vit_model", "eva_clip_g")
|
757 |
+
img_size = cfg.get("image_size")
|
758 |
+
num_query_token = cfg.get("num_query_token")
|
759 |
+
t5_model = cfg.get("t5_model")
|
760 |
+
|
761 |
+
drop_path_rate = cfg.get("drop_path_rate", 0)
|
762 |
+
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
763 |
+
vit_precision = cfg.get("vit_precision", "fp16")
|
764 |
+
freeze_vit = cfg.get("freeze_vit", True)
|
765 |
+
|
766 |
+
prompt = cfg.get("prompt", "")
|
767 |
+
max_txt_len = cfg.get("max_txt_len", 128)
|
768 |
+
max_output_txt_len = cfg.get("max_output_txt_len", 256)
|
769 |
+
|
770 |
+
apply_lemmatizer = cfg.get("apply_lemmatizer", False)
|
771 |
+
|
772 |
+
num_few_shot_examples = cfg.get("num_few_shot_examples", 0)
|
773 |
+
few_shot_prob = cfg.get("few_shot_prob", 0.0)
|
774 |
+
|
775 |
+
qformer_text_input = cfg.get("qformer_text_input", True)
|
776 |
+
|
777 |
+
model = cls(
|
778 |
+
vit_model=vit_model,
|
779 |
+
img_size=img_size,
|
780 |
+
drop_path_rate=drop_path_rate,
|
781 |
+
use_grad_checkpoint=use_grad_checkpoint,
|
782 |
+
vit_precision=vit_precision,
|
783 |
+
freeze_vit=freeze_vit,
|
784 |
+
num_query_token=num_query_token,
|
785 |
+
t5_model=t5_model,
|
786 |
+
prompt=prompt,
|
787 |
+
max_txt_len=max_txt_len,
|
788 |
+
max_output_txt_len=max_output_txt_len,
|
789 |
+
apply_lemmatizer=apply_lemmatizer,
|
790 |
+
num_few_shot_examples=num_few_shot_examples,
|
791 |
+
few_shot_prob=few_shot_prob,
|
792 |
+
qformer_text_input=qformer_text_input,
|
793 |
+
)
|
794 |
+
|
795 |
+
# if qformer_text_input:
|
796 |
+
# # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
|
797 |
+
# model.load_from_pretrained(
|
798 |
+
# url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
|
799 |
+
# )
|
800 |
+
|
801 |
+
model.load_checkpoint_from_config(cfg)
|
802 |
+
|
803 |
+
return model
|
bliva/models/bliva_vicuna7b.py
ADDED
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import string
|
3 |
+
from packaging import version
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.cuda.amp import autocast as autocast
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
import transformers
|
10 |
+
|
11 |
+
from bliva.common.registry import registry
|
12 |
+
from bliva.models.blip2 import Blip2Base, disabled_train
|
13 |
+
|
14 |
+
@registry.register_model("bliva_vicuna")
|
15 |
+
class BLIVAVicuna(Blip2Base):
|
16 |
+
|
17 |
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
18 |
+
"vicuna7b": "configs/models/bliva_vicuna7b.yaml",
|
19 |
+
}
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
vit_model="eva_clip_g",
|
24 |
+
img_size=224,
|
25 |
+
drop_path_rate=0,
|
26 |
+
use_grad_checkpoint=False,
|
27 |
+
vit_precision="fp16",
|
28 |
+
freeze_vit=True,
|
29 |
+
num_query_token=32,
|
30 |
+
llm_model="",
|
31 |
+
prompt="",
|
32 |
+
max_txt_len=128,
|
33 |
+
max_output_txt_len=256,
|
34 |
+
apply_lemmatizer=False,
|
35 |
+
qformer_text_input=True,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
transformers_version = version.parse(transformers.__version__)
|
39 |
+
assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28"
|
40 |
+
from transformers import LlamaTokenizer
|
41 |
+
from bliva.models.modeling_llama import LlamaForCausalLM
|
42 |
+
|
43 |
+
self.tokenizer = self.init_tokenizer(truncation_side="left")
|
44 |
+
|
45 |
+
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
46 |
+
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
47 |
+
)
|
48 |
+
|
49 |
+
if freeze_vit:
|
50 |
+
for name, param in self.visual_encoder.named_parameters():
|
51 |
+
param.requires_grad = False
|
52 |
+
self.visual_encoder = self.visual_encoder.eval()
|
53 |
+
self.visual_encoder.train = disabled_train
|
54 |
+
logging.info("freeze vision encoder")
|
55 |
+
|
56 |
+
self.Qformer, self.query_tokens = self.init_Qformer(
|
57 |
+
num_query_token, self.visual_encoder.num_features
|
58 |
+
)
|
59 |
+
|
60 |
+
if not qformer_text_input:
|
61 |
+
self.Qformer.bert.embeddings.word_embeddings = None
|
62 |
+
self.Qformer.bert.embeddings.position_embeddings = None
|
63 |
+
for layer in self.Qformer.bert.encoder.layer:
|
64 |
+
layer.output = None
|
65 |
+
layer.intermediate = None
|
66 |
+
else:
|
67 |
+
self.Qformer.resize_token_embeddings(len(self.tokenizer))
|
68 |
+
self.Qformer.cls = None
|
69 |
+
|
70 |
+
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, truncation_side="left")
|
71 |
+
self.llm_model = LlamaForCausalLM.from_pretrained(
|
72 |
+
llm_model, low_cpu_mem_usage=True, torch_dtype=torch.float16
|
73 |
+
).to('cuda:0') #load_in_8bit=True
|
74 |
+
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
75 |
+
self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
|
76 |
+
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
|
77 |
+
self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
|
78 |
+
# self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
|
79 |
+
|
80 |
+
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
|
81 |
+
|
82 |
+
# self.eos_token_id = self.llm_tokenizer(
|
83 |
+
# self.llm_tokenizer.eos_token, add_special_tokens=False
|
84 |
+
# ).input_ids[0]
|
85 |
+
|
86 |
+
for name, param in self.llm_model.named_parameters():
|
87 |
+
param.requires_grad = False
|
88 |
+
|
89 |
+
self.llm_proj = nn.Linear(
|
90 |
+
self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
|
91 |
+
)
|
92 |
+
|
93 |
+
self.max_txt_len = max_txt_len
|
94 |
+
self.max_output_txt_len = max_output_txt_len
|
95 |
+
self.prompt = prompt
|
96 |
+
prompt_tokens = self.llm_tokenizer(self.prompt, return_tensors="pt")
|
97 |
+
self.prompt_length = prompt_tokens.attention_mask.sum(1)
|
98 |
+
|
99 |
+
self._lemmatizer = None
|
100 |
+
|
101 |
+
self.qformer_text_input = qformer_text_input
|
102 |
+
|
103 |
+
self.vision_project = nn.Linear(self.visual_encoder.num_features, self.llm_model.config.hidden_size)
|
104 |
+
|
105 |
+
def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts):
|
106 |
+
input_part_targets_len = []
|
107 |
+
llm_tokens = {"input_ids": [], "attention_mask": []}
|
108 |
+
for i in range(input_ids.size(0)):
|
109 |
+
this_input_ones = input_atts[i].sum()
|
110 |
+
input_part_targets_len.append(this_input_ones)
|
111 |
+
llm_tokens['input_ids'].append(
|
112 |
+
torch.cat([
|
113 |
+
input_ids[i][:this_input_ones],
|
114 |
+
output_ids[i][1:],
|
115 |
+
input_ids[i][this_input_ones:]
|
116 |
+
])
|
117 |
+
)
|
118 |
+
llm_tokens['attention_mask'].append(
|
119 |
+
torch.cat([
|
120 |
+
input_atts[i][:this_input_ones],
|
121 |
+
output_atts[i][1:],
|
122 |
+
input_atts[i][this_input_ones:]
|
123 |
+
])
|
124 |
+
)
|
125 |
+
llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
|
126 |
+
llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
|
127 |
+
return llm_tokens, input_part_targets_len
|
128 |
+
|
129 |
+
def forward(self, samples):
|
130 |
+
# print('-----------------')
|
131 |
+
# print(samples["text_input"])
|
132 |
+
# print(samples["text_output"])
|
133 |
+
# print(samples["image"].shape)
|
134 |
+
# print('-----------------')
|
135 |
+
|
136 |
+
image = samples["image"]
|
137 |
+
|
138 |
+
image_features= self.visual_encoder.get_intermediate_layers(image)[-2] # [batch_size, 257, 1408]
|
139 |
+
image_features = image_features[:, 1:]
|
140 |
+
add_feature_llm = self.vision_project(image_features)
|
141 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
142 |
+
|
143 |
+
with self.maybe_autocast():
|
144 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
145 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
146 |
+
|
147 |
+
bs = image.size(0)
|
148 |
+
|
149 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
150 |
+
if self.qformer_text_input:
|
151 |
+
text_Qformer = self.tokenizer(
|
152 |
+
samples["text_input"],
|
153 |
+
padding='longest',
|
154 |
+
truncation=True,
|
155 |
+
max_length=self.max_txt_len,
|
156 |
+
return_tensors="pt",
|
157 |
+
).to(image.device)
|
158 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
159 |
+
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1)
|
160 |
+
|
161 |
+
query_output = self.Qformer.bert(
|
162 |
+
text_Qformer.input_ids,
|
163 |
+
attention_mask=Qformer_atts,
|
164 |
+
query_embeds=query_tokens,
|
165 |
+
encoder_hidden_states=image_embeds,
|
166 |
+
encoder_attention_mask=image_atts,
|
167 |
+
return_dict=True,
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
query_output = self.Qformer.bert(
|
171 |
+
query_embeds=query_tokens,
|
172 |
+
encoder_hidden_states=image_embeds,
|
173 |
+
encoder_attention_mask=image_atts,
|
174 |
+
return_dict=True,
|
175 |
+
)
|
176 |
+
|
177 |
+
inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
178 |
+
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
179 |
+
|
180 |
+
self.llm_tokenizer.padding_side = "right"
|
181 |
+
self.llm_tokenizer.truncation_side = 'left'
|
182 |
+
text_input_tokens = self.llm_tokenizer(
|
183 |
+
samples['text_input'],
|
184 |
+
return_tensors="pt",
|
185 |
+
padding="longest",
|
186 |
+
truncation=True,
|
187 |
+
max_length=self.max_txt_len,
|
188 |
+
).to(image.device)
|
189 |
+
|
190 |
+
self.llm_tokenizer.truncation_side = 'right'
|
191 |
+
text_output_tokens = self.llm_tokenizer(
|
192 |
+
[t + self.llm_tokenizer.eos_token for t in samples['text_output']],
|
193 |
+
return_tensors="pt",
|
194 |
+
padding="longest",
|
195 |
+
truncation=True,
|
196 |
+
max_length=self.max_output_txt_len,
|
197 |
+
).to(image.device)
|
198 |
+
|
199 |
+
llm_tokens, input_part_targets_len = self.concat_text_input_output(
|
200 |
+
text_input_tokens.input_ids,
|
201 |
+
text_input_tokens.attention_mask,
|
202 |
+
text_output_tokens.input_ids,
|
203 |
+
text_output_tokens.attention_mask,
|
204 |
+
)
|
205 |
+
|
206 |
+
# do not apply loss to the padding
|
207 |
+
targets = llm_tokens['input_ids'].masked_fill(
|
208 |
+
llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
|
209 |
+
)
|
210 |
+
|
211 |
+
# do not apply loss to the text input (i.e., instruction)
|
212 |
+
for i, l in enumerate(input_part_targets_len):
|
213 |
+
targets[i][:l] = -100
|
214 |
+
|
215 |
+
# do not apply loss to the query tokens
|
216 |
+
empty_targets = (
|
217 |
+
torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
|
218 |
+
)
|
219 |
+
# do not apply loss to the additional image features
|
220 |
+
empty_add_targets = (
|
221 |
+
torch.ones(atts_add_feature_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
|
222 |
+
)
|
223 |
+
#targets = torch.cat([empty_targets, targets], dim=1)
|
224 |
+
targets = torch.cat([empty_targets, empty_add_targets, targets], dim=1)
|
225 |
+
|
226 |
+
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
|
227 |
+
#inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
228 |
+
#attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1)
|
229 |
+
inputs_embeds = torch.cat([inputs_llm, add_feature_llm, inputs_embeds], dim=1)
|
230 |
+
attention_mask = torch.cat([atts_llm, atts_add_feature_llm, llm_tokens['attention_mask']], dim=1)
|
231 |
+
|
232 |
+
with self.maybe_autocast():
|
233 |
+
outputs = self.llm_model(
|
234 |
+
inputs_embeds=inputs_embeds,
|
235 |
+
attention_mask=attention_mask,
|
236 |
+
return_dict=True,
|
237 |
+
labels=targets,
|
238 |
+
)
|
239 |
+
|
240 |
+
loss = outputs.loss
|
241 |
+
|
242 |
+
return {"loss": loss}
|
243 |
+
|
244 |
+
@torch.no_grad()
|
245 |
+
def generate(
|
246 |
+
self,
|
247 |
+
samples,
|
248 |
+
use_nucleus_sampling=False,
|
249 |
+
num_beams=5,
|
250 |
+
max_length=256,
|
251 |
+
min_length=1,
|
252 |
+
top_p=0.9,
|
253 |
+
repetition_penalty=1.5,
|
254 |
+
length_penalty=1,
|
255 |
+
num_captions=1,
|
256 |
+
temperature=1,
|
257 |
+
):
|
258 |
+
self.llm_tokenizer.padding_side = "left"
|
259 |
+
|
260 |
+
if "prompt" in samples.keys():
|
261 |
+
prompt = samples["prompt"]
|
262 |
+
else:
|
263 |
+
prompt = samples["text_input"]
|
264 |
+
|
265 |
+
image = samples["image"]
|
266 |
+
|
267 |
+
bs = image.size(0)
|
268 |
+
|
269 |
+
# if isinstance(prompt, str):
|
270 |
+
# prompt = [prompt] * bs
|
271 |
+
# else:
|
272 |
+
# assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
|
273 |
+
|
274 |
+
# For TextCaps
|
275 |
+
if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
|
276 |
+
prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
|
277 |
+
|
278 |
+
if 'context' in samples.keys() and samples['context'] != '':
|
279 |
+
prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
|
280 |
+
print('using context')
|
281 |
+
query_tokens = self.query_tokens.expand(bs, -1, -1)
|
282 |
+
if self.qformer_text_input:
|
283 |
+
# remove ocr tokens in q_former (for eval textvqa)
|
284 |
+
# qformer_prompt = prompt
|
285 |
+
# qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
|
286 |
+
|
287 |
+
text_Qformer = self.tokenizer(
|
288 |
+
prompt,
|
289 |
+
padding='longest',
|
290 |
+
truncation=True,
|
291 |
+
max_length=self.max_txt_len,
|
292 |
+
return_tensors="pt",
|
293 |
+
).to(image.device)
|
294 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
295 |
+
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
|
296 |
+
|
297 |
+
# For video data
|
298 |
+
if image.dim() == 5:
|
299 |
+
inputs_llm, atts_llm = [], []
|
300 |
+
add_inputs_llm, add_atts_llm = [], []
|
301 |
+
for j in range(image.size(2)):
|
302 |
+
this_frame = image[:,:,j,:,:]
|
303 |
+
with self.maybe_autocast():
|
304 |
+
frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
|
305 |
+
frame_features =self.visual_encoder.get_intermediate_layers(this_frame)[-2]
|
306 |
+
|
307 |
+
frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
308 |
+
frame_features = frame_features[:, 1:]
|
309 |
+
|
310 |
+
add_feature_llm = self.vision_project(frame_features)
|
311 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
312 |
+
|
313 |
+
if self.qformer_text_input:
|
314 |
+
frame_query_output = self.Qformer.bert(
|
315 |
+
text_Qformer.input_ids,
|
316 |
+
attention_mask=Qformer_atts,
|
317 |
+
query_embeds=query_tokens,
|
318 |
+
encoder_hidden_states=frame_embeds,
|
319 |
+
encoder_attention_mask=frame_atts,
|
320 |
+
return_dict=True,
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
frame_query_output = self.Qformer.bert(
|
324 |
+
query_embeds=query_tokens,
|
325 |
+
encoder_hidden_states=frame_embeds,
|
326 |
+
encoder_attention_mask=frame_atts,
|
327 |
+
return_dict=True,
|
328 |
+
)
|
329 |
+
frame_inputs_llm = self.llm_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
330 |
+
frame_atts_llm = torch.ones(frame_inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
331 |
+
inputs_llm.append(frame_inputs_llm)
|
332 |
+
atts_llm.append(frame_atts_llm)
|
333 |
+
add_inputs_llm.append(add_feature_llm)
|
334 |
+
add_atts_llm.append(atts_add_feature_llm)
|
335 |
+
inputs_llm = torch.cat(inputs_llm, dim=1)
|
336 |
+
atts_llm = torch.cat(atts_llm, dim=1)
|
337 |
+
add_feature_llm = torch.cat(add_inputs_llm, dim=1)
|
338 |
+
atts_add_feature_llm = torch.cat(add_atts_llm, dim=1)
|
339 |
+
else:
|
340 |
+
with self.maybe_autocast():
|
341 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
342 |
+
|
343 |
+
image_features= self.visual_encoder.get_intermediate_layers(image)[-2] # [batch_size, 257, 1408]
|
344 |
+
|
345 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
346 |
+
|
347 |
+
image_features = image_features[:, 1:]
|
348 |
+
add_feature_llm = self.vision_project(image_features)
|
349 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
350 |
+
|
351 |
+
if self.qformer_text_input:
|
352 |
+
query_output = self.Qformer.bert(
|
353 |
+
text_Qformer.input_ids,
|
354 |
+
attention_mask=Qformer_atts,
|
355 |
+
query_embeds=query_tokens,
|
356 |
+
encoder_hidden_states=image_embeds,
|
357 |
+
encoder_attention_mask=image_atts,
|
358 |
+
return_dict=True,
|
359 |
+
)
|
360 |
+
else:
|
361 |
+
query_output = self.Qformer.bert(
|
362 |
+
query_embeds=query_tokens,
|
363 |
+
encoder_hidden_states=image_embeds,
|
364 |
+
encoder_attention_mask=image_atts,
|
365 |
+
return_dict=True,
|
366 |
+
)
|
367 |
+
|
368 |
+
inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
369 |
+
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
370 |
+
|
371 |
+
llm_tokens = self.llm_tokenizer(
|
372 |
+
prompt,
|
373 |
+
padding="longest",
|
374 |
+
return_tensors="pt"
|
375 |
+
).to(image.device)
|
376 |
+
|
377 |
+
with self.maybe_autocast():
|
378 |
+
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
|
379 |
+
# inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
380 |
+
# attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
|
381 |
+
inputs_embeds = torch.cat([inputs_llm, add_feature_llm, inputs_embeds], dim=1)
|
382 |
+
attention_mask = torch.cat([atts_llm, atts_add_feature_llm, llm_tokens['attention_mask']], dim=1)
|
383 |
+
|
384 |
+
outputs = self.llm_model.generate(
|
385 |
+
inputs_embeds=inputs_embeds,
|
386 |
+
attention_mask=attention_mask,
|
387 |
+
do_sample=use_nucleus_sampling,
|
388 |
+
top_p=top_p,
|
389 |
+
temperature=temperature,
|
390 |
+
num_beams=num_beams,
|
391 |
+
max_length=max_length,
|
392 |
+
min_length=min_length,
|
393 |
+
# eos_token_id=self.eos_token_id,
|
394 |
+
repetition_penalty=repetition_penalty,
|
395 |
+
length_penalty=length_penalty,
|
396 |
+
num_return_sequences=num_captions,
|
397 |
+
)
|
398 |
+
|
399 |
+
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
|
400 |
+
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
401 |
+
output_text = [text.strip() for text in output_text]
|
402 |
+
|
403 |
+
return output_text
|
404 |
+
|
405 |
+
def predict_answers(
|
406 |
+
self,
|
407 |
+
samples,
|
408 |
+
num_beams=5,
|
409 |
+
inference_method="generate",
|
410 |
+
max_len=10,
|
411 |
+
min_len=1,
|
412 |
+
num_ans_candidates=128,
|
413 |
+
answer_list=None,
|
414 |
+
prompt="",
|
415 |
+
length_penalty=0,
|
416 |
+
**kwargs
|
417 |
+
):
|
418 |
+
if isinstance(samples["text_input"], str):
|
419 |
+
samples["text_input"] = [samples["text_input"]]
|
420 |
+
|
421 |
+
if prompt:
|
422 |
+
if prompt.count("{}") == 2:
|
423 |
+
if 'ocr_tokens' in samples:
|
424 |
+
text_input = [
|
425 |
+
prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i])
|
426 |
+
for i in range(len(samples["text_input"]))]
|
427 |
+
elif 'choices' in samples:
|
428 |
+
text_input = []
|
429 |
+
for i in range(len(samples["text_input"])):
|
430 |
+
this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])]
|
431 |
+
this_choices = " ".join(this_choices)
|
432 |
+
text_input.append(prompt.format(samples["text_input"][i], this_choices))
|
433 |
+
else:
|
434 |
+
text_input = [prompt.format(question) for question in samples["text_input"]]
|
435 |
+
else:
|
436 |
+
text_input = samples["text_input"]
|
437 |
+
|
438 |
+
samples["prompt"] = text_input
|
439 |
+
|
440 |
+
output_text = self.generate(
|
441 |
+
samples,
|
442 |
+
num_beams=num_beams,
|
443 |
+
max_length=max_len,
|
444 |
+
min_length=min_len,
|
445 |
+
length_penalty=length_penalty
|
446 |
+
)
|
447 |
+
|
448 |
+
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
|
449 |
+
output_text = self._lemmatize(output_text)
|
450 |
+
|
451 |
+
return output_text
|
452 |
+
|
453 |
+
def predict_class(
|
454 |
+
self,
|
455 |
+
samples,
|
456 |
+
candidates,
|
457 |
+
n_segments=1,
|
458 |
+
):
|
459 |
+
self.llm_tokenizer.padding_side = "left"
|
460 |
+
|
461 |
+
# If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one
|
462 |
+
if type(candidates[0]) == list:
|
463 |
+
results = []
|
464 |
+
|
465 |
+
for i in range(samples["image"].size(0)):
|
466 |
+
this_sample = {
|
467 |
+
"image": samples["image"][i].unsqueeze(0),
|
468 |
+
"prompt": samples["prompt"][i],
|
469 |
+
}
|
470 |
+
|
471 |
+
if "text_input" in samples.keys():
|
472 |
+
this_sample["text_input"] = [samples["text_input"][i]]
|
473 |
+
|
474 |
+
if 'context' in samples.keys():
|
475 |
+
this_sample['context'] = [samples["context"][i]]
|
476 |
+
|
477 |
+
if 'history' in samples.keys():
|
478 |
+
this_sample['history'] = [samples["history"][i]]
|
479 |
+
|
480 |
+
if 'caption' in samples.keys():
|
481 |
+
this_sample['caption'] = [samples["caption"][i]]
|
482 |
+
|
483 |
+
this_result = self._predict_class(this_sample, candidates[i], n_segments)
|
484 |
+
results.append(this_result)
|
485 |
+
|
486 |
+
try:
|
487 |
+
results = torch.cat(results, dim=0)
|
488 |
+
except:
|
489 |
+
results = [res.tolist()[0] for res in results]
|
490 |
+
|
491 |
+
return results
|
492 |
+
|
493 |
+
return self._predict_class(samples, candidates, n_segments)
|
494 |
+
|
495 |
+
def _predict_class(
|
496 |
+
self,
|
497 |
+
samples,
|
498 |
+
candidates,
|
499 |
+
n_segments=1,
|
500 |
+
):
|
501 |
+
image = samples["image"]
|
502 |
+
prompt = samples["prompt"]
|
503 |
+
|
504 |
+
bs = image.size(0)
|
505 |
+
|
506 |
+
if isinstance(prompt, str):
|
507 |
+
prompt = [prompt] * bs
|
508 |
+
else:
|
509 |
+
assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
|
510 |
+
|
511 |
+
if "text_input" in samples.keys():
|
512 |
+
if type(samples["text_input"][0]) == list:
|
513 |
+
prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))]
|
514 |
+
else:
|
515 |
+
prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))]
|
516 |
+
|
517 |
+
# scienceqa
|
518 |
+
if 'context' in samples.keys() and samples['context'] != '':
|
519 |
+
prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
|
520 |
+
|
521 |
+
# visual dialog
|
522 |
+
if 'history' in samples.keys() and samples['history'][0] != '':
|
523 |
+
prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))]
|
524 |
+
|
525 |
+
if 'caption' in samples.keys() and samples['caption'][0] != '':
|
526 |
+
prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))]
|
527 |
+
|
528 |
+
query_tokens = self.query_tokens.expand(bs, -1, -1)
|
529 |
+
if self.qformer_text_input:
|
530 |
+
text_Qformer = self.tokenizer(
|
531 |
+
prompt,
|
532 |
+
padding='longest',
|
533 |
+
truncation=True,
|
534 |
+
max_length=self.max_txt_len,
|
535 |
+
return_tensors="pt"
|
536 |
+
).to(image.device)
|
537 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
538 |
+
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
|
539 |
+
|
540 |
+
# For video data
|
541 |
+
if image.dim() == 5:
|
542 |
+
inputs_llm, atts_llm = [], []
|
543 |
+
add_inputs_llm, add_atts_llm = [], []
|
544 |
+
for j in range(image.size(2)):
|
545 |
+
this_frame = image[:,:,j,:,:]
|
546 |
+
with self.maybe_autocast():
|
547 |
+
|
548 |
+
frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
|
549 |
+
frame_features =self.visual_encoder.get_intermediate_layers(this_frame)[-2]
|
550 |
+
|
551 |
+
frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
552 |
+
frame_features = frame_features[:, 1:]
|
553 |
+
|
554 |
+
add_feature_llm = self.vision_project(frame_features)
|
555 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
556 |
+
|
557 |
+
if self.qformer_text_input:
|
558 |
+
frame_query_output = self.Qformer.bert(
|
559 |
+
text_Qformer.input_ids,
|
560 |
+
attention_mask=Qformer_atts,
|
561 |
+
query_embeds=query_tokens,
|
562 |
+
encoder_hidden_states=frame_embeds,
|
563 |
+
encoder_attention_mask=frame_atts,
|
564 |
+
return_dict=True,
|
565 |
+
)
|
566 |
+
else:
|
567 |
+
frame_query_output = self.Qformer.bert(
|
568 |
+
query_embeds=query_tokens,
|
569 |
+
encoder_hidden_states=frame_embeds,
|
570 |
+
encoder_attention_mask=frame_atts,
|
571 |
+
return_dict=True,
|
572 |
+
)
|
573 |
+
frame_inputs_llm = self.llm_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
574 |
+
frame_atts_llm = torch.ones(frame_inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
575 |
+
inputs_llm.append(frame_inputs_llm)
|
576 |
+
atts_llm.append(frame_atts_llm)
|
577 |
+
add_inputs_llm.append(add_feature_llm)
|
578 |
+
add_atts_llm.append(atts_add_feature_llm)
|
579 |
+
inputs_llm = torch.cat(inputs_llm, dim=1)
|
580 |
+
atts_llm = torch.cat(atts_llm, dim=1)
|
581 |
+
add_feature_llm = torch.cat(add_inputs_llm, dim=1)
|
582 |
+
atts_add_feature_llm = torch.cat(add_atts_llm, dim=1)
|
583 |
+
else:
|
584 |
+
with self.maybe_autocast():
|
585 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
586 |
+
|
587 |
+
image_features= self.visual_encoder.get_intermediate_layers(image)[-2] # [batch_size, 257, 1408]
|
588 |
+
|
589 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
590 |
+
|
591 |
+
image_features = image_features[:, 1:]
|
592 |
+
add_feature_llm = self.vision_project(image_features)
|
593 |
+
atts_add_feature_llm = torch.ones(add_feature_llm.size()[:-1], dtype=torch.long).to(image.device)
|
594 |
+
|
595 |
+
if self.qformer_text_input:
|
596 |
+
query_output = self.Qformer.bert(
|
597 |
+
text_Qformer.input_ids,
|
598 |
+
attention_mask=Qformer_atts,
|
599 |
+
query_embeds=query_tokens,
|
600 |
+
encoder_hidden_states=image_embeds,
|
601 |
+
encoder_attention_mask=image_atts,
|
602 |
+
return_dict=True,
|
603 |
+
)
|
604 |
+
else:
|
605 |
+
query_output = self.Qformer.bert(
|
606 |
+
query_embeds=query_tokens,
|
607 |
+
encoder_hidden_states=image_embeds,
|
608 |
+
encoder_attention_mask=image_atts,
|
609 |
+
return_dict=True,
|
610 |
+
)
|
611 |
+
|
612 |
+
inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
613 |
+
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
614 |
+
|
615 |
+
self.llm_tokenizer.padding_side = "right"
|
616 |
+
self.llm_tokenizer.truncation_side = 'left'
|
617 |
+
text_input_tokens = self.llm_tokenizer(
|
618 |
+
prompt,
|
619 |
+
return_tensors="pt",
|
620 |
+
padding="longest",
|
621 |
+
# truncation=True,
|
622 |
+
# max_length=self.max_txt_len,
|
623 |
+
).to(image.device)
|
624 |
+
|
625 |
+
empty_targets = torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
|
626 |
+
empty_add_targets = (
|
627 |
+
torch.ones(atts_add_feature_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
|
628 |
+
)
|
629 |
+
|
630 |
+
# self.llm_tokenizer.padding_side = "right"
|
631 |
+
self.llm_tokenizer.truncation_side = 'right'
|
632 |
+
n_cands = len(candidates)
|
633 |
+
with self.maybe_autocast(dtype=torch.bfloat16):
|
634 |
+
all_losses = []
|
635 |
+
for n in range(n_segments):
|
636 |
+
seg_len = n_cands // n_segments
|
637 |
+
if n == (n_segments - 1):
|
638 |
+
seg_len = n_cands - seg_len * (n_segments - 1)
|
639 |
+
|
640 |
+
start_i = n * (n_cands // n_segments)
|
641 |
+
end_i = start_i + seg_len
|
642 |
+
|
643 |
+
this_output_tokens = self.llm_tokenizer(
|
644 |
+
candidates[start_i:end_i],
|
645 |
+
return_tensors="pt",
|
646 |
+
padding="longest",
|
647 |
+
# truncation=True,
|
648 |
+
# max_length=self.max_output_txt_len,
|
649 |
+
).to(image.device)
|
650 |
+
|
651 |
+
this_input_tokens_ids = text_input_tokens.input_ids.repeat_interleave(seg_len, dim=0)
|
652 |
+
this_input_tokens_atts = text_input_tokens.attention_mask.repeat_interleave(seg_len, dim=0)
|
653 |
+
|
654 |
+
this_output_tokens_ids = this_output_tokens.input_ids.repeat(bs, 1)
|
655 |
+
this_output_tokens_atts = this_output_tokens.attention_mask.repeat(bs, 1)
|
656 |
+
|
657 |
+
this_llm_tokens, this_input_targets_len = self.concat_text_input_output(
|
658 |
+
this_input_tokens_ids,
|
659 |
+
this_input_tokens_atts,
|
660 |
+
this_output_tokens_ids,
|
661 |
+
this_output_tokens_atts
|
662 |
+
)
|
663 |
+
|
664 |
+
this_llm_input_ids = this_llm_tokens['input_ids']
|
665 |
+
this_llm_atts = this_llm_tokens['attention_mask']
|
666 |
+
# this_llm_input_ids = torch.cat([this_input_tokens_ids, this_output_tokens_ids], dim=1)
|
667 |
+
# this_llm_atts = torch.cat([this_input_tokens_atts, this_output_tokens_atts], dim=1)
|
668 |
+
|
669 |
+
inputs_embeds = self.llm_model.get_input_embeddings()(this_llm_input_ids)
|
670 |
+
inputs_embeds = torch.cat([inputs_llm.repeat_interleave(seg_len, dim=0), \
|
671 |
+
add_feature_llm.repeat_interleave(seg_len, dim=0), inputs_embeds], dim=1)
|
672 |
+
attention_mask = torch.cat([atts_llm.repeat_interleave(seg_len, dim=0), \
|
673 |
+
atts_add_feature_llm.repeat_interleave(seg_len, dim=0) ,this_llm_atts], dim=1)
|
674 |
+
|
675 |
+
this_targets = this_llm_input_ids.masked_fill(this_llm_input_ids == self.llm_tokenizer.pad_token_id, -100)
|
676 |
+
# this_targets[:, :this_input_tokens_ids.size(1)] = -100
|
677 |
+
for i, l in enumerate(this_input_targets_len):
|
678 |
+
this_targets[i][:l] = -100
|
679 |
+
|
680 |
+
this_targets = torch.cat([empty_targets.repeat_interleave(seg_len, dim=0), \
|
681 |
+
empty_add_targets.repeat_interleave(seg_len, dim=0) ,this_targets], dim=1)
|
682 |
+
|
683 |
+
outputs = self.llm_model(
|
684 |
+
inputs_embeds=inputs_embeds,
|
685 |
+
attention_mask=attention_mask,
|
686 |
+
return_dict=True,
|
687 |
+
labels=this_targets,
|
688 |
+
reduction="none",
|
689 |
+
)
|
690 |
+
|
691 |
+
loss = outputs.loss
|
692 |
+
|
693 |
+
loss = loss.reshape(bs, seg_len)
|
694 |
+
# output_class_ranks = torch.argsort(loss, dim=-1)
|
695 |
+
all_losses.append(loss)
|
696 |
+
|
697 |
+
all_losses = torch.cat(all_losses, dim=-1)
|
698 |
+
output_class_ranks = torch.argsort(all_losses, dim=-1)
|
699 |
+
|
700 |
+
return output_class_ranks
|
701 |
+
|
702 |
+
def _lemmatize(self, answers):
|
703 |
+
def apply(answer):
|
704 |
+
doc = self.lemmatizer(answer)
|
705 |
+
|
706 |
+
words = []
|
707 |
+
for token in doc:
|
708 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
709 |
+
words.append(token.lemma_)
|
710 |
+
else:
|
711 |
+
words.append(token.text)
|
712 |
+
answer = " ".join(words)
|
713 |
+
|
714 |
+
return answer
|
715 |
+
|
716 |
+
return [apply(answer) for answer in answers]
|
717 |
+
|
718 |
+
@property
|
719 |
+
def lemmatizer(self):
|
720 |
+
if self._lemmatizer is None:
|
721 |
+
try:
|
722 |
+
import spacy
|
723 |
+
|
724 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
725 |
+
except ImportError:
|
726 |
+
logging.error(
|
727 |
+
"""
|
728 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
729 |
+
python -m spacy download en_core_web_sm
|
730 |
+
OR
|
731 |
+
import spacy.cli
|
732 |
+
spacy.cli.download("en_core_web_sm")
|
733 |
+
"""
|
734 |
+
)
|
735 |
+
exit(1)
|
736 |
+
|
737 |
+
return self._lemmatizer
|
738 |
+
|
739 |
+
@classmethod
|
740 |
+
def from_config(cls, cfg):
|
741 |
+
vit_model = cfg.get("vit_model", "eva_clip_g")
|
742 |
+
img_size = cfg.get("image_size")
|
743 |
+
num_query_token = cfg.get("num_query_token")
|
744 |
+
llm_model = cfg.get("llm_model")
|
745 |
+
|
746 |
+
drop_path_rate = cfg.get("drop_path_rate", 0)
|
747 |
+
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
748 |
+
vit_precision = cfg.get("vit_precision", "fp16")
|
749 |
+
freeze_vit = cfg.get("freeze_vit", True)
|
750 |
+
|
751 |
+
prompt = cfg.get("prompt", "")
|
752 |
+
max_txt_len = cfg.get("max_txt_len", 128)
|
753 |
+
max_output_txt_len = cfg.get("max_output_txt_len", 256)
|
754 |
+
|
755 |
+
apply_lemmatizer = cfg.get("apply_lemmatizer", False)
|
756 |
+
|
757 |
+
qformer_text_input = cfg.get("qformer_text_input", True)
|
758 |
+
|
759 |
+
model = cls(
|
760 |
+
vit_model=vit_model,
|
761 |
+
img_size=img_size,
|
762 |
+
drop_path_rate=drop_path_rate,
|
763 |
+
use_grad_checkpoint=use_grad_checkpoint,
|
764 |
+
vit_precision=vit_precision,
|
765 |
+
freeze_vit=freeze_vit,
|
766 |
+
num_query_token=num_query_token,
|
767 |
+
llm_model=llm_model,
|
768 |
+
prompt=prompt,
|
769 |
+
max_txt_len=max_txt_len,
|
770 |
+
max_output_txt_len=max_output_txt_len,
|
771 |
+
apply_lemmatizer=apply_lemmatizer,
|
772 |
+
qformer_text_input=qformer_text_input,
|
773 |
+
)
|
774 |
+
|
775 |
+
# if qformer_text_input:
|
776 |
+
# # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
|
777 |
+
# model.load_from_pretrained(
|
778 |
+
# url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
|
779 |
+
# )
|
780 |
+
|
781 |
+
model.load_checkpoint_from_config(cfg)
|
782 |
+
|
783 |
+
return model
|
bliva/models/clip_vit.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from itertools import repeat
|
3 |
+
import collections.abc
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
11 |
+
|
12 |
+
from bliva.models.eva_vit import convert_weights_to_fp16
|
13 |
+
from bliva.common.dist_utils import download_cached_file
|
14 |
+
|
15 |
+
class Bottleneck(nn.Module):
|
16 |
+
expansion = 4
|
17 |
+
|
18 |
+
def __init__(self, inplanes, planes, stride=1):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
22 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
23 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
24 |
+
self.relu1 = nn.ReLU(inplace=True)
|
25 |
+
|
26 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
27 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
28 |
+
self.relu2 = nn.ReLU(inplace=True)
|
29 |
+
|
30 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
31 |
+
|
32 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
33 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
34 |
+
self.relu3 = nn.ReLU(inplace=True)
|
35 |
+
|
36 |
+
self.downsample = None
|
37 |
+
self.stride = stride
|
38 |
+
|
39 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
40 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
41 |
+
self.downsample = nn.Sequential(OrderedDict([
|
42 |
+
("-1", nn.AvgPool2d(stride)),
|
43 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
44 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
45 |
+
]))
|
46 |
+
|
47 |
+
def forward(self, x: torch.Tensor):
|
48 |
+
identity = x
|
49 |
+
|
50 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
51 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
52 |
+
out = self.avgpool(out)
|
53 |
+
out = self.bn3(self.conv3(out))
|
54 |
+
|
55 |
+
if self.downsample is not None:
|
56 |
+
identity = self.downsample(x)
|
57 |
+
|
58 |
+
out += identity
|
59 |
+
out = self.relu3(out)
|
60 |
+
return out
|
61 |
+
|
62 |
+
|
63 |
+
class AttentionPool2d(nn.Module):
|
64 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
65 |
+
super().__init__()
|
66 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
67 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
68 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
69 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
70 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
71 |
+
self.num_heads = num_heads
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
75 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
76 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
77 |
+
x, _ = F.multi_head_attention_forward(
|
78 |
+
query=x, key=x, value=x,
|
79 |
+
embed_dim_to_check=x.shape[-1],
|
80 |
+
num_heads=self.num_heads,
|
81 |
+
q_proj_weight=self.q_proj.weight,
|
82 |
+
k_proj_weight=self.k_proj.weight,
|
83 |
+
v_proj_weight=self.v_proj.weight,
|
84 |
+
in_proj_weight=None,
|
85 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
86 |
+
bias_k=None,
|
87 |
+
bias_v=None,
|
88 |
+
add_zero_attn=False,
|
89 |
+
dropout_p=0,
|
90 |
+
out_proj_weight=self.c_proj.weight,
|
91 |
+
out_proj_bias=self.c_proj.bias,
|
92 |
+
use_separate_proj_weight=True,
|
93 |
+
training=self.training,
|
94 |
+
need_weights=False
|
95 |
+
)
|
96 |
+
|
97 |
+
return x[0]
|
98 |
+
|
99 |
+
|
100 |
+
class LayerNorm(nn.LayerNorm):
|
101 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
102 |
+
|
103 |
+
def forward(self, x: torch.Tensor):
|
104 |
+
orig_type = x.dtype
|
105 |
+
ret = super().forward(x.type(torch.float32))
|
106 |
+
return ret.type(orig_type)
|
107 |
+
|
108 |
+
|
109 |
+
class QuickGELU(nn.Module):
|
110 |
+
def forward(self, x: torch.Tensor):
|
111 |
+
return x * torch.sigmoid(1.702 * x)
|
112 |
+
|
113 |
+
|
114 |
+
class ResidualAttentionBlock(nn.Module):
|
115 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
|
116 |
+
super().__init__()
|
117 |
+
|
118 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
119 |
+
self.ln_1 = LayerNorm(d_model)
|
120 |
+
self.mlp = nn.Sequential(OrderedDict([
|
121 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
122 |
+
("gelu", QuickGELU()),
|
123 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
124 |
+
]))
|
125 |
+
self.ln_2 = LayerNorm(d_model)
|
126 |
+
self.attn_mask = attn_mask
|
127 |
+
|
128 |
+
if use_grad_checkpointing:
|
129 |
+
self.attn = checkpoint_wrapper(self.attn)
|
130 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
131 |
+
|
132 |
+
def attention(self, x: torch.Tensor):
|
133 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
134 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
135 |
+
|
136 |
+
def forward(self, x: torch.Tensor):
|
137 |
+
x = x + self.attention(self.ln_1(x))
|
138 |
+
x = x + self.mlp(self.ln_2(x))
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class Transformer(nn.Module):
|
143 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
|
144 |
+
super().__init__()
|
145 |
+
self.width = width
|
146 |
+
self.layers = layers
|
147 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
|
148 |
+
|
149 |
+
def forward(self, x: torch.Tensor):
|
150 |
+
return self.resblocks(x)
|
151 |
+
|
152 |
+
def get_second_last_feature(self, x: torch.Tensor):
|
153 |
+
for i in range(len(self.resblocks) - 2):
|
154 |
+
x = self.resblocks[i](x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
|
158 |
+
class VisionTransformer(nn.Module):
|
159 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
|
160 |
+
super().__init__()
|
161 |
+
self.input_resolution = input_resolution
|
162 |
+
self.num_features = width
|
163 |
+
self.num_heads = heads
|
164 |
+
self.num_patches = (input_resolution // patch_size) ** 2
|
165 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
166 |
+
|
167 |
+
scale = width ** -0.5
|
168 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
169 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
|
170 |
+
self.ln_pre = LayerNorm(width)
|
171 |
+
|
172 |
+
self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
|
173 |
+
|
174 |
+
# self.ln_final = LayerNorm(width)
|
175 |
+
|
176 |
+
def forward(self, x: torch.Tensor):
|
177 |
+
|
178 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
179 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
180 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
181 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
182 |
+
x = x + self.positional_embedding.to(x.dtype)
|
183 |
+
x = self.ln_pre(x)
|
184 |
+
|
185 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
186 |
+
x = self.transformer(x)
|
187 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
188 |
+
|
189 |
+
# x = self.ln_final(x)
|
190 |
+
return x
|
191 |
+
|
192 |
+
def get_last_second_feature(self, x):
|
193 |
+
|
194 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
195 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
196 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
197 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
198 |
+
x = x + self.positional_embedding.to(x.dtype)
|
199 |
+
x = self.ln_pre(x)
|
200 |
+
|
201 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
202 |
+
x = self.transformer.get_second_last_feature(x)
|
203 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
204 |
+
|
205 |
+
return x
|
206 |
+
|
207 |
+
# From PyTorch internals
|
208 |
+
def _ntuple(n):
|
209 |
+
def parse(x):
|
210 |
+
if isinstance(x, collections.abc.Iterable):
|
211 |
+
return x
|
212 |
+
return tuple(repeat(x, n))
|
213 |
+
return parse
|
214 |
+
to_2tuple = _ntuple(2)
|
215 |
+
def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1):
|
216 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
217 |
+
old_pos_embed = state_dict.get('positional_embedding', None)
|
218 |
+
|
219 |
+
grid_size = round((model.positional_embedding.shape[0] - 1) ** 0.5)
|
220 |
+
if old_pos_embed is None:
|
221 |
+
return
|
222 |
+
grid_size = to_2tuple(grid_size)
|
223 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
224 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
225 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
226 |
+
return
|
227 |
+
|
228 |
+
if extra_tokens:
|
229 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
230 |
+
else:
|
231 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
232 |
+
|
233 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
234 |
+
|
235 |
+
print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
236 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
237 |
+
pos_emb_img = F.interpolate(
|
238 |
+
pos_emb_img,
|
239 |
+
size=grid_size,
|
240 |
+
mode=interpolation,
|
241 |
+
align_corners=True,
|
242 |
+
)
|
243 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
244 |
+
if pos_emb_tok is not None:
|
245 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
246 |
+
else:
|
247 |
+
new_pos_embed = pos_emb_img
|
248 |
+
state_dict['positional_embedding'] = new_pos_embed
|
249 |
+
|
250 |
+
|
251 |
+
def create_clip_vit_L(img_size=224,use_checkpoint=False,precision="fp16"):
|
252 |
+
model = VisionTransformer(
|
253 |
+
input_resolution=img_size,
|
254 |
+
patch_size=14,
|
255 |
+
width=1024,
|
256 |
+
layers=23,
|
257 |
+
heads=16,
|
258 |
+
use_grad_checkpointing=use_checkpoint,
|
259 |
+
)
|
260 |
+
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth"
|
261 |
+
cached_file = download_cached_file(
|
262 |
+
url, check_hash=False, progress=True
|
263 |
+
)
|
264 |
+
state_dict = torch.load(cached_file, map_location="cpu")
|
265 |
+
interpolate_pos_embed(model,state_dict)
|
266 |
+
|
267 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
268 |
+
# print(incompatible_keys)
|
269 |
+
|
270 |
+
if precision == "fp16":
|
271 |
+
convert_weights_to_fp16(model)
|
272 |
+
return model
|
bliva/models/eva_vit.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on EVA, BEIT, timm and DeiT code bases
|
2 |
+
# https://github.com/baaivision/EVA
|
3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
4 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
5 |
+
# https://github.com/facebookresearch/deit/
|
6 |
+
# https://github.com/facebookresearch/dino
|
7 |
+
# --------------------------------------------------------'
|
8 |
+
import math
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.utils.checkpoint as checkpoint
|
15 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
16 |
+
from timm.models.registry import register_model
|
17 |
+
|
18 |
+
from bliva.common.dist_utils import download_cached_file
|
19 |
+
|
20 |
+
def _cfg(url='', **kwargs):
|
21 |
+
return {
|
22 |
+
'url': url,
|
23 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
24 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
25 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
26 |
+
**kwargs
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
class DropPath(nn.Module):
|
31 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
32 |
+
"""
|
33 |
+
def __init__(self, drop_prob=None):
|
34 |
+
super(DropPath, self).__init__()
|
35 |
+
self.drop_prob = drop_prob
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return drop_path(x, self.drop_prob, self.training)
|
39 |
+
|
40 |
+
def extra_repr(self) -> str:
|
41 |
+
return 'p={}'.format(self.drop_prob)
|
42 |
+
|
43 |
+
|
44 |
+
class Mlp(nn.Module):
|
45 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
46 |
+
super().__init__()
|
47 |
+
out_features = out_features or in_features
|
48 |
+
hidden_features = hidden_features or in_features
|
49 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
50 |
+
self.act = act_layer()
|
51 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
52 |
+
self.drop = nn.Dropout(drop)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = self.fc1(x)
|
56 |
+
x = self.act(x)
|
57 |
+
# x = self.drop(x)
|
58 |
+
# commit this for the orignal BERT implement
|
59 |
+
x = self.fc2(x)
|
60 |
+
x = self.drop(x)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class Attention(nn.Module):
|
65 |
+
def __init__(
|
66 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
67 |
+
proj_drop=0., window_size=None, attn_head_dim=None):
|
68 |
+
super().__init__()
|
69 |
+
self.num_heads = num_heads
|
70 |
+
head_dim = dim // num_heads
|
71 |
+
if attn_head_dim is not None:
|
72 |
+
head_dim = attn_head_dim
|
73 |
+
all_head_dim = head_dim * self.num_heads
|
74 |
+
self.scale = qk_scale or head_dim ** -0.5
|
75 |
+
|
76 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
77 |
+
if qkv_bias:
|
78 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
79 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
80 |
+
else:
|
81 |
+
self.q_bias = None
|
82 |
+
self.v_bias = None
|
83 |
+
|
84 |
+
if window_size:
|
85 |
+
self.window_size = window_size
|
86 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
87 |
+
self.relative_position_bias_table = nn.Parameter(
|
88 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
89 |
+
# cls to token & token 2 cls & cls to cls
|
90 |
+
|
91 |
+
# get pair-wise relative position index for each token inside the window
|
92 |
+
coords_h = torch.arange(window_size[0])
|
93 |
+
coords_w = torch.arange(window_size[1])
|
94 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
95 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
96 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
97 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
98 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
99 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
100 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
101 |
+
relative_position_index = \
|
102 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
103 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
104 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
105 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
106 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
107 |
+
|
108 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
109 |
+
else:
|
110 |
+
self.window_size = None
|
111 |
+
self.relative_position_bias_table = None
|
112 |
+
self.relative_position_index = None
|
113 |
+
|
114 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
115 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
116 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
117 |
+
|
118 |
+
def forward(self, x, rel_pos_bias=None):
|
119 |
+
B, N, C = x.shape
|
120 |
+
qkv_bias = None
|
121 |
+
if self.q_bias is not None:
|
122 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
123 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
124 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
125 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
126 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
127 |
+
|
128 |
+
q = q * self.scale
|
129 |
+
attn = (q @ k.transpose(-2, -1))
|
130 |
+
|
131 |
+
if self.relative_position_bias_table is not None:
|
132 |
+
relative_position_bias = \
|
133 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
134 |
+
self.window_size[0] * self.window_size[1] + 1,
|
135 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
136 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
137 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
138 |
+
|
139 |
+
if rel_pos_bias is not None:
|
140 |
+
attn = attn + rel_pos_bias
|
141 |
+
|
142 |
+
attn = attn.softmax(dim=-1)
|
143 |
+
attn = self.attn_drop(attn)
|
144 |
+
|
145 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
146 |
+
x = self.proj(x)
|
147 |
+
x = self.proj_drop(x)
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class Block(nn.Module):
|
152 |
+
|
153 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
154 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
155 |
+
window_size=None, attn_head_dim=None):
|
156 |
+
super().__init__()
|
157 |
+
self.norm1 = norm_layer(dim)
|
158 |
+
self.attn = Attention(
|
159 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
160 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
161 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
162 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
163 |
+
self.norm2 = norm_layer(dim)
|
164 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
165 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
166 |
+
|
167 |
+
if init_values is not None and init_values > 0:
|
168 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
169 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
170 |
+
else:
|
171 |
+
self.gamma_1, self.gamma_2 = None, None
|
172 |
+
|
173 |
+
def forward(self, x, rel_pos_bias=None):
|
174 |
+
if self.gamma_1 is None:
|
175 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
176 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
177 |
+
else:
|
178 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
179 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class PatchEmbed(nn.Module):
|
184 |
+
""" Image to Patch Embedding
|
185 |
+
"""
|
186 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
187 |
+
super().__init__()
|
188 |
+
img_size = to_2tuple(img_size)
|
189 |
+
patch_size = to_2tuple(patch_size)
|
190 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
191 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
192 |
+
self.img_size = img_size
|
193 |
+
self.patch_size = patch_size
|
194 |
+
self.num_patches = num_patches
|
195 |
+
|
196 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
197 |
+
|
198 |
+
def forward(self, x, **kwargs):
|
199 |
+
B, C, H, W = x.shape
|
200 |
+
# FIXME look at relaxing size constraints
|
201 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
202 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
203 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class RelativePositionBias(nn.Module):
|
208 |
+
|
209 |
+
def __init__(self, window_size, num_heads):
|
210 |
+
super().__init__()
|
211 |
+
self.window_size = window_size
|
212 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
213 |
+
self.relative_position_bias_table = nn.Parameter(
|
214 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
215 |
+
# cls to token & token 2 cls & cls to cls
|
216 |
+
|
217 |
+
# get pair-wise relative position index for each token inside the window
|
218 |
+
coords_h = torch.arange(window_size[0])
|
219 |
+
coords_w = torch.arange(window_size[1])
|
220 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
221 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
222 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
223 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
224 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
225 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
226 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
227 |
+
relative_position_index = \
|
228 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
229 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
230 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
231 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
232 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
233 |
+
|
234 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
235 |
+
|
236 |
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
237 |
+
|
238 |
+
def forward(self):
|
239 |
+
relative_position_bias = \
|
240 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
241 |
+
self.window_size[0] * self.window_size[1] + 1,
|
242 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
243 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
244 |
+
|
245 |
+
|
246 |
+
class VisionTransformer(nn.Module):
|
247 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
248 |
+
"""
|
249 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
250 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
251 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
|
252 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
|
253 |
+
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
|
254 |
+
super().__init__()
|
255 |
+
self.image_size = img_size
|
256 |
+
self.num_classes = num_classes
|
257 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
258 |
+
|
259 |
+
self.patch_embed = PatchEmbed(
|
260 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
261 |
+
num_patches = self.patch_embed.num_patches
|
262 |
+
|
263 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
264 |
+
if use_abs_pos_emb:
|
265 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
266 |
+
else:
|
267 |
+
self.pos_embed = None
|
268 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
269 |
+
|
270 |
+
if use_shared_rel_pos_bias:
|
271 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
272 |
+
else:
|
273 |
+
self.rel_pos_bias = None
|
274 |
+
self.use_checkpoint = use_checkpoint
|
275 |
+
|
276 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
277 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
278 |
+
self.blocks = nn.ModuleList([
|
279 |
+
Block(
|
280 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
281 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
282 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
283 |
+
for i in range(depth)])
|
284 |
+
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
285 |
+
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
286 |
+
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
287 |
+
|
288 |
+
if self.pos_embed is not None:
|
289 |
+
trunc_normal_(self.pos_embed, std=.02)
|
290 |
+
trunc_normal_(self.cls_token, std=.02)
|
291 |
+
# trunc_normal_(self.mask_token, std=.02)
|
292 |
+
# if isinstance(self.head, nn.Linear):
|
293 |
+
# trunc_normal_(self.head.weight, std=.02)
|
294 |
+
self.apply(self._init_weights)
|
295 |
+
self.fix_init_weight()
|
296 |
+
# if isinstance(self.head, nn.Linear):
|
297 |
+
# self.head.weight.data.mul_(init_scale)
|
298 |
+
# self.head.bias.data.mul_(init_scale)
|
299 |
+
|
300 |
+
def fix_init_weight(self):
|
301 |
+
def rescale(param, layer_id):
|
302 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
303 |
+
|
304 |
+
for layer_id, layer in enumerate(self.blocks):
|
305 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
306 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
307 |
+
|
308 |
+
def _init_weights(self, m):
|
309 |
+
if isinstance(m, nn.Linear):
|
310 |
+
trunc_normal_(m.weight, std=.02)
|
311 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
312 |
+
nn.init.constant_(m.bias, 0)
|
313 |
+
elif isinstance(m, nn.LayerNorm):
|
314 |
+
nn.init.constant_(m.bias, 0)
|
315 |
+
nn.init.constant_(m.weight, 1.0)
|
316 |
+
|
317 |
+
def get_classifier(self):
|
318 |
+
return self.head
|
319 |
+
|
320 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
321 |
+
self.num_classes = num_classes
|
322 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
323 |
+
|
324 |
+
def forward_features(self, x):
|
325 |
+
x = self.patch_embed(x)
|
326 |
+
batch_size, seq_len, _ = x.size()
|
327 |
+
|
328 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
329 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
330 |
+
if self.pos_embed is not None:
|
331 |
+
x = x + self.pos_embed
|
332 |
+
x = self.pos_drop(x)
|
333 |
+
|
334 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
335 |
+
for blk in self.blocks:
|
336 |
+
if self.use_checkpoint:
|
337 |
+
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
338 |
+
else:
|
339 |
+
x = blk(x, rel_pos_bias)
|
340 |
+
return x
|
341 |
+
# x = self.norm(x)
|
342 |
+
|
343 |
+
# if self.fc_norm is not None:
|
344 |
+
# t = x[:, 1:, :]
|
345 |
+
# return self.fc_norm(t.mean(1))
|
346 |
+
# else:
|
347 |
+
# return x[:, 0]
|
348 |
+
|
349 |
+
def forward(self, x):
|
350 |
+
x = self.forward_features(x)
|
351 |
+
# x = self.head(x)
|
352 |
+
return x
|
353 |
+
|
354 |
+
def get_intermediate_layers(self, x):
|
355 |
+
x = self.patch_embed(x)
|
356 |
+
batch_size, seq_len, _ = x.size()
|
357 |
+
|
358 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
359 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
360 |
+
if self.pos_embed is not None:
|
361 |
+
x = x + self.pos_embed
|
362 |
+
x = self.pos_drop(x)
|
363 |
+
|
364 |
+
features = []
|
365 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
366 |
+
for blk in self.blocks:
|
367 |
+
x = blk(x, rel_pos_bias)
|
368 |
+
features.append(x)
|
369 |
+
|
370 |
+
return features
|
371 |
+
|
372 |
+
|
373 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
374 |
+
if 'pos_embed' in checkpoint_model:
|
375 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
|
376 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
377 |
+
num_patches = model.patch_embed.num_patches
|
378 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
379 |
+
# height (== width) for the checkpoint position embedding
|
380 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
381 |
+
# height (== width) for the new position embedding
|
382 |
+
new_size = int(num_patches ** 0.5)
|
383 |
+
# class_token and dist_token are kept unchanged
|
384 |
+
if orig_size != new_size:
|
385 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
386 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
387 |
+
# only the position tokens are interpolated
|
388 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
389 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
390 |
+
pos_tokens = torch.nn.functional.interpolate(
|
391 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
392 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
393 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
394 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
395 |
+
|
396 |
+
|
397 |
+
def convert_weights_to_fp16(model: nn.Module):
|
398 |
+
"""Convert applicable model parameters to fp16"""
|
399 |
+
|
400 |
+
def _convert_weights_to_fp16(l):
|
401 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
402 |
+
l.weight.data = l.weight.data.half()
|
403 |
+
if l.bias is not None:
|
404 |
+
l.bias.data = l.bias.data.half()
|
405 |
+
|
406 |
+
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
407 |
+
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
408 |
+
# tensor = getattr(l, attr)
|
409 |
+
# if tensor is not None:
|
410 |
+
# tensor.data = tensor.data.half()
|
411 |
+
|
412 |
+
model.apply(_convert_weights_to_fp16)
|
413 |
+
|
414 |
+
|
415 |
+
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
|
416 |
+
model = VisionTransformer(
|
417 |
+
img_size=img_size,
|
418 |
+
patch_size=14,
|
419 |
+
use_mean_pooling=False,
|
420 |
+
embed_dim=1408,
|
421 |
+
depth=39,
|
422 |
+
num_heads=1408//88,
|
423 |
+
mlp_ratio=4.3637,
|
424 |
+
qkv_bias=True,
|
425 |
+
drop_path_rate=drop_path_rate,
|
426 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
427 |
+
use_checkpoint=use_checkpoint,
|
428 |
+
)
|
429 |
+
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
430 |
+
cached_file = download_cached_file(
|
431 |
+
url, check_hash=False, progress=True
|
432 |
+
)
|
433 |
+
state_dict = torch.load(cached_file, map_location="cpu")
|
434 |
+
interpolate_pos_embed(model,state_dict)
|
435 |
+
|
436 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
437 |
+
# print(incompatible_keys)
|
438 |
+
|
439 |
+
if precision == "fp16":
|
440 |
+
# model.to("cuda")
|
441 |
+
convert_weights_to_fp16(model)
|
442 |
+
return model
|
bliva/models/modeling_llama.py
ADDED
@@ -0,0 +1,888 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" PyTorch LLaMA model."""
|
21 |
+
import math
|
22 |
+
from typing import List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.utils.checkpoint
|
26 |
+
from torch import nn
|
27 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
28 |
+
|
29 |
+
from transformers.activations import ACT2FN
|
30 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
31 |
+
from transformers.modeling_utils import PreTrainedModel
|
32 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
_CONFIG_FOR_DOC = "LlamaConfig"
|
39 |
+
|
40 |
+
|
41 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
42 |
+
def _make_causal_mask(
|
43 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Make causal mask used for bi-directional self-attention.
|
47 |
+
"""
|
48 |
+
bsz, tgt_len = input_ids_shape
|
49 |
+
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
50 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
51 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
52 |
+
mask = mask.to(dtype)
|
53 |
+
|
54 |
+
if past_key_values_length > 0:
|
55 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
56 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
57 |
+
|
58 |
+
|
59 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
60 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
61 |
+
"""
|
62 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
63 |
+
"""
|
64 |
+
bsz, src_len = mask.size()
|
65 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
66 |
+
|
67 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
68 |
+
|
69 |
+
inverted_mask = 1.0 - expanded_mask
|
70 |
+
|
71 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
72 |
+
|
73 |
+
|
74 |
+
class LlamaRMSNorm(nn.Module):
|
75 |
+
def __init__(self, hidden_size, eps=1e-6):
|
76 |
+
"""
|
77 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
78 |
+
"""
|
79 |
+
super().__init__()
|
80 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
81 |
+
self.variance_epsilon = eps
|
82 |
+
|
83 |
+
def forward(self, hidden_states):
|
84 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
85 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
86 |
+
|
87 |
+
# convert into half-precision if necessary
|
88 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
89 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
90 |
+
|
91 |
+
return self.weight * hidden_states
|
92 |
+
|
93 |
+
|
94 |
+
class LlamaRotaryEmbedding(torch.nn.Module):
|
95 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
96 |
+
super().__init__()
|
97 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
98 |
+
self.register_buffer("inv_freq", inv_freq)
|
99 |
+
|
100 |
+
# Build here to make `torch.jit.trace` work.
|
101 |
+
self.max_seq_len_cached = max_position_embeddings
|
102 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
103 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
104 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
105 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
106 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
107 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
108 |
+
|
109 |
+
def forward(self, x, seq_len=None):
|
110 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
111 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
112 |
+
if seq_len > self.max_seq_len_cached:
|
113 |
+
self.max_seq_len_cached = seq_len
|
114 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
115 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
116 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
117 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
118 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
119 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
120 |
+
return (
|
121 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
122 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def rotate_half(x):
|
127 |
+
"""Rotates half the hidden dims of the input."""
|
128 |
+
x1 = x[..., : x.shape[-1] // 2]
|
129 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
130 |
+
return torch.cat((-x2, x1), dim=-1)
|
131 |
+
|
132 |
+
|
133 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
134 |
+
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
135 |
+
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
136 |
+
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
137 |
+
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
138 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
139 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
140 |
+
return q_embed, k_embed
|
141 |
+
|
142 |
+
|
143 |
+
class LlamaMLP(nn.Module):
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
hidden_size: int,
|
147 |
+
intermediate_size: int,
|
148 |
+
hidden_act: str,
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
152 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
153 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
154 |
+
self.act_fn = ACT2FN[hidden_act]
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
158 |
+
|
159 |
+
|
160 |
+
class LlamaAttention(nn.Module):
|
161 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
162 |
+
|
163 |
+
def __init__(self, config: LlamaConfig):
|
164 |
+
super().__init__()
|
165 |
+
self.config = config
|
166 |
+
self.hidden_size = config.hidden_size
|
167 |
+
self.num_heads = config.num_attention_heads
|
168 |
+
self.head_dim = self.hidden_size // self.num_heads
|
169 |
+
self.max_position_embeddings = config.max_position_embeddings
|
170 |
+
|
171 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
172 |
+
raise ValueError(
|
173 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
174 |
+
f" and `num_heads`: {self.num_heads})."
|
175 |
+
)
|
176 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
177 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
178 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
179 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
180 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
181 |
+
|
182 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
183 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
184 |
+
|
185 |
+
def forward(
|
186 |
+
self,
|
187 |
+
hidden_states: torch.Tensor,
|
188 |
+
attention_mask: Optional[torch.Tensor] = None,
|
189 |
+
position_ids: Optional[torch.LongTensor] = None,
|
190 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
191 |
+
output_attentions: bool = False,
|
192 |
+
use_cache: bool = False,
|
193 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
194 |
+
bsz, q_len, _ = hidden_states.size()
|
195 |
+
|
196 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
197 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
198 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
199 |
+
|
200 |
+
kv_seq_len = key_states.shape[-2]
|
201 |
+
if past_key_value is not None:
|
202 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
203 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
204 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
205 |
+
# [bsz, nh, t, hd]
|
206 |
+
|
207 |
+
if past_key_value is not None:
|
208 |
+
# reuse k, v, self_attention
|
209 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
210 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
211 |
+
|
212 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
213 |
+
|
214 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
215 |
+
|
216 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
217 |
+
raise ValueError(
|
218 |
+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
219 |
+
f" {attn_weights.size()}"
|
220 |
+
)
|
221 |
+
|
222 |
+
if attention_mask is not None:
|
223 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
224 |
+
raise ValueError(
|
225 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
226 |
+
)
|
227 |
+
attn_weights = attn_weights + attention_mask
|
228 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
229 |
+
|
230 |
+
# upcast attention to fp32
|
231 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
232 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
233 |
+
|
234 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
235 |
+
raise ValueError(
|
236 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
237 |
+
f" {attn_output.size()}"
|
238 |
+
)
|
239 |
+
|
240 |
+
attn_output = attn_output.transpose(1, 2)
|
241 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
242 |
+
|
243 |
+
attn_output = self.o_proj(attn_output)
|
244 |
+
|
245 |
+
if not output_attentions:
|
246 |
+
attn_weights = None
|
247 |
+
|
248 |
+
return attn_output, attn_weights, past_key_value
|
249 |
+
|
250 |
+
|
251 |
+
class LlamaDecoderLayer(nn.Module):
|
252 |
+
def __init__(self, config: LlamaConfig):
|
253 |
+
super().__init__()
|
254 |
+
self.hidden_size = config.hidden_size
|
255 |
+
self.self_attn = LlamaAttention(config=config)
|
256 |
+
self.mlp = LlamaMLP(
|
257 |
+
hidden_size=self.hidden_size,
|
258 |
+
intermediate_size=config.intermediate_size,
|
259 |
+
hidden_act=config.hidden_act,
|
260 |
+
)
|
261 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
262 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
263 |
+
|
264 |
+
def forward(
|
265 |
+
self,
|
266 |
+
hidden_states: torch.Tensor,
|
267 |
+
attention_mask: Optional[torch.Tensor] = None,
|
268 |
+
position_ids: Optional[torch.LongTensor] = None,
|
269 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
270 |
+
output_attentions: Optional[bool] = False,
|
271 |
+
use_cache: Optional[bool] = False,
|
272 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
273 |
+
"""
|
274 |
+
Args:
|
275 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
276 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
277 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
278 |
+
output_attentions (`bool`, *optional*):
|
279 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
280 |
+
returned tensors for more detail.
|
281 |
+
use_cache (`bool`, *optional*):
|
282 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
283 |
+
(see `past_key_values`).
|
284 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
285 |
+
"""
|
286 |
+
|
287 |
+
residual = hidden_states
|
288 |
+
|
289 |
+
hidden_states = self.input_layernorm(hidden_states)
|
290 |
+
|
291 |
+
# Self Attention
|
292 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
293 |
+
hidden_states=hidden_states,
|
294 |
+
attention_mask=attention_mask,
|
295 |
+
position_ids=position_ids,
|
296 |
+
past_key_value=past_key_value,
|
297 |
+
output_attentions=output_attentions,
|
298 |
+
use_cache=use_cache,
|
299 |
+
)
|
300 |
+
hidden_states = residual + hidden_states
|
301 |
+
|
302 |
+
# Fully Connected
|
303 |
+
residual = hidden_states
|
304 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
305 |
+
hidden_states = self.mlp(hidden_states)
|
306 |
+
hidden_states = residual + hidden_states
|
307 |
+
|
308 |
+
outputs = (hidden_states,)
|
309 |
+
|
310 |
+
if output_attentions:
|
311 |
+
outputs += (self_attn_weights,)
|
312 |
+
|
313 |
+
if use_cache:
|
314 |
+
outputs += (present_key_value,)
|
315 |
+
|
316 |
+
return outputs
|
317 |
+
|
318 |
+
|
319 |
+
LLAMA_START_DOCSTRING = r"""
|
320 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
321 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
322 |
+
etc.)
|
323 |
+
|
324 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
325 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
326 |
+
and behavior.
|
327 |
+
|
328 |
+
Parameters:
|
329 |
+
config ([`LlamaConfig`]):
|
330 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
331 |
+
load the weights associated with the model, only the configuration. Check out the
|
332 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
333 |
+
"""
|
334 |
+
|
335 |
+
|
336 |
+
@add_start_docstrings(
|
337 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
338 |
+
LLAMA_START_DOCSTRING,
|
339 |
+
)
|
340 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
341 |
+
config_class = LlamaConfig
|
342 |
+
base_model_prefix = "model"
|
343 |
+
supports_gradient_checkpointing = True
|
344 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
345 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
346 |
+
|
347 |
+
def _init_weights(self, module):
|
348 |
+
std = self.config.initializer_range
|
349 |
+
if isinstance(module, nn.Linear):
|
350 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
351 |
+
if module.bias is not None:
|
352 |
+
module.bias.data.zero_()
|
353 |
+
elif isinstance(module, nn.Embedding):
|
354 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
355 |
+
if module.padding_idx is not None:
|
356 |
+
module.weight.data[module.padding_idx].zero_()
|
357 |
+
|
358 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
359 |
+
if isinstance(module, LlamaModel):
|
360 |
+
module.gradient_checkpointing = value
|
361 |
+
|
362 |
+
|
363 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
364 |
+
Args:
|
365 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
366 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
367 |
+
it.
|
368 |
+
|
369 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
370 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
371 |
+
|
372 |
+
[What are input IDs?](../glossary#input-ids)
|
373 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
374 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
375 |
+
|
376 |
+
- 1 for tokens that are **not masked**,
|
377 |
+
- 0 for tokens that are **masked**.
|
378 |
+
|
379 |
+
[What are attention masks?](../glossary#attention-mask)
|
380 |
+
|
381 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
382 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
383 |
+
|
384 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
385 |
+
`past_key_values`).
|
386 |
+
|
387 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
388 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
389 |
+
information on the default strategy.
|
390 |
+
|
391 |
+
- 1 indicates the head is **not masked**,
|
392 |
+
- 0 indicates the head is **masked**.
|
393 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
394 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
395 |
+
config.n_positions - 1]`.
|
396 |
+
|
397 |
+
[What are position IDs?](../glossary#position-ids)
|
398 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
399 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
400 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
401 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
402 |
+
|
403 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
404 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
405 |
+
|
406 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
407 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
408 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
409 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
410 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
411 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
412 |
+
model's internal embedding lookup matrix.
|
413 |
+
use_cache (`bool`, *optional*):
|
414 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
415 |
+
`past_key_values`).
|
416 |
+
output_attentions (`bool`, *optional*):
|
417 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
418 |
+
tensors for more detail.
|
419 |
+
output_hidden_states (`bool`, *optional*):
|
420 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
421 |
+
more detail.
|
422 |
+
return_dict (`bool`, *optional*):
|
423 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
424 |
+
"""
|
425 |
+
|
426 |
+
|
427 |
+
@add_start_docstrings(
|
428 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
429 |
+
LLAMA_START_DOCSTRING,
|
430 |
+
)
|
431 |
+
class LlamaModel(LlamaPreTrainedModel):
|
432 |
+
"""
|
433 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
434 |
+
|
435 |
+
Args:
|
436 |
+
config: LlamaConfig
|
437 |
+
"""
|
438 |
+
|
439 |
+
def __init__(self, config: LlamaConfig):
|
440 |
+
super().__init__(config)
|
441 |
+
self.padding_idx = config.pad_token_id
|
442 |
+
self.vocab_size = config.vocab_size
|
443 |
+
|
444 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
445 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
446 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
447 |
+
|
448 |
+
self.gradient_checkpointing = False
|
449 |
+
# Initialize weights and apply final processing
|
450 |
+
self.post_init()
|
451 |
+
|
452 |
+
def get_input_embeddings(self):
|
453 |
+
return self.embed_tokens
|
454 |
+
|
455 |
+
def set_input_embeddings(self, value):
|
456 |
+
self.embed_tokens = value
|
457 |
+
|
458 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
459 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
460 |
+
# create causal mask
|
461 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
462 |
+
combined_attention_mask = None
|
463 |
+
if input_shape[-1] > 1:
|
464 |
+
combined_attention_mask = _make_causal_mask(
|
465 |
+
input_shape,
|
466 |
+
inputs_embeds.dtype,
|
467 |
+
device=inputs_embeds.device,
|
468 |
+
past_key_values_length=past_key_values_length,
|
469 |
+
)
|
470 |
+
|
471 |
+
if attention_mask is not None:
|
472 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
473 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
474 |
+
inputs_embeds.device
|
475 |
+
)
|
476 |
+
combined_attention_mask = (
|
477 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
478 |
+
)
|
479 |
+
|
480 |
+
return combined_attention_mask
|
481 |
+
|
482 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
483 |
+
def forward(
|
484 |
+
self,
|
485 |
+
input_ids: torch.LongTensor = None,
|
486 |
+
attention_mask: Optional[torch.Tensor] = None,
|
487 |
+
position_ids: Optional[torch.LongTensor] = None,
|
488 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
489 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
490 |
+
use_cache: Optional[bool] = None,
|
491 |
+
output_attentions: Optional[bool] = None,
|
492 |
+
output_hidden_states: Optional[bool] = None,
|
493 |
+
return_dict: Optional[bool] = None,
|
494 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
495 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
496 |
+
output_hidden_states = (
|
497 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
498 |
+
)
|
499 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
500 |
+
|
501 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
502 |
+
|
503 |
+
# retrieve input_ids and inputs_embeds
|
504 |
+
if input_ids is not None and inputs_embeds is not None:
|
505 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
506 |
+
elif input_ids is not None:
|
507 |
+
batch_size, seq_length = input_ids.shape
|
508 |
+
elif inputs_embeds is not None:
|
509 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
510 |
+
else:
|
511 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
512 |
+
|
513 |
+
seq_length_with_past = seq_length
|
514 |
+
past_key_values_length = 0
|
515 |
+
|
516 |
+
if past_key_values is not None:
|
517 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
518 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
519 |
+
|
520 |
+
if position_ids is None:
|
521 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
522 |
+
position_ids = torch.arange(
|
523 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
524 |
+
)
|
525 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
526 |
+
else:
|
527 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
528 |
+
|
529 |
+
if inputs_embeds is None:
|
530 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
531 |
+
# embed positions
|
532 |
+
if attention_mask is None:
|
533 |
+
attention_mask = torch.ones(
|
534 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
535 |
+
)
|
536 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
537 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
538 |
+
)
|
539 |
+
|
540 |
+
hidden_states = inputs_embeds
|
541 |
+
|
542 |
+
if self.gradient_checkpointing and self.training:
|
543 |
+
if use_cache:
|
544 |
+
logger.warning_once(
|
545 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
546 |
+
)
|
547 |
+
use_cache = False
|
548 |
+
|
549 |
+
# decoder layers
|
550 |
+
all_hidden_states = () if output_hidden_states else None
|
551 |
+
all_self_attns = () if output_attentions else None
|
552 |
+
next_decoder_cache = () if use_cache else None
|
553 |
+
|
554 |
+
for idx, decoder_layer in enumerate(self.layers):
|
555 |
+
if output_hidden_states:
|
556 |
+
all_hidden_states += (hidden_states,)
|
557 |
+
|
558 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
559 |
+
|
560 |
+
if self.gradient_checkpointing and self.training:
|
561 |
+
|
562 |
+
def create_custom_forward(module):
|
563 |
+
def custom_forward(*inputs):
|
564 |
+
# None for past_key_value
|
565 |
+
return module(*inputs, output_attentions, None)
|
566 |
+
|
567 |
+
return custom_forward
|
568 |
+
|
569 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
570 |
+
create_custom_forward(decoder_layer),
|
571 |
+
hidden_states,
|
572 |
+
attention_mask,
|
573 |
+
position_ids,
|
574 |
+
None,
|
575 |
+
)
|
576 |
+
else:
|
577 |
+
layer_outputs = decoder_layer(
|
578 |
+
hidden_states,
|
579 |
+
attention_mask=attention_mask,
|
580 |
+
position_ids=position_ids,
|
581 |
+
past_key_value=past_key_value,
|
582 |
+
output_attentions=output_attentions,
|
583 |
+
use_cache=use_cache,
|
584 |
+
)
|
585 |
+
|
586 |
+
hidden_states = layer_outputs[0]
|
587 |
+
|
588 |
+
if use_cache:
|
589 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
590 |
+
|
591 |
+
if output_attentions:
|
592 |
+
all_self_attns += (layer_outputs[1],)
|
593 |
+
|
594 |
+
hidden_states = self.norm(hidden_states)
|
595 |
+
|
596 |
+
# add hidden states from the last decoder layer
|
597 |
+
if output_hidden_states:
|
598 |
+
all_hidden_states += (hidden_states,)
|
599 |
+
|
600 |
+
next_cache = next_decoder_cache if use_cache else None
|
601 |
+
if not return_dict:
|
602 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
603 |
+
return BaseModelOutputWithPast(
|
604 |
+
last_hidden_state=hidden_states,
|
605 |
+
past_key_values=next_cache,
|
606 |
+
hidden_states=all_hidden_states,
|
607 |
+
attentions=all_self_attns,
|
608 |
+
)
|
609 |
+
|
610 |
+
|
611 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
612 |
+
def __init__(self, config):
|
613 |
+
super().__init__(config)
|
614 |
+
self.model = LlamaModel(config)
|
615 |
+
|
616 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
617 |
+
|
618 |
+
# Initialize weights and apply final processing
|
619 |
+
self.post_init()
|
620 |
+
|
621 |
+
def get_input_embeddings(self):
|
622 |
+
return self.model.embed_tokens
|
623 |
+
|
624 |
+
def set_input_embeddings(self, value):
|
625 |
+
self.model.embed_tokens = value
|
626 |
+
|
627 |
+
def get_output_embeddings(self):
|
628 |
+
return self.lm_head
|
629 |
+
|
630 |
+
def set_output_embeddings(self, new_embeddings):
|
631 |
+
self.lm_head = new_embeddings
|
632 |
+
|
633 |
+
def set_decoder(self, decoder):
|
634 |
+
self.model = decoder
|
635 |
+
|
636 |
+
def get_decoder(self):
|
637 |
+
return self.model
|
638 |
+
|
639 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
640 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
641 |
+
def forward(
|
642 |
+
self,
|
643 |
+
input_ids: torch.LongTensor = None,
|
644 |
+
attention_mask: Optional[torch.Tensor] = None,
|
645 |
+
position_ids: Optional[torch.LongTensor] = None,
|
646 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
647 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
648 |
+
labels: Optional[torch.LongTensor] = None,
|
649 |
+
use_cache: Optional[bool] = None,
|
650 |
+
output_attentions: Optional[bool] = None,
|
651 |
+
output_hidden_states: Optional[bool] = None,
|
652 |
+
return_dict: Optional[bool] = None,
|
653 |
+
reduction: Optional[str] = "mean",
|
654 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
655 |
+
r"""
|
656 |
+
Args:
|
657 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
658 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
659 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
660 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
661 |
+
|
662 |
+
Returns:
|
663 |
+
|
664 |
+
Example:
|
665 |
+
|
666 |
+
```python
|
667 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
668 |
+
|
669 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
670 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
671 |
+
|
672 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
673 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
674 |
+
|
675 |
+
>>> # Generate
|
676 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
677 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
678 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
679 |
+
```"""
|
680 |
+
|
681 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
682 |
+
output_hidden_states = (
|
683 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
684 |
+
)
|
685 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
686 |
+
|
687 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
688 |
+
outputs = self.model(
|
689 |
+
input_ids=input_ids,
|
690 |
+
attention_mask=attention_mask,
|
691 |
+
position_ids=position_ids,
|
692 |
+
past_key_values=past_key_values,
|
693 |
+
inputs_embeds=inputs_embeds,
|
694 |
+
use_cache=use_cache,
|
695 |
+
output_attentions=output_attentions,
|
696 |
+
output_hidden_states=output_hidden_states,
|
697 |
+
return_dict=return_dict,
|
698 |
+
)
|
699 |
+
|
700 |
+
hidden_states = outputs[0]
|
701 |
+
logits = self.lm_head(hidden_states)
|
702 |
+
|
703 |
+
loss = None
|
704 |
+
if labels is not None:
|
705 |
+
# Shift so that tokens < n predict n
|
706 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
707 |
+
shift_labels = labels[..., 1:].contiguous()
|
708 |
+
# Flatten the tokens
|
709 |
+
loss_fct = CrossEntropyLoss(reduction=reduction)
|
710 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
711 |
+
shift_labels = shift_labels.view(-1)
|
712 |
+
# Enable model parallelism
|
713 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
714 |
+
loss = loss_fct(shift_logits, shift_labels)
|
715 |
+
if reduction == "none":
|
716 |
+
# loss = loss.view(logits.size(0), -1).sum(1)
|
717 |
+
loss = loss.view(logits.size(0), -1).mean(1)
|
718 |
+
|
719 |
+
if not return_dict:
|
720 |
+
output = (logits,) + outputs[1:]
|
721 |
+
return (loss,) + output if loss is not None else output
|
722 |
+
|
723 |
+
return CausalLMOutputWithPast(
|
724 |
+
loss=loss,
|
725 |
+
logits=logits,
|
726 |
+
past_key_values=outputs.past_key_values,
|
727 |
+
hidden_states=outputs.hidden_states,
|
728 |
+
attentions=outputs.attentions,
|
729 |
+
)
|
730 |
+
|
731 |
+
def prepare_inputs_for_generation(
|
732 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
733 |
+
):
|
734 |
+
if past_key_values:
|
735 |
+
input_ids = input_ids[:, -1:]
|
736 |
+
|
737 |
+
position_ids = kwargs.get("position_ids", None)
|
738 |
+
if attention_mask is not None and position_ids is None:
|
739 |
+
# create position_ids on the fly for batch generation
|
740 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
741 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
742 |
+
if past_key_values:
|
743 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
744 |
+
|
745 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
746 |
+
if inputs_embeds is not None and past_key_values is None:
|
747 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
748 |
+
else:
|
749 |
+
model_inputs = {"input_ids": input_ids}
|
750 |
+
|
751 |
+
model_inputs.update(
|
752 |
+
{
|
753 |
+
"position_ids": position_ids,
|
754 |
+
"past_key_values": past_key_values,
|
755 |
+
"use_cache": kwargs.get("use_cache"),
|
756 |
+
"attention_mask": attention_mask,
|
757 |
+
}
|
758 |
+
)
|
759 |
+
return model_inputs
|
760 |
+
|
761 |
+
@staticmethod
|
762 |
+
def _reorder_cache(past_key_values, beam_idx):
|
763 |
+
reordered_past = ()
|
764 |
+
for layer_past in past_key_values:
|
765 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
766 |
+
return reordered_past
|
767 |
+
|
768 |
+
|
769 |
+
@add_start_docstrings(
|
770 |
+
"""
|
771 |
+
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
772 |
+
|
773 |
+
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
774 |
+
(e.g. GPT-2) do.
|
775 |
+
|
776 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
777 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
778 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
779 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
780 |
+
each row of the batch).
|
781 |
+
""",
|
782 |
+
LLAMA_START_DOCSTRING,
|
783 |
+
)
|
784 |
+
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
785 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
786 |
+
|
787 |
+
def __init__(self, config):
|
788 |
+
super().__init__(config)
|
789 |
+
self.num_labels = config.num_labels
|
790 |
+
self.model = LlamaModel(config)
|
791 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
792 |
+
|
793 |
+
# Initialize weights and apply final processing
|
794 |
+
self.post_init()
|
795 |
+
|
796 |
+
def get_input_embeddings(self):
|
797 |
+
return self.model.embed_tokens
|
798 |
+
|
799 |
+
def set_input_embeddings(self, value):
|
800 |
+
self.model.embed_tokens = value
|
801 |
+
|
802 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
803 |
+
def forward(
|
804 |
+
self,
|
805 |
+
input_ids: torch.LongTensor = None,
|
806 |
+
attention_mask: Optional[torch.Tensor] = None,
|
807 |
+
position_ids: Optional[torch.LongTensor] = None,
|
808 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
809 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
810 |
+
labels: Optional[torch.LongTensor] = None,
|
811 |
+
use_cache: Optional[bool] = None,
|
812 |
+
output_attentions: Optional[bool] = None,
|
813 |
+
output_hidden_states: Optional[bool] = None,
|
814 |
+
return_dict: Optional[bool] = None,
|
815 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
816 |
+
r"""
|
817 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
818 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
819 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
820 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
821 |
+
"""
|
822 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
823 |
+
|
824 |
+
transformer_outputs = self.model(
|
825 |
+
input_ids,
|
826 |
+
attention_mask=attention_mask,
|
827 |
+
position_ids=position_ids,
|
828 |
+
past_key_values=past_key_values,
|
829 |
+
inputs_embeds=inputs_embeds,
|
830 |
+
use_cache=use_cache,
|
831 |
+
output_attentions=output_attentions,
|
832 |
+
output_hidden_states=output_hidden_states,
|
833 |
+
return_dict=return_dict,
|
834 |
+
)
|
835 |
+
hidden_states = transformer_outputs[0]
|
836 |
+
logits = self.score(hidden_states)
|
837 |
+
|
838 |
+
if input_ids is not None:
|
839 |
+
batch_size = input_ids.shape[0]
|
840 |
+
else:
|
841 |
+
batch_size = inputs_embeds.shape[0]
|
842 |
+
|
843 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
844 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
845 |
+
if self.config.pad_token_id is None:
|
846 |
+
sequence_lengths = -1
|
847 |
+
else:
|
848 |
+
if input_ids is not None:
|
849 |
+
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
850 |
+
else:
|
851 |
+
sequence_lengths = -1
|
852 |
+
|
853 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
854 |
+
|
855 |
+
loss = None
|
856 |
+
if labels is not None:
|
857 |
+
labels = labels.to(logits.device)
|
858 |
+
if self.config.problem_type is None:
|
859 |
+
if self.num_labels == 1:
|
860 |
+
self.config.problem_type = "regression"
|
861 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
862 |
+
self.config.problem_type = "single_label_classification"
|
863 |
+
else:
|
864 |
+
self.config.problem_type = "multi_label_classification"
|
865 |
+
|
866 |
+
if self.config.problem_type == "regression":
|
867 |
+
loss_fct = MSELoss()
|
868 |
+
if self.num_labels == 1:
|
869 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
870 |
+
else:
|
871 |
+
loss = loss_fct(pooled_logits, labels)
|
872 |
+
elif self.config.problem_type == "single_label_classification":
|
873 |
+
loss_fct = CrossEntropyLoss()
|
874 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
875 |
+
elif self.config.problem_type == "multi_label_classification":
|
876 |
+
loss_fct = BCEWithLogitsLoss()
|
877 |
+
loss = loss_fct(pooled_logits, labels)
|
878 |
+
if not return_dict:
|
879 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
880 |
+
return ((loss,) + output) if loss is not None else output
|
881 |
+
|
882 |
+
return SequenceClassifierOutputWithPast(
|
883 |
+
loss=loss,
|
884 |
+
logits=pooled_logits,
|
885 |
+
past_key_values=transformer_outputs.past_key_values,
|
886 |
+
hidden_states=transformer_outputs.hidden_states,
|
887 |
+
attentions=transformer_outputs.attentions,
|
888 |
+
)
|
bliva/models/modeling_t5.py
ADDED
@@ -0,0 +1,2063 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch T5 model."""
|
16 |
+
|
17 |
+
|
18 |
+
import copy
|
19 |
+
import math
|
20 |
+
import os
|
21 |
+
import warnings
|
22 |
+
from typing import Optional, Tuple, Union
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import CrossEntropyLoss
|
27 |
+
from torch.utils.checkpoint import checkpoint
|
28 |
+
|
29 |
+
from transformers.activations import ACT2FN
|
30 |
+
from transformers.modeling_outputs import (
|
31 |
+
BaseModelOutput,
|
32 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
33 |
+
Seq2SeqLMOutput,
|
34 |
+
Seq2SeqModelOutput,
|
35 |
+
)
|
36 |
+
from transformers.modeling_utils import PreTrainedModel
|
37 |
+
from transformers.pytorch_utils import (
|
38 |
+
ALL_LAYERNORM_LAYERS,
|
39 |
+
find_pruneable_heads_and_indices,
|
40 |
+
prune_linear_layer,
|
41 |
+
)
|
42 |
+
from transformers.utils import (
|
43 |
+
DUMMY_INPUTS,
|
44 |
+
DUMMY_MASK,
|
45 |
+
add_start_docstrings,
|
46 |
+
add_start_docstrings_to_model_forward,
|
47 |
+
is_torch_fx_proxy,
|
48 |
+
logging,
|
49 |
+
replace_return_docstrings,
|
50 |
+
)
|
51 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
52 |
+
from transformers.models.t5.configuration_t5 import T5Config
|
53 |
+
|
54 |
+
|
55 |
+
logger = logging.get_logger(__name__)
|
56 |
+
|
57 |
+
_CONFIG_FOR_DOC = "T5Config"
|
58 |
+
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
59 |
+
_CHECKPOINT_FOR_DOC = "t5-small"
|
60 |
+
|
61 |
+
####################################################
|
62 |
+
# This dict contains ids and associated url
|
63 |
+
# for the pretrained weights provided with the models
|
64 |
+
####################################################
|
65 |
+
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
66 |
+
"t5-small",
|
67 |
+
"t5-base",
|
68 |
+
"t5-large",
|
69 |
+
"t5-3b",
|
70 |
+
"t5-11b",
|
71 |
+
# See all T5 models at https://huggingface.co/models?filter=t5
|
72 |
+
]
|
73 |
+
|
74 |
+
|
75 |
+
####################################################
|
76 |
+
# This is a conversion method from TF 1.0 to PyTorch
|
77 |
+
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
|
78 |
+
####################################################
|
79 |
+
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
80 |
+
"""Load tf checkpoints in a pytorch model."""
|
81 |
+
try:
|
82 |
+
import re
|
83 |
+
|
84 |
+
import numpy as np
|
85 |
+
import tensorflow as tf
|
86 |
+
except ImportError:
|
87 |
+
logger.error(
|
88 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
89 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
90 |
+
)
|
91 |
+
raise
|
92 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
93 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
94 |
+
# Load weights from TF model
|
95 |
+
init_vars = tf.train.list_variables(tf_path)
|
96 |
+
names = []
|
97 |
+
tf_weights = {}
|
98 |
+
for name, shape in init_vars:
|
99 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
100 |
+
array = tf.train.load_variable(tf_path, name)
|
101 |
+
names.append(name)
|
102 |
+
tf_weights[name] = array
|
103 |
+
|
104 |
+
for txt_name in names:
|
105 |
+
name = txt_name.split("/")
|
106 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
107 |
+
# which are not required for using pretrained model
|
108 |
+
if any(
|
109 |
+
n
|
110 |
+
in [
|
111 |
+
"adam_v",
|
112 |
+
"adam_m",
|
113 |
+
"AdamWeightDecayOptimizer",
|
114 |
+
"AdamWeightDecayOptimizer_1",
|
115 |
+
"global_step",
|
116 |
+
]
|
117 |
+
for n in name
|
118 |
+
):
|
119 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
120 |
+
tf_weights.pop(txt_name, None)
|
121 |
+
continue
|
122 |
+
if "_slot_" in name[-1]:
|
123 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
124 |
+
tf_weights.pop(txt_name, None)
|
125 |
+
continue
|
126 |
+
pointer = model
|
127 |
+
array = tf_weights[txt_name]
|
128 |
+
|
129 |
+
for m_name in name:
|
130 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
131 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
132 |
+
else:
|
133 |
+
scope_names = [m_name]
|
134 |
+
if scope_names[0] in ["kernel", "scale", "embedding"]:
|
135 |
+
pointer = getattr(pointer, "weight")
|
136 |
+
elif scope_names[0] == "self_attention":
|
137 |
+
pointer = getattr(pointer, "layer")
|
138 |
+
pointer = pointer[0]
|
139 |
+
elif scope_names[0] == "enc_dec_attention":
|
140 |
+
pointer = getattr(pointer, "layer")
|
141 |
+
pointer = pointer[1]
|
142 |
+
elif scope_names[0] == "dense_relu_dense":
|
143 |
+
pointer = getattr(pointer, "layer")
|
144 |
+
pointer = pointer[2]
|
145 |
+
elif scope_names[0] == "rms_norm":
|
146 |
+
if hasattr(pointer, "layer_norm"):
|
147 |
+
pointer = getattr(pointer, "layer_norm")
|
148 |
+
elif hasattr(pointer, "final_layer_norm"):
|
149 |
+
pointer = getattr(pointer, "final_layer_norm")
|
150 |
+
elif scope_names[0] == "scale":
|
151 |
+
pointer = getattr(pointer, "weight")
|
152 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
153 |
+
pointer = getattr(pointer, "bias")
|
154 |
+
elif scope_names[0] == "squad":
|
155 |
+
pointer = getattr(pointer, "classifier")
|
156 |
+
elif scope_names[0] == "decoder" and name[1] == "logits":
|
157 |
+
continue
|
158 |
+
elif scope_names[0] == "logits":
|
159 |
+
pointer = getattr(pointer, "lm_head")
|
160 |
+
elif (
|
161 |
+
scope_names[0] == "wi"
|
162 |
+
and len(scope_names) > 1
|
163 |
+
and scope_names[1].isdigit()
|
164 |
+
):
|
165 |
+
pointer = getattr(pointer, f"wi_{scope_names[1]}")
|
166 |
+
continue
|
167 |
+
else:
|
168 |
+
try:
|
169 |
+
pointer = getattr(pointer, scope_names[0])
|
170 |
+
except AttributeError:
|
171 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
172 |
+
continue
|
173 |
+
if len(scope_names) >= 2:
|
174 |
+
num = int(scope_names[1])
|
175 |
+
pointer = pointer[num]
|
176 |
+
if scope_names[0] not in ["kernel", "scale", "embedding"]:
|
177 |
+
pointer = getattr(pointer, "weight")
|
178 |
+
if scope_names[0] != "embedding":
|
179 |
+
logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
|
180 |
+
array = np.transpose(array)
|
181 |
+
try:
|
182 |
+
assert (
|
183 |
+
pointer.shape == array.shape
|
184 |
+
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
185 |
+
except AssertionError as e:
|
186 |
+
e.args += (pointer.shape, array.shape)
|
187 |
+
raise
|
188 |
+
logger.info(f"Initialize PyTorch weight {name}")
|
189 |
+
pointer.data = torch.from_numpy(array.astype(np.float32))
|
190 |
+
tf_weights.pop(txt_name, None)
|
191 |
+
|
192 |
+
logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
|
193 |
+
return model
|
194 |
+
|
195 |
+
|
196 |
+
####################################################
|
197 |
+
# PyTorch Models are constructed by sub-classing
|
198 |
+
# - torch.nn.Module for the layers and
|
199 |
+
# - PreTrainedModel for the models (it-self a sub-class of nn.Module)
|
200 |
+
####################################################
|
201 |
+
PARALLELIZE_DOCSTRING = r"""
|
202 |
+
This is an experimental feature and is a subject to change at a moment's notice.
|
203 |
+
|
204 |
+
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
205 |
+
it will evenly distribute blocks across all devices.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
device_map (`Dict[int, list]`, optional, defaults to None):
|
209 |
+
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
210 |
+
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
211 |
+
have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
|
212 |
+
following number of attention modules:
|
213 |
+
|
214 |
+
- t5-small: 6
|
215 |
+
- t5-base: 12
|
216 |
+
- t5-large: 24
|
217 |
+
- t5-3b: 24
|
218 |
+
- t5-11b: 24
|
219 |
+
|
220 |
+
Example:
|
221 |
+
|
222 |
+
```python
|
223 |
+
# Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
|
224 |
+
model = T5ForConditionalGeneration.from_pretrained("t5-3b")
|
225 |
+
device_map = {
|
226 |
+
0: [0, 1, 2],
|
227 |
+
1: [3, 4, 5, 6, 7, 8, 9],
|
228 |
+
2: [10, 11, 12, 13, 14, 15, 16],
|
229 |
+
3: [17, 18, 19, 20, 21, 22, 23],
|
230 |
+
}
|
231 |
+
model.parallelize(device_map)
|
232 |
+
```
|
233 |
+
"""
|
234 |
+
DEPARALLELIZE_DOCSTRING = r"""
|
235 |
+
Moves the model to cpu from a model parallel state.
|
236 |
+
|
237 |
+
Example:
|
238 |
+
|
239 |
+
```python
|
240 |
+
# On a 4 GPU machine with t5-3b:
|
241 |
+
model = T5ForConditionalGeneration.from_pretrained("t5-3b")
|
242 |
+
device_map = {
|
243 |
+
0: [0, 1, 2],
|
244 |
+
1: [3, 4, 5, 6, 7, 8, 9],
|
245 |
+
2: [10, 11, 12, 13, 14, 15, 16],
|
246 |
+
3: [17, 18, 19, 20, 21, 22, 23],
|
247 |
+
}
|
248 |
+
model.parallelize(device_map) # Splits the model across several devices
|
249 |
+
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
250 |
+
```
|
251 |
+
"""
|
252 |
+
|
253 |
+
|
254 |
+
class T5LayerNorm(nn.Module):
|
255 |
+
def __init__(self, hidden_size, eps=1e-6):
|
256 |
+
"""
|
257 |
+
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
258 |
+
"""
|
259 |
+
super().__init__()
|
260 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
261 |
+
self.variance_epsilon = eps
|
262 |
+
|
263 |
+
def forward(self, hidden_states):
|
264 |
+
|
265 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
266 |
+
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
267 |
+
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
268 |
+
# half-precision inputs is done in fp32
|
269 |
+
|
270 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
271 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
272 |
+
|
273 |
+
# convert into half-precision if necessary
|
274 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
275 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
276 |
+
|
277 |
+
return self.weight * hidden_states
|
278 |
+
|
279 |
+
|
280 |
+
try:
|
281 |
+
from apex.normalization import FusedRMSNorm
|
282 |
+
|
283 |
+
T5LayerNorm = FusedRMSNorm # noqa
|
284 |
+
|
285 |
+
logger.info(
|
286 |
+
"Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm"
|
287 |
+
)
|
288 |
+
except ImportError:
|
289 |
+
# using the normal T5LayerNorm
|
290 |
+
pass
|
291 |
+
except Exception:
|
292 |
+
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
|
293 |
+
pass
|
294 |
+
|
295 |
+
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
|
296 |
+
|
297 |
+
|
298 |
+
class T5DenseActDense(nn.Module):
|
299 |
+
def __init__(self, config: T5Config):
|
300 |
+
super().__init__()
|
301 |
+
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
|
302 |
+
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
303 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
304 |
+
self.act = ACT2FN[config.dense_act_fn]
|
305 |
+
|
306 |
+
def forward(self, hidden_states):
|
307 |
+
hidden_states = self.wi(hidden_states)
|
308 |
+
hidden_states = self.act(hidden_states)
|
309 |
+
hidden_states = self.dropout(hidden_states)
|
310 |
+
hidden_states = self.wo(hidden_states)
|
311 |
+
return hidden_states
|
312 |
+
|
313 |
+
|
314 |
+
class T5DenseGatedActDense(nn.Module):
|
315 |
+
def __init__(self, config: T5Config):
|
316 |
+
super().__init__()
|
317 |
+
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
318 |
+
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
319 |
+
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
320 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
321 |
+
self.act = ACT2FN[config.dense_act_fn]
|
322 |
+
|
323 |
+
def forward(self, hidden_states):
|
324 |
+
hidden_gelu = self.act(self.wi_0(hidden_states))
|
325 |
+
hidden_linear = self.wi_1(hidden_states)
|
326 |
+
hidden_states = hidden_gelu * hidden_linear
|
327 |
+
hidden_states = self.dropout(hidden_states)
|
328 |
+
hidden_states = self.wo(hidden_states)
|
329 |
+
return hidden_states
|
330 |
+
|
331 |
+
|
332 |
+
class T5LayerFF(nn.Module):
|
333 |
+
def __init__(self, config: T5Config):
|
334 |
+
super().__init__()
|
335 |
+
if config.is_gated_act:
|
336 |
+
self.DenseReluDense = T5DenseGatedActDense(config)
|
337 |
+
else:
|
338 |
+
self.DenseReluDense = T5DenseActDense(config)
|
339 |
+
|
340 |
+
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
341 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
342 |
+
|
343 |
+
def forward(self, hidden_states):
|
344 |
+
forwarded_states = self.layer_norm(hidden_states)
|
345 |
+
forwarded_states = self.DenseReluDense(forwarded_states)
|
346 |
+
hidden_states = hidden_states + self.dropout(forwarded_states)
|
347 |
+
return hidden_states
|
348 |
+
|
349 |
+
|
350 |
+
class T5Attention(nn.Module):
|
351 |
+
def __init__(self, config: T5Config, has_relative_attention_bias=False):
|
352 |
+
super().__init__()
|
353 |
+
self.is_decoder = config.is_decoder
|
354 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
355 |
+
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
356 |
+
self.relative_attention_max_distance = config.relative_attention_max_distance
|
357 |
+
self.d_model = config.d_model
|
358 |
+
self.key_value_proj_dim = config.d_kv
|
359 |
+
self.n_heads = config.num_heads
|
360 |
+
self.dropout = config.dropout_rate
|
361 |
+
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
362 |
+
|
363 |
+
# Mesh TensorFlow initialization to avoid scaling before softmax
|
364 |
+
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
365 |
+
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
366 |
+
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
367 |
+
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
|
368 |
+
|
369 |
+
if self.has_relative_attention_bias:
|
370 |
+
self.relative_attention_bias = nn.Embedding(
|
371 |
+
self.relative_attention_num_buckets, self.n_heads
|
372 |
+
)
|
373 |
+
self.pruned_heads = set()
|
374 |
+
self.gradient_checkpointing = False
|
375 |
+
|
376 |
+
def prune_heads(self, heads):
|
377 |
+
if len(heads) == 0:
|
378 |
+
return
|
379 |
+
heads, index = find_pruneable_heads_and_indices(
|
380 |
+
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
|
381 |
+
)
|
382 |
+
# Prune linear layers
|
383 |
+
self.q = prune_linear_layer(self.q, index)
|
384 |
+
self.k = prune_linear_layer(self.k, index)
|
385 |
+
self.v = prune_linear_layer(self.v, index)
|
386 |
+
self.o = prune_linear_layer(self.o, index, dim=1)
|
387 |
+
# Update hyper params
|
388 |
+
self.n_heads = self.n_heads - len(heads)
|
389 |
+
self.inner_dim = self.key_value_proj_dim * self.n_heads
|
390 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
391 |
+
|
392 |
+
@staticmethod
|
393 |
+
def _relative_position_bucket(
|
394 |
+
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
395 |
+
):
|
396 |
+
"""
|
397 |
+
Adapted from Mesh Tensorflow:
|
398 |
+
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
399 |
+
|
400 |
+
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
401 |
+
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
402 |
+
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
403 |
+
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
404 |
+
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
405 |
+
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
406 |
+
|
407 |
+
Args:
|
408 |
+
relative_position: an int32 Tensor
|
409 |
+
bidirectional: a boolean - whether the attention is bidirectional
|
410 |
+
num_buckets: an integer
|
411 |
+
max_distance: an integer
|
412 |
+
|
413 |
+
Returns:
|
414 |
+
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
415 |
+
"""
|
416 |
+
relative_buckets = 0
|
417 |
+
if bidirectional:
|
418 |
+
num_buckets //= 2
|
419 |
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
420 |
+
relative_position = torch.abs(relative_position)
|
421 |
+
else:
|
422 |
+
relative_position = -torch.min(
|
423 |
+
relative_position, torch.zeros_like(relative_position)
|
424 |
+
)
|
425 |
+
# now relative_position is in the range [0, inf)
|
426 |
+
|
427 |
+
# half of the buckets are for exact increments in positions
|
428 |
+
max_exact = num_buckets // 2
|
429 |
+
is_small = relative_position < max_exact
|
430 |
+
|
431 |
+
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
432 |
+
relative_position_if_large = max_exact + (
|
433 |
+
torch.log(relative_position.float() / max_exact)
|
434 |
+
/ math.log(max_distance / max_exact)
|
435 |
+
* (num_buckets - max_exact)
|
436 |
+
).to(torch.long)
|
437 |
+
relative_position_if_large = torch.min(
|
438 |
+
relative_position_if_large,
|
439 |
+
torch.full_like(relative_position_if_large, num_buckets - 1),
|
440 |
+
)
|
441 |
+
|
442 |
+
relative_buckets += torch.where(
|
443 |
+
is_small, relative_position, relative_position_if_large
|
444 |
+
)
|
445 |
+
return relative_buckets
|
446 |
+
|
447 |
+
def compute_bias(self, query_length, key_length, device=None):
|
448 |
+
"""Compute binned relative position bias"""
|
449 |
+
if device is None:
|
450 |
+
device = self.relative_attention_bias.weight.device
|
451 |
+
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
|
452 |
+
:, None
|
453 |
+
]
|
454 |
+
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
|
455 |
+
None, :
|
456 |
+
]
|
457 |
+
relative_position = (
|
458 |
+
memory_position - context_position
|
459 |
+
) # shape (query_length, key_length)
|
460 |
+
relative_position_bucket = self._relative_position_bucket(
|
461 |
+
relative_position, # shape (query_length, key_length)
|
462 |
+
bidirectional=(not self.is_decoder),
|
463 |
+
num_buckets=self.relative_attention_num_buckets,
|
464 |
+
max_distance=self.relative_attention_max_distance,
|
465 |
+
)
|
466 |
+
values = self.relative_attention_bias(
|
467 |
+
relative_position_bucket
|
468 |
+
) # shape (query_length, key_length, num_heads)
|
469 |
+
values = values.permute([2, 0, 1]).unsqueeze(
|
470 |
+
0
|
471 |
+
) # shape (1, num_heads, query_length, key_length)
|
472 |
+
return values
|
473 |
+
|
474 |
+
def forward(
|
475 |
+
self,
|
476 |
+
hidden_states,
|
477 |
+
mask=None,
|
478 |
+
key_value_states=None,
|
479 |
+
position_bias=None,
|
480 |
+
past_key_value=None,
|
481 |
+
layer_head_mask=None,
|
482 |
+
query_length=None,
|
483 |
+
use_cache=False,
|
484 |
+
output_attentions=False,
|
485 |
+
):
|
486 |
+
"""
|
487 |
+
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
488 |
+
"""
|
489 |
+
# Input is (batch_size, seq_length, dim)
|
490 |
+
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
491 |
+
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
492 |
+
batch_size, seq_length = hidden_states.shape[:2]
|
493 |
+
|
494 |
+
real_seq_length = seq_length
|
495 |
+
|
496 |
+
if past_key_value is not None:
|
497 |
+
assert (
|
498 |
+
len(past_key_value) == 2
|
499 |
+
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
|
500 |
+
real_seq_length += (
|
501 |
+
past_key_value[0].shape[2] if query_length is None else query_length
|
502 |
+
)
|
503 |
+
|
504 |
+
key_length = (
|
505 |
+
real_seq_length if key_value_states is None else key_value_states.shape[1]
|
506 |
+
)
|
507 |
+
|
508 |
+
def shape(states):
|
509 |
+
"""projection"""
|
510 |
+
return states.view(
|
511 |
+
batch_size, -1, self.n_heads, self.key_value_proj_dim
|
512 |
+
).transpose(1, 2)
|
513 |
+
|
514 |
+
def unshape(states):
|
515 |
+
"""reshape"""
|
516 |
+
return (
|
517 |
+
states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
518 |
+
)
|
519 |
+
|
520 |
+
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
521 |
+
"""projects hidden states correctly to key/query states"""
|
522 |
+
if key_value_states is None:
|
523 |
+
# self-attn
|
524 |
+
# (batch_size, n_heads, seq_length, dim_per_head)
|
525 |
+
hidden_states = shape(proj_layer(hidden_states))
|
526 |
+
elif past_key_value is None:
|
527 |
+
# cross-attn
|
528 |
+
# (batch_size, n_heads, seq_length, dim_per_head)
|
529 |
+
hidden_states = shape(proj_layer(key_value_states))
|
530 |
+
|
531 |
+
if past_key_value is not None:
|
532 |
+
if key_value_states is None:
|
533 |
+
# self-attn
|
534 |
+
# (batch_size, n_heads, key_length, dim_per_head)
|
535 |
+
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
|
536 |
+
else:
|
537 |
+
# cross-attn
|
538 |
+
hidden_states = past_key_value
|
539 |
+
return hidden_states
|
540 |
+
|
541 |
+
# get query states
|
542 |
+
query_states = shape(
|
543 |
+
self.q(hidden_states)
|
544 |
+
) # (batch_size, n_heads, seq_length, dim_per_head)
|
545 |
+
|
546 |
+
# get key/value states
|
547 |
+
key_states = project(
|
548 |
+
hidden_states,
|
549 |
+
self.k,
|
550 |
+
key_value_states,
|
551 |
+
past_key_value[0] if past_key_value is not None else None,
|
552 |
+
)
|
553 |
+
value_states = project(
|
554 |
+
hidden_states,
|
555 |
+
self.v,
|
556 |
+
key_value_states,
|
557 |
+
past_key_value[1] if past_key_value is not None else None,
|
558 |
+
)
|
559 |
+
|
560 |
+
# compute scores
|
561 |
+
scores = torch.matmul(
|
562 |
+
query_states, key_states.transpose(3, 2)
|
563 |
+
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
564 |
+
|
565 |
+
if position_bias is None:
|
566 |
+
if not self.has_relative_attention_bias:
|
567 |
+
position_bias = torch.zeros(
|
568 |
+
(1, self.n_heads, real_seq_length, key_length),
|
569 |
+
device=scores.device,
|
570 |
+
dtype=scores.dtype,
|
571 |
+
)
|
572 |
+
if self.gradient_checkpointing and self.training:
|
573 |
+
position_bias.requires_grad = True
|
574 |
+
else:
|
575 |
+
position_bias = self.compute_bias(
|
576 |
+
real_seq_length, key_length, device=scores.device
|
577 |
+
)
|
578 |
+
|
579 |
+
# if key and values are already calculated
|
580 |
+
# we want only the last query position bias
|
581 |
+
if past_key_value is not None:
|
582 |
+
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
|
583 |
+
|
584 |
+
if mask is not None:
|
585 |
+
position_bias = (
|
586 |
+
position_bias + mask
|
587 |
+
) # (batch_size, n_heads, seq_length, key_length)
|
588 |
+
|
589 |
+
if self.pruned_heads:
|
590 |
+
mask = torch.ones(position_bias.shape[1])
|
591 |
+
mask[list(self.pruned_heads)] = 0
|
592 |
+
position_bias_masked = position_bias[:, mask.bool()]
|
593 |
+
else:
|
594 |
+
position_bias_masked = position_bias
|
595 |
+
|
596 |
+
scores += position_bias_masked
|
597 |
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
598 |
+
scores
|
599 |
+
) # (batch_size, n_heads, seq_length, key_length)
|
600 |
+
attn_weights = nn.functional.dropout(
|
601 |
+
attn_weights, p=self.dropout, training=self.training
|
602 |
+
) # (batch_size, n_heads, seq_length, key_length)
|
603 |
+
|
604 |
+
# Mask heads if we want to
|
605 |
+
if layer_head_mask is not None:
|
606 |
+
attn_weights = attn_weights * layer_head_mask
|
607 |
+
|
608 |
+
attn_output = unshape(
|
609 |
+
torch.matmul(attn_weights, value_states)
|
610 |
+
) # (batch_size, seq_length, dim)
|
611 |
+
attn_output = self.o(attn_output)
|
612 |
+
|
613 |
+
present_key_value_state = (
|
614 |
+
(key_states, value_states) if (self.is_decoder and use_cache) else None
|
615 |
+
)
|
616 |
+
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
|
617 |
+
|
618 |
+
if output_attentions:
|
619 |
+
outputs = outputs + (attn_weights,)
|
620 |
+
return outputs
|
621 |
+
|
622 |
+
|
623 |
+
class T5LayerSelfAttention(nn.Module):
|
624 |
+
def __init__(self, config, has_relative_attention_bias=False):
|
625 |
+
super().__init__()
|
626 |
+
self.SelfAttention = T5Attention(
|
627 |
+
config, has_relative_attention_bias=has_relative_attention_bias
|
628 |
+
)
|
629 |
+
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
630 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
631 |
+
|
632 |
+
def forward(
|
633 |
+
self,
|
634 |
+
hidden_states,
|
635 |
+
attention_mask=None,
|
636 |
+
position_bias=None,
|
637 |
+
layer_head_mask=None,
|
638 |
+
past_key_value=None,
|
639 |
+
use_cache=False,
|
640 |
+
output_attentions=False,
|
641 |
+
):
|
642 |
+
normed_hidden_states = self.layer_norm(hidden_states)
|
643 |
+
attention_output = self.SelfAttention(
|
644 |
+
normed_hidden_states,
|
645 |
+
mask=attention_mask,
|
646 |
+
position_bias=position_bias,
|
647 |
+
layer_head_mask=layer_head_mask,
|
648 |
+
past_key_value=past_key_value,
|
649 |
+
use_cache=use_cache,
|
650 |
+
output_attentions=output_attentions,
|
651 |
+
)
|
652 |
+
hidden_states = hidden_states + self.dropout(attention_output[0])
|
653 |
+
outputs = (hidden_states,) + attention_output[
|
654 |
+
1:
|
655 |
+
] # add attentions if we output them
|
656 |
+
return outputs
|
657 |
+
|
658 |
+
|
659 |
+
class T5LayerCrossAttention(nn.Module):
|
660 |
+
def __init__(self, config):
|
661 |
+
super().__init__()
|
662 |
+
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
|
663 |
+
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
664 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
665 |
+
|
666 |
+
def forward(
|
667 |
+
self,
|
668 |
+
hidden_states,
|
669 |
+
key_value_states,
|
670 |
+
attention_mask=None,
|
671 |
+
position_bias=None,
|
672 |
+
layer_head_mask=None,
|
673 |
+
past_key_value=None,
|
674 |
+
use_cache=False,
|
675 |
+
query_length=None,
|
676 |
+
output_attentions=False,
|
677 |
+
):
|
678 |
+
normed_hidden_states = self.layer_norm(hidden_states)
|
679 |
+
attention_output = self.EncDecAttention(
|
680 |
+
normed_hidden_states,
|
681 |
+
mask=attention_mask,
|
682 |
+
key_value_states=key_value_states,
|
683 |
+
position_bias=position_bias,
|
684 |
+
layer_head_mask=layer_head_mask,
|
685 |
+
past_key_value=past_key_value,
|
686 |
+
use_cache=use_cache,
|
687 |
+
query_length=query_length,
|
688 |
+
output_attentions=output_attentions,
|
689 |
+
)
|
690 |
+
layer_output = hidden_states + self.dropout(attention_output[0])
|
691 |
+
outputs = (layer_output,) + attention_output[
|
692 |
+
1:
|
693 |
+
] # add attentions if we output them
|
694 |
+
return outputs
|
695 |
+
|
696 |
+
|
697 |
+
class T5Block(nn.Module):
|
698 |
+
def __init__(self, config, has_relative_attention_bias=False):
|
699 |
+
super().__init__()
|
700 |
+
self.is_decoder = config.is_decoder
|
701 |
+
self.layer = nn.ModuleList()
|
702 |
+
self.layer.append(
|
703 |
+
T5LayerSelfAttention(
|
704 |
+
config, has_relative_attention_bias=has_relative_attention_bias
|
705 |
+
)
|
706 |
+
)
|
707 |
+
if self.is_decoder:
|
708 |
+
self.layer.append(T5LayerCrossAttention(config))
|
709 |
+
|
710 |
+
self.layer.append(T5LayerFF(config))
|
711 |
+
|
712 |
+
def forward(
|
713 |
+
self,
|
714 |
+
hidden_states,
|
715 |
+
attention_mask=None,
|
716 |
+
position_bias=None,
|
717 |
+
encoder_hidden_states=None,
|
718 |
+
encoder_attention_mask=None,
|
719 |
+
encoder_decoder_position_bias=None,
|
720 |
+
layer_head_mask=None,
|
721 |
+
cross_attn_layer_head_mask=None,
|
722 |
+
past_key_value=None,
|
723 |
+
use_cache=False,
|
724 |
+
output_attentions=False,
|
725 |
+
return_dict=True,
|
726 |
+
):
|
727 |
+
|
728 |
+
if past_key_value is not None:
|
729 |
+
if not self.is_decoder:
|
730 |
+
logger.warning(
|
731 |
+
"`past_key_values` is passed to the encoder. Please make sure this is intended."
|
732 |
+
)
|
733 |
+
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
|
734 |
+
|
735 |
+
if len(past_key_value) != expected_num_past_key_values:
|
736 |
+
raise ValueError(
|
737 |
+
f"There should be {expected_num_past_key_values} past states. "
|
738 |
+
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
739 |
+
f"Got {len(past_key_value)} past key / value states"
|
740 |
+
)
|
741 |
+
|
742 |
+
self_attn_past_key_value = past_key_value[:2]
|
743 |
+
cross_attn_past_key_value = past_key_value[2:]
|
744 |
+
else:
|
745 |
+
self_attn_past_key_value, cross_attn_past_key_value = None, None
|
746 |
+
|
747 |
+
self_attention_outputs = self.layer[0](
|
748 |
+
hidden_states,
|
749 |
+
attention_mask=attention_mask,
|
750 |
+
position_bias=position_bias,
|
751 |
+
layer_head_mask=layer_head_mask,
|
752 |
+
past_key_value=self_attn_past_key_value,
|
753 |
+
use_cache=use_cache,
|
754 |
+
output_attentions=output_attentions,
|
755 |
+
)
|
756 |
+
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
757 |
+
attention_outputs = self_attention_outputs[
|
758 |
+
2:
|
759 |
+
] # Keep self-attention outputs and relative position weights
|
760 |
+
|
761 |
+
# clamp inf values to enable fp16 training
|
762 |
+
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
763 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
764 |
+
hidden_states = torch.clamp(
|
765 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
766 |
+
)
|
767 |
+
|
768 |
+
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
769 |
+
if do_cross_attention:
|
770 |
+
# the actual query length is unknown for cross attention
|
771 |
+
# if using past key value states. Need to inject it here
|
772 |
+
if present_key_value_state is not None:
|
773 |
+
query_length = present_key_value_state[0].shape[2]
|
774 |
+
else:
|
775 |
+
query_length = None
|
776 |
+
|
777 |
+
cross_attention_outputs = self.layer[1](
|
778 |
+
hidden_states,
|
779 |
+
key_value_states=encoder_hidden_states,
|
780 |
+
attention_mask=encoder_attention_mask,
|
781 |
+
position_bias=encoder_decoder_position_bias,
|
782 |
+
layer_head_mask=cross_attn_layer_head_mask,
|
783 |
+
past_key_value=cross_attn_past_key_value,
|
784 |
+
query_length=query_length,
|
785 |
+
use_cache=use_cache,
|
786 |
+
output_attentions=output_attentions,
|
787 |
+
)
|
788 |
+
hidden_states = cross_attention_outputs[0]
|
789 |
+
|
790 |
+
# clamp inf values to enable fp16 training
|
791 |
+
if (
|
792 |
+
hidden_states.dtype == torch.float16
|
793 |
+
and torch.isinf(hidden_states).any()
|
794 |
+
):
|
795 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
796 |
+
hidden_states = torch.clamp(
|
797 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
798 |
+
)
|
799 |
+
|
800 |
+
# Combine self attn and cross attn key value states
|
801 |
+
if present_key_value_state is not None:
|
802 |
+
present_key_value_state = (
|
803 |
+
present_key_value_state + cross_attention_outputs[1]
|
804 |
+
)
|
805 |
+
|
806 |
+
# Keep cross-attention outputs and relative position weights
|
807 |
+
attention_outputs = attention_outputs + cross_attention_outputs[2:]
|
808 |
+
|
809 |
+
# Apply Feed Forward layer
|
810 |
+
hidden_states = self.layer[-1](hidden_states)
|
811 |
+
|
812 |
+
# clamp inf values to enable fp16 training
|
813 |
+
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
814 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
815 |
+
hidden_states = torch.clamp(
|
816 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
817 |
+
)
|
818 |
+
|
819 |
+
outputs = (hidden_states,)
|
820 |
+
|
821 |
+
if use_cache:
|
822 |
+
outputs = outputs + (present_key_value_state,) + attention_outputs
|
823 |
+
else:
|
824 |
+
outputs = outputs + attention_outputs
|
825 |
+
|
826 |
+
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
827 |
+
|
828 |
+
|
829 |
+
class T5PreTrainedModel(PreTrainedModel):
|
830 |
+
"""
|
831 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
832 |
+
models.
|
833 |
+
"""
|
834 |
+
|
835 |
+
config_class = T5Config
|
836 |
+
load_tf_weights = load_tf_weights_in_t5
|
837 |
+
base_model_prefix = "transformer"
|
838 |
+
is_parallelizable = True
|
839 |
+
supports_gradient_checkpointing = True
|
840 |
+
_no_split_modules = ["T5Block"]
|
841 |
+
|
842 |
+
@property
|
843 |
+
def dummy_inputs(self):
|
844 |
+
input_ids = torch.tensor(DUMMY_INPUTS)
|
845 |
+
input_mask = torch.tensor(DUMMY_MASK)
|
846 |
+
dummy_inputs = {
|
847 |
+
"decoder_input_ids": input_ids,
|
848 |
+
"input_ids": input_ids,
|
849 |
+
"decoder_attention_mask": input_mask,
|
850 |
+
}
|
851 |
+
return dummy_inputs
|
852 |
+
|
853 |
+
def _init_weights(self, module):
|
854 |
+
"""Initialize the weights"""
|
855 |
+
factor = (
|
856 |
+
self.config.initializer_factor
|
857 |
+
) # Used for testing weights initialization
|
858 |
+
if isinstance(module, T5LayerNorm):
|
859 |
+
module.weight.data.fill_(factor * 1.0)
|
860 |
+
elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
|
861 |
+
# Mesh TensorFlow embeddings initialization
|
862 |
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
863 |
+
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
864 |
+
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
|
865 |
+
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
866 |
+
elif isinstance(module, T5DenseActDense):
|
867 |
+
# Mesh TensorFlow FF initialization
|
868 |
+
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
869 |
+
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
|
870 |
+
module.wi.weight.data.normal_(
|
871 |
+
mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
|
872 |
+
)
|
873 |
+
if hasattr(module.wi, "bias") and module.wi.bias is not None:
|
874 |
+
module.wi.bias.data.zero_()
|
875 |
+
module.wo.weight.data.normal_(
|
876 |
+
mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
|
877 |
+
)
|
878 |
+
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
879 |
+
module.wo.bias.data.zero_()
|
880 |
+
elif isinstance(module, T5DenseGatedActDense):
|
881 |
+
module.wi_0.weight.data.normal_(
|
882 |
+
mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
|
883 |
+
)
|
884 |
+
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
|
885 |
+
module.wi_0.bias.data.zero_()
|
886 |
+
module.wi_1.weight.data.normal_(
|
887 |
+
mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
|
888 |
+
)
|
889 |
+
if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
|
890 |
+
module.wi_1.bias.data.zero_()
|
891 |
+
module.wo.weight.data.normal_(
|
892 |
+
mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
|
893 |
+
)
|
894 |
+
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
895 |
+
module.wo.bias.data.zero_()
|
896 |
+
elif isinstance(module, T5Attention):
|
897 |
+
# Mesh TensorFlow attention initialization to avoid scaling before softmax
|
898 |
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
|
899 |
+
d_model = self.config.d_model
|
900 |
+
key_value_proj_dim = self.config.d_kv
|
901 |
+
n_heads = self.config.num_heads
|
902 |
+
module.q.weight.data.normal_(
|
903 |
+
mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)
|
904 |
+
)
|
905 |
+
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
906 |
+
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
907 |
+
module.o.weight.data.normal_(
|
908 |
+
mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)
|
909 |
+
)
|
910 |
+
if module.has_relative_attention_bias:
|
911 |
+
module.relative_attention_bias.weight.data.normal_(
|
912 |
+
mean=0.0, std=factor * ((d_model) ** -0.5)
|
913 |
+
)
|
914 |
+
|
915 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
916 |
+
if isinstance(module, (T5Attention, T5Stack)):
|
917 |
+
module.gradient_checkpointing = value
|
918 |
+
|
919 |
+
def _shift_right(self, input_ids):
|
920 |
+
decoder_start_token_id = self.config.decoder_start_token_id
|
921 |
+
pad_token_id = self.config.pad_token_id
|
922 |
+
|
923 |
+
assert decoder_start_token_id is not None, (
|
924 |
+
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
|
925 |
+
" See T5 docs for more information"
|
926 |
+
)
|
927 |
+
|
928 |
+
# shift inputs to the right
|
929 |
+
if is_torch_fx_proxy(input_ids):
|
930 |
+
# Item assignment is not supported natively for proxies.
|
931 |
+
shifted_input_ids = torch.full(
|
932 |
+
input_ids.shape[:-1] + (1,), decoder_start_token_id
|
933 |
+
)
|
934 |
+
shifted_input_ids = torch.cat(
|
935 |
+
[shifted_input_ids, input_ids[..., :-1]], dim=-1
|
936 |
+
)
|
937 |
+
else:
|
938 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
939 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
940 |
+
shifted_input_ids[..., 0] = decoder_start_token_id
|
941 |
+
|
942 |
+
assert (
|
943 |
+
pad_token_id is not None
|
944 |
+
), "self.model.config.pad_token_id has to be defined."
|
945 |
+
# replace possible -100 values in labels by `pad_token_id`
|
946 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
947 |
+
|
948 |
+
return shifted_input_ids
|
949 |
+
|
950 |
+
|
951 |
+
class T5Stack(T5PreTrainedModel):
|
952 |
+
def __init__(self, config, embed_tokens=None):
|
953 |
+
super().__init__(config)
|
954 |
+
|
955 |
+
self.embed_tokens = embed_tokens
|
956 |
+
self.is_decoder = config.is_decoder
|
957 |
+
|
958 |
+
self.block = nn.ModuleList(
|
959 |
+
[
|
960 |
+
T5Block(config, has_relative_attention_bias=bool(i == 0))
|
961 |
+
for i in range(config.num_layers)
|
962 |
+
]
|
963 |
+
)
|
964 |
+
self.final_layer_norm = T5LayerNorm(
|
965 |
+
config.d_model, eps=config.layer_norm_epsilon
|
966 |
+
)
|
967 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
968 |
+
|
969 |
+
# Initialize weights and apply final processing
|
970 |
+
self.post_init()
|
971 |
+
# Model parallel
|
972 |
+
self.model_parallel = False
|
973 |
+
self.device_map = None
|
974 |
+
self.gradient_checkpointing = False
|
975 |
+
|
976 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
977 |
+
def parallelize(self, device_map=None):
|
978 |
+
# Check validity of device_map
|
979 |
+
self.device_map = (
|
980 |
+
get_device_map(len(self.block), range(torch.cuda.device_count()))
|
981 |
+
if device_map is None
|
982 |
+
else device_map
|
983 |
+
)
|
984 |
+
assert_device_map(self.device_map, len(self.block))
|
985 |
+
self.model_parallel = True
|
986 |
+
self.first_device = (
|
987 |
+
"cpu"
|
988 |
+
if "cpu" in self.device_map.keys()
|
989 |
+
else "cuda:" + str(min(self.device_map.keys()))
|
990 |
+
)
|
991 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
992 |
+
# Load onto devices
|
993 |
+
for k, v in self.device_map.items():
|
994 |
+
for layer in v:
|
995 |
+
cuda_device = "cuda:" + str(k)
|
996 |
+
self.block[layer] = self.block[layer].to(cuda_device)
|
997 |
+
|
998 |
+
# Set embed_tokens to first layer
|
999 |
+
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
1000 |
+
# Set final layer norm to last device
|
1001 |
+
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
|
1002 |
+
|
1003 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1004 |
+
def deparallelize(self):
|
1005 |
+
self.model_parallel = False
|
1006 |
+
self.device_map = None
|
1007 |
+
self.first_device = "cpu"
|
1008 |
+
self.last_device = "cpu"
|
1009 |
+
for i in range(len(self.block)):
|
1010 |
+
self.block[i] = self.block[i].to("cpu")
|
1011 |
+
self.embed_tokens = self.embed_tokens.to("cpu")
|
1012 |
+
self.final_layer_norm = self.final_layer_norm.to("cpu")
|
1013 |
+
torch.cuda.empty_cache()
|
1014 |
+
|
1015 |
+
def get_input_embeddings(self):
|
1016 |
+
return self.embed_tokens
|
1017 |
+
|
1018 |
+
def set_input_embeddings(self, new_embeddings):
|
1019 |
+
self.embed_tokens = new_embeddings
|
1020 |
+
|
1021 |
+
def forward(
|
1022 |
+
self,
|
1023 |
+
input_ids=None,
|
1024 |
+
attention_mask=None,
|
1025 |
+
encoder_hidden_states=None,
|
1026 |
+
encoder_attention_mask=None,
|
1027 |
+
inputs_embeds=None,
|
1028 |
+
head_mask=None,
|
1029 |
+
cross_attn_head_mask=None,
|
1030 |
+
past_key_values=None,
|
1031 |
+
use_cache=None,
|
1032 |
+
output_attentions=None,
|
1033 |
+
output_hidden_states=None,
|
1034 |
+
return_dict=None,
|
1035 |
+
):
|
1036 |
+
# Model parallel
|
1037 |
+
if self.model_parallel:
|
1038 |
+
torch.cuda.set_device(self.first_device)
|
1039 |
+
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
1040 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1041 |
+
output_attentions = (
|
1042 |
+
output_attentions
|
1043 |
+
if output_attentions is not None
|
1044 |
+
else self.config.output_attentions
|
1045 |
+
)
|
1046 |
+
output_hidden_states = (
|
1047 |
+
output_hidden_states
|
1048 |
+
if output_hidden_states is not None
|
1049 |
+
else self.config.output_hidden_states
|
1050 |
+
)
|
1051 |
+
return_dict = (
|
1052 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1053 |
+
)
|
1054 |
+
|
1055 |
+
if input_ids is not None and inputs_embeds is not None:
|
1056 |
+
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
1057 |
+
raise ValueError(
|
1058 |
+
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
|
1059 |
+
)
|
1060 |
+
elif input_ids is not None:
|
1061 |
+
input_shape = input_ids.size()
|
1062 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
1063 |
+
elif inputs_embeds is not None:
|
1064 |
+
input_shape = inputs_embeds.size()[:-1]
|
1065 |
+
else:
|
1066 |
+
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
1067 |
+
raise ValueError(
|
1068 |
+
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
|
1069 |
+
)
|
1070 |
+
|
1071 |
+
if inputs_embeds is None:
|
1072 |
+
assert (
|
1073 |
+
self.embed_tokens is not None
|
1074 |
+
), "You have to initialize the model with valid token embeddings"
|
1075 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1076 |
+
|
1077 |
+
batch_size, seq_length = input_shape
|
1078 |
+
|
1079 |
+
# required mask seq length can be calculated via length of past
|
1080 |
+
mask_seq_length = (
|
1081 |
+
past_key_values[0][0].shape[2] + seq_length
|
1082 |
+
if past_key_values is not None
|
1083 |
+
else seq_length
|
1084 |
+
)
|
1085 |
+
|
1086 |
+
if use_cache is True:
|
1087 |
+
assert (
|
1088 |
+
self.is_decoder
|
1089 |
+
), f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
1090 |
+
|
1091 |
+
if attention_mask is None:
|
1092 |
+
attention_mask = torch.ones(
|
1093 |
+
batch_size, mask_seq_length, device=inputs_embeds.device
|
1094 |
+
)
|
1095 |
+
if (
|
1096 |
+
self.is_decoder
|
1097 |
+
and encoder_attention_mask is None
|
1098 |
+
and encoder_hidden_states is not None
|
1099 |
+
):
|
1100 |
+
encoder_seq_length = encoder_hidden_states.shape[1]
|
1101 |
+
encoder_attention_mask = torch.ones(
|
1102 |
+
batch_size,
|
1103 |
+
encoder_seq_length,
|
1104 |
+
device=inputs_embeds.device,
|
1105 |
+
dtype=torch.long,
|
1106 |
+
)
|
1107 |
+
|
1108 |
+
# initialize past_key_values with `None` if past does not exist
|
1109 |
+
if past_key_values is None:
|
1110 |
+
past_key_values = [None] * len(self.block)
|
1111 |
+
|
1112 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
1113 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
1114 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
1115 |
+
attention_mask, input_shape
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
1119 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
1120 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
1121 |
+
(
|
1122 |
+
encoder_batch_size,
|
1123 |
+
encoder_sequence_length,
|
1124 |
+
_,
|
1125 |
+
) = encoder_hidden_states.size()
|
1126 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
1127 |
+
if encoder_attention_mask is None:
|
1128 |
+
encoder_attention_mask = torch.ones(
|
1129 |
+
encoder_hidden_shape, device=inputs_embeds.device
|
1130 |
+
)
|
1131 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
1132 |
+
encoder_attention_mask
|
1133 |
+
)
|
1134 |
+
else:
|
1135 |
+
encoder_extended_attention_mask = None
|
1136 |
+
|
1137 |
+
# Prepare head mask if needed
|
1138 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
1139 |
+
cross_attn_head_mask = self.get_head_mask(
|
1140 |
+
cross_attn_head_mask, self.config.num_layers
|
1141 |
+
)
|
1142 |
+
present_key_value_states = () if use_cache else None
|
1143 |
+
all_hidden_states = () if output_hidden_states else None
|
1144 |
+
all_attentions = () if output_attentions else None
|
1145 |
+
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
|
1146 |
+
position_bias = None
|
1147 |
+
encoder_decoder_position_bias = None
|
1148 |
+
|
1149 |
+
hidden_states = self.dropout(inputs_embeds)
|
1150 |
+
|
1151 |
+
for i, (layer_module, past_key_value) in enumerate(
|
1152 |
+
zip(self.block, past_key_values)
|
1153 |
+
):
|
1154 |
+
layer_head_mask = head_mask[i]
|
1155 |
+
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
1156 |
+
# Model parallel
|
1157 |
+
if self.model_parallel:
|
1158 |
+
torch.cuda.set_device(hidden_states.device)
|
1159 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
1160 |
+
if attention_mask is not None:
|
1161 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
1162 |
+
if position_bias is not None:
|
1163 |
+
position_bias = position_bias.to(hidden_states.device)
|
1164 |
+
if encoder_hidden_states is not None:
|
1165 |
+
encoder_hidden_states = encoder_hidden_states.to(
|
1166 |
+
hidden_states.device
|
1167 |
+
)
|
1168 |
+
if encoder_extended_attention_mask is not None:
|
1169 |
+
encoder_extended_attention_mask = (
|
1170 |
+
encoder_extended_attention_mask.to(hidden_states.device)
|
1171 |
+
)
|
1172 |
+
if encoder_decoder_position_bias is not None:
|
1173 |
+
encoder_decoder_position_bias = encoder_decoder_position_bias.to(
|
1174 |
+
hidden_states.device
|
1175 |
+
)
|
1176 |
+
if layer_head_mask is not None:
|
1177 |
+
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
1178 |
+
if cross_attn_layer_head_mask is not None:
|
1179 |
+
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
|
1180 |
+
hidden_states.device
|
1181 |
+
)
|
1182 |
+
if output_hidden_states:
|
1183 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
1184 |
+
|
1185 |
+
if self.gradient_checkpointing and self.training:
|
1186 |
+
if use_cache:
|
1187 |
+
logger.warning(
|
1188 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
1189 |
+
)
|
1190 |
+
use_cache = False
|
1191 |
+
|
1192 |
+
def create_custom_forward(module):
|
1193 |
+
def custom_forward(*inputs):
|
1194 |
+
return tuple(module(*inputs, use_cache, output_attentions))
|
1195 |
+
|
1196 |
+
return custom_forward
|
1197 |
+
|
1198 |
+
layer_outputs = checkpoint(
|
1199 |
+
create_custom_forward(layer_module),
|
1200 |
+
hidden_states,
|
1201 |
+
extended_attention_mask,
|
1202 |
+
position_bias,
|
1203 |
+
encoder_hidden_states,
|
1204 |
+
encoder_extended_attention_mask,
|
1205 |
+
encoder_decoder_position_bias,
|
1206 |
+
layer_head_mask,
|
1207 |
+
cross_attn_layer_head_mask,
|
1208 |
+
None, # past_key_value is always None with gradient checkpointing
|
1209 |
+
)
|
1210 |
+
else:
|
1211 |
+
layer_outputs = layer_module(
|
1212 |
+
hidden_states,
|
1213 |
+
attention_mask=extended_attention_mask,
|
1214 |
+
position_bias=position_bias,
|
1215 |
+
encoder_hidden_states=encoder_hidden_states,
|
1216 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
1217 |
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
1218 |
+
layer_head_mask=layer_head_mask,
|
1219 |
+
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
1220 |
+
past_key_value=past_key_value,
|
1221 |
+
use_cache=use_cache,
|
1222 |
+
output_attentions=output_attentions,
|
1223 |
+
)
|
1224 |
+
|
1225 |
+
# layer_outputs is a tuple with:
|
1226 |
+
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
1227 |
+
if use_cache is False:
|
1228 |
+
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
|
1229 |
+
|
1230 |
+
hidden_states, present_key_value_state = layer_outputs[:2]
|
1231 |
+
|
1232 |
+
# We share the position biases between the layers - the first layer store them
|
1233 |
+
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
1234 |
+
# (cross-attention position bias), (cross-attention weights)
|
1235 |
+
position_bias = layer_outputs[2]
|
1236 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
1237 |
+
encoder_decoder_position_bias = layer_outputs[
|
1238 |
+
4 if output_attentions else 3
|
1239 |
+
]
|
1240 |
+
# append next layer key value states
|
1241 |
+
if use_cache:
|
1242 |
+
present_key_value_states = present_key_value_states + (
|
1243 |
+
present_key_value_state,
|
1244 |
+
)
|
1245 |
+
|
1246 |
+
if output_attentions:
|
1247 |
+
all_attentions = all_attentions + (layer_outputs[3],)
|
1248 |
+
if self.is_decoder:
|
1249 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
|
1250 |
+
|
1251 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
1252 |
+
if self.model_parallel:
|
1253 |
+
for k, v in self.device_map.items():
|
1254 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
1255 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
1256 |
+
|
1257 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
1258 |
+
hidden_states = self.dropout(hidden_states)
|
1259 |
+
|
1260 |
+
# Add last layer
|
1261 |
+
if output_hidden_states:
|
1262 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
1263 |
+
|
1264 |
+
if not return_dict:
|
1265 |
+
return tuple(
|
1266 |
+
v
|
1267 |
+
for v in [
|
1268 |
+
hidden_states,
|
1269 |
+
present_key_value_states,
|
1270 |
+
all_hidden_states,
|
1271 |
+
all_attentions,
|
1272 |
+
all_cross_attentions,
|
1273 |
+
]
|
1274 |
+
if v is not None
|
1275 |
+
)
|
1276 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
1277 |
+
last_hidden_state=hidden_states,
|
1278 |
+
past_key_values=present_key_value_states,
|
1279 |
+
hidden_states=all_hidden_states,
|
1280 |
+
attentions=all_attentions,
|
1281 |
+
cross_attentions=all_cross_attentions,
|
1282 |
+
)
|
1283 |
+
|
1284 |
+
|
1285 |
+
T5_START_DOCSTRING = r"""
|
1286 |
+
|
1287 |
+
The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
|
1288 |
+
Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
|
1289 |
+
Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
|
1290 |
+
text-to-text denoising generative setting.
|
1291 |
+
|
1292 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
1293 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
1294 |
+
etc.)
|
1295 |
+
|
1296 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
1297 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
1298 |
+
and behavior.
|
1299 |
+
|
1300 |
+
Parameters:
|
1301 |
+
config ([`T5Config`]): Model configuration class with all the parameters of the model.
|
1302 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
1303 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
1304 |
+
"""
|
1305 |
+
|
1306 |
+
T5_INPUTS_DOCSTRING = r"""
|
1307 |
+
Args:
|
1308 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1309 |
+
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
1310 |
+
should be able to pad the inputs on both the right and the left.
|
1311 |
+
|
1312 |
+
Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1313 |
+
[`PreTrainedTokenizer.__call__`] for detail.
|
1314 |
+
|
1315 |
+
[What are input IDs?](../glossary#input-ids)
|
1316 |
+
|
1317 |
+
To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
|
1318 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1319 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1320 |
+
|
1321 |
+
- 1 for tokens that are **not masked**,
|
1322 |
+
- 0 for tokens that are **masked**.
|
1323 |
+
|
1324 |
+
[What are attention masks?](../glossary#attention-mask)
|
1325 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
1326 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
1327 |
+
|
1328 |
+
Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1329 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1330 |
+
|
1331 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
1332 |
+
|
1333 |
+
T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
1334 |
+
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
1335 |
+
|
1336 |
+
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
|
1337 |
+
Training](./t5#training).
|
1338 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
1339 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
1340 |
+
be used by default.
|
1341 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
1342 |
+
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
|
1343 |
+
1]`:
|
1344 |
+
|
1345 |
+
- 1 indicates the head is **not masked**,
|
1346 |
+
- 0 indicates the head is **masked**.
|
1347 |
+
|
1348 |
+
decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
1349 |
+
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
|
1350 |
+
1]`:
|
1351 |
+
|
1352 |
+
- 1 indicates the head is **not masked**,
|
1353 |
+
- 0 indicates the head is **masked**.
|
1354 |
+
|
1355 |
+
cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
1356 |
+
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
|
1357 |
+
`[0, 1]`:
|
1358 |
+
|
1359 |
+
- 1 indicates the head is **not masked**,
|
1360 |
+
- 0 indicates the head is **masked**.
|
1361 |
+
|
1362 |
+
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
1363 |
+
Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
|
1364 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
|
1365 |
+
the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
1366 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1367 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1368 |
+
|
1369 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1370 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1371 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1372 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1373 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
1374 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1375 |
+
model's internal embedding lookup matrix.
|
1376 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
|
1377 |
+
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
|
1378 |
+
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
|
1379 |
+
input (see `past_key_values`). This is useful if you want more control over how to convert
|
1380 |
+
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
1381 |
+
|
1382 |
+
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
1383 |
+
of `inputs_embeds`.
|
1384 |
+
|
1385 |
+
use_cache (`bool`, *optional*):
|
1386 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1387 |
+
`past_key_values`).
|
1388 |
+
|
1389 |
+
output_attentions (`bool`, *optional*):
|
1390 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
1391 |
+
tensors for more detail.
|
1392 |
+
output_hidden_states (`bool`, *optional*):
|
1393 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
1394 |
+
more detail.
|
1395 |
+
return_dict (`bool`, *optional*):
|
1396 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1397 |
+
"""
|
1398 |
+
|
1399 |
+
T5_ENCODER_INPUTS_DOCSTRING = r"""
|
1400 |
+
Args:
|
1401 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1402 |
+
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
1403 |
+
should be able to pad the inputs on both the right and the left.
|
1404 |
+
|
1405 |
+
Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1406 |
+
[`PreTrainedTokenizer.__call__`] for detail.
|
1407 |
+
|
1408 |
+
To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
|
1409 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1410 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1411 |
+
|
1412 |
+
- 1 for tokens that are **not masked**,
|
1413 |
+
- 0 for tokens that are **masked**.
|
1414 |
+
|
1415 |
+
[What are attention masks?](../glossary#attention-mask)
|
1416 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
1417 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
1418 |
+
|
1419 |
+
- 1 indicates the head is **not masked**,
|
1420 |
+
- 0 indicates the head is **masked**.
|
1421 |
+
|
1422 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1423 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
1424 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1425 |
+
model's internal embedding lookup matrix.
|
1426 |
+
output_attentions (`bool`, *optional*):
|
1427 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
1428 |
+
tensors for more detail.
|
1429 |
+
output_hidden_states (`bool`, *optional*):
|
1430 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
1431 |
+
more detail.
|
1432 |
+
return_dict (`bool`, *optional*):
|
1433 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1434 |
+
"""
|
1435 |
+
|
1436 |
+
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
1437 |
+
__HEAD_MASK_WARNING_MSG = """
|
1438 |
+
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
1439 |
+
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
1440 |
+
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
|
1441 |
+
num_heads)`.
|
1442 |
+
"""
|
1443 |
+
|
1444 |
+
|
1445 |
+
@add_start_docstrings(
|
1446 |
+
"The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
|
1447 |
+
T5_START_DOCSTRING,
|
1448 |
+
)
|
1449 |
+
class T5Model(T5PreTrainedModel):
|
1450 |
+
_keys_to_ignore_on_load_missing = [
|
1451 |
+
r"encoder.embed_tokens.weight",
|
1452 |
+
r"decoder.embed_tokens.weight",
|
1453 |
+
]
|
1454 |
+
_keys_to_ignore_on_load_unexpected = [
|
1455 |
+
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
|
1456 |
+
]
|
1457 |
+
|
1458 |
+
def __init__(self, config: T5Config):
|
1459 |
+
super().__init__(config)
|
1460 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
1461 |
+
|
1462 |
+
encoder_config = copy.deepcopy(config)
|
1463 |
+
encoder_config.is_decoder = False
|
1464 |
+
encoder_config.use_cache = False
|
1465 |
+
encoder_config.is_encoder_decoder = False
|
1466 |
+
self.encoder = T5Stack(encoder_config, self.shared)
|
1467 |
+
|
1468 |
+
decoder_config = copy.deepcopy(config)
|
1469 |
+
decoder_config.is_decoder = True
|
1470 |
+
decoder_config.is_encoder_decoder = False
|
1471 |
+
decoder_config.num_layers = config.num_decoder_layers
|
1472 |
+
self.decoder = T5Stack(decoder_config, self.shared)
|
1473 |
+
|
1474 |
+
# Initialize weights and apply final processing
|
1475 |
+
self.post_init()
|
1476 |
+
|
1477 |
+
# Model parallel
|
1478 |
+
self.model_parallel = False
|
1479 |
+
self.device_map = None
|
1480 |
+
|
1481 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1482 |
+
def parallelize(self, device_map=None):
|
1483 |
+
self.device_map = (
|
1484 |
+
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
1485 |
+
if device_map is None
|
1486 |
+
else device_map
|
1487 |
+
)
|
1488 |
+
assert_device_map(self.device_map, len(self.encoder.block))
|
1489 |
+
self.encoder.parallelize(self.device_map)
|
1490 |
+
self.decoder.parallelize(self.device_map)
|
1491 |
+
self.model_parallel = True
|
1492 |
+
|
1493 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1494 |
+
def deparallelize(self):
|
1495 |
+
self.encoder.deparallelize()
|
1496 |
+
self.decoder.deparallelize()
|
1497 |
+
self.encoder = self.encoder.to("cpu")
|
1498 |
+
self.decoder = self.decoder.to("cpu")
|
1499 |
+
self.model_parallel = False
|
1500 |
+
self.device_map = None
|
1501 |
+
torch.cuda.empty_cache()
|
1502 |
+
|
1503 |
+
def get_input_embeddings(self):
|
1504 |
+
return self.shared
|
1505 |
+
|
1506 |
+
def set_input_embeddings(self, new_embeddings):
|
1507 |
+
self.shared = new_embeddings
|
1508 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
1509 |
+
self.decoder.set_input_embeddings(new_embeddings)
|
1510 |
+
|
1511 |
+
def get_encoder(self):
|
1512 |
+
return self.encoder
|
1513 |
+
|
1514 |
+
def get_decoder(self):
|
1515 |
+
return self.decoder
|
1516 |
+
|
1517 |
+
def _prune_heads(self, heads_to_prune):
|
1518 |
+
"""
|
1519 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
1520 |
+
class PreTrainedModel
|
1521 |
+
"""
|
1522 |
+
for layer, heads in heads_to_prune.items():
|
1523 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
1524 |
+
|
1525 |
+
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
1526 |
+
@replace_return_docstrings(
|
1527 |
+
output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC
|
1528 |
+
)
|
1529 |
+
def forward(
|
1530 |
+
self,
|
1531 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1532 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1533 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1534 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
1535 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1536 |
+
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
1537 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1538 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1539 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1540 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1541 |
+
decoder_inputs_embeds: Optional[torch.Tensor] = None,
|
1542 |
+
use_cache: Optional[bool] = None,
|
1543 |
+
output_attentions: Optional[bool] = None,
|
1544 |
+
output_hidden_states: Optional[bool] = None,
|
1545 |
+
return_dict: Optional[bool] = None,
|
1546 |
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
1547 |
+
r"""
|
1548 |
+
Returns:
|
1549 |
+
|
1550 |
+
Example:
|
1551 |
+
|
1552 |
+
```python
|
1553 |
+
>>> from transformers import T5Tokenizer, T5Model
|
1554 |
+
|
1555 |
+
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
1556 |
+
>>> model = T5Model.from_pretrained("t5-small")
|
1557 |
+
|
1558 |
+
>>> input_ids = tokenizer(
|
1559 |
+
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
|
1560 |
+
... ).input_ids # Batch size 1
|
1561 |
+
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
|
1562 |
+
|
1563 |
+
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
|
1564 |
+
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
|
1565 |
+
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
|
1566 |
+
|
1567 |
+
>>> # forward pass
|
1568 |
+
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
1569 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
1570 |
+
```"""
|
1571 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1572 |
+
return_dict = (
|
1573 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1574 |
+
)
|
1575 |
+
|
1576 |
+
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
1577 |
+
if head_mask is not None and decoder_head_mask is None:
|
1578 |
+
if self.config.num_layers == self.config.num_decoder_layers:
|
1579 |
+
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
1580 |
+
decoder_head_mask = head_mask
|
1581 |
+
|
1582 |
+
# Encode if needed (training, first prediction pass)
|
1583 |
+
if encoder_outputs is None:
|
1584 |
+
encoder_outputs = self.encoder(
|
1585 |
+
input_ids=input_ids,
|
1586 |
+
attention_mask=attention_mask,
|
1587 |
+
inputs_embeds=inputs_embeds,
|
1588 |
+
head_mask=head_mask,
|
1589 |
+
output_attentions=output_attentions,
|
1590 |
+
output_hidden_states=output_hidden_states,
|
1591 |
+
return_dict=return_dict,
|
1592 |
+
)
|
1593 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
1594 |
+
encoder_outputs = BaseModelOutput(
|
1595 |
+
last_hidden_state=encoder_outputs[0],
|
1596 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
1597 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
1598 |
+
)
|
1599 |
+
|
1600 |
+
hidden_states = encoder_outputs[0]
|
1601 |
+
|
1602 |
+
# Set device for model parallelism
|
1603 |
+
if self.model_parallel:
|
1604 |
+
torch.cuda.set_device(self.decoder.first_device)
|
1605 |
+
hidden_states = hidden_states.to(self.decoder.first_device)
|
1606 |
+
if decoder_input_ids is not None:
|
1607 |
+
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
|
1608 |
+
if attention_mask is not None:
|
1609 |
+
attention_mask = attention_mask.to(self.decoder.first_device)
|
1610 |
+
if decoder_attention_mask is not None:
|
1611 |
+
decoder_attention_mask = decoder_attention_mask.to(
|
1612 |
+
self.decoder.first_device
|
1613 |
+
)
|
1614 |
+
|
1615 |
+
# Decode
|
1616 |
+
decoder_outputs = self.decoder(
|
1617 |
+
input_ids=decoder_input_ids,
|
1618 |
+
attention_mask=decoder_attention_mask,
|
1619 |
+
inputs_embeds=decoder_inputs_embeds,
|
1620 |
+
past_key_values=past_key_values,
|
1621 |
+
encoder_hidden_states=hidden_states,
|
1622 |
+
encoder_attention_mask=attention_mask,
|
1623 |
+
head_mask=decoder_head_mask,
|
1624 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1625 |
+
use_cache=use_cache,
|
1626 |
+
output_attentions=output_attentions,
|
1627 |
+
output_hidden_states=output_hidden_states,
|
1628 |
+
return_dict=return_dict,
|
1629 |
+
)
|
1630 |
+
|
1631 |
+
if not return_dict:
|
1632 |
+
return decoder_outputs + encoder_outputs
|
1633 |
+
|
1634 |
+
return Seq2SeqModelOutput(
|
1635 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
1636 |
+
past_key_values=decoder_outputs.past_key_values,
|
1637 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
1638 |
+
decoder_attentions=decoder_outputs.attentions,
|
1639 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1640 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
1641 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
1642 |
+
encoder_attentions=encoder_outputs.attentions,
|
1643 |
+
)
|
1644 |
+
|
1645 |
+
|
1646 |
+
@add_start_docstrings(
|
1647 |
+
"""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING
|
1648 |
+
)
|
1649 |
+
class T5ForConditionalGeneration(T5PreTrainedModel):
|
1650 |
+
_keys_to_ignore_on_load_missing = [
|
1651 |
+
r"encoder.embed_tokens.weight",
|
1652 |
+
r"decoder.embed_tokens.weight",
|
1653 |
+
r"lm_head.weight",
|
1654 |
+
]
|
1655 |
+
_keys_to_ignore_on_load_unexpected = [
|
1656 |
+
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
|
1657 |
+
]
|
1658 |
+
|
1659 |
+
def __init__(self, config: T5Config):
|
1660 |
+
super().__init__(config)
|
1661 |
+
self.model_dim = config.d_model
|
1662 |
+
|
1663 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
1664 |
+
|
1665 |
+
encoder_config = copy.deepcopy(config)
|
1666 |
+
encoder_config.is_decoder = False
|
1667 |
+
encoder_config.use_cache = False
|
1668 |
+
encoder_config.is_encoder_decoder = False
|
1669 |
+
self.encoder = T5Stack(encoder_config, self.shared)
|
1670 |
+
|
1671 |
+
decoder_config = copy.deepcopy(config)
|
1672 |
+
decoder_config.is_decoder = True
|
1673 |
+
decoder_config.is_encoder_decoder = False
|
1674 |
+
decoder_config.num_layers = config.num_decoder_layers
|
1675 |
+
self.decoder = T5Stack(decoder_config, self.shared)
|
1676 |
+
|
1677 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
1678 |
+
|
1679 |
+
# Initialize weights and apply final processing
|
1680 |
+
self.post_init()
|
1681 |
+
|
1682 |
+
# Model parallel
|
1683 |
+
self.model_parallel = False
|
1684 |
+
self.device_map = None
|
1685 |
+
|
1686 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1687 |
+
def parallelize(self, device_map=None):
|
1688 |
+
self.device_map = (
|
1689 |
+
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
1690 |
+
if device_map is None
|
1691 |
+
else device_map
|
1692 |
+
)
|
1693 |
+
assert_device_map(self.device_map, len(self.encoder.block))
|
1694 |
+
self.encoder.parallelize(self.device_map)
|
1695 |
+
self.decoder.parallelize(self.device_map)
|
1696 |
+
self.lm_head = self.lm_head.to(self.decoder.first_device)
|
1697 |
+
self.model_parallel = True
|
1698 |
+
|
1699 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1700 |
+
def deparallelize(self):
|
1701 |
+
self.encoder.deparallelize()
|
1702 |
+
self.decoder.deparallelize()
|
1703 |
+
self.encoder = self.encoder.to("cpu")
|
1704 |
+
self.decoder = self.decoder.to("cpu")
|
1705 |
+
self.lm_head = self.lm_head.to("cpu")
|
1706 |
+
self.model_parallel = False
|
1707 |
+
self.device_map = None
|
1708 |
+
torch.cuda.empty_cache()
|
1709 |
+
|
1710 |
+
def get_input_embeddings(self):
|
1711 |
+
return self.shared
|
1712 |
+
|
1713 |
+
def set_input_embeddings(self, new_embeddings):
|
1714 |
+
self.shared = new_embeddings
|
1715 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
1716 |
+
self.decoder.set_input_embeddings(new_embeddings)
|
1717 |
+
|
1718 |
+
def set_output_embeddings(self, new_embeddings):
|
1719 |
+
self.lm_head = new_embeddings
|
1720 |
+
|
1721 |
+
def get_output_embeddings(self):
|
1722 |
+
return self.lm_head
|
1723 |
+
|
1724 |
+
def get_encoder(self):
|
1725 |
+
return self.encoder
|
1726 |
+
|
1727 |
+
def get_decoder(self):
|
1728 |
+
return self.decoder
|
1729 |
+
|
1730 |
+
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
1731 |
+
@replace_return_docstrings(
|
1732 |
+
output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
|
1733 |
+
)
|
1734 |
+
def forward(
|
1735 |
+
self,
|
1736 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1737 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1738 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1739 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
1740 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1741 |
+
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
1742 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1743 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1744 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1745 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1746 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1747 |
+
labels: Optional[torch.LongTensor] = None,
|
1748 |
+
use_cache: Optional[bool] = None,
|
1749 |
+
output_attentions: Optional[bool] = None,
|
1750 |
+
output_hidden_states: Optional[bool] = None,
|
1751 |
+
return_dict: Optional[bool] = None,
|
1752 |
+
reduction: Optional[str] = "mean",
|
1753 |
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
1754 |
+
r"""
|
1755 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1756 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
|
1757 |
+
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
|
1758 |
+
labels in `[0, ..., config.vocab_size]`
|
1759 |
+
|
1760 |
+
Returns:
|
1761 |
+
|
1762 |
+
Examples:
|
1763 |
+
|
1764 |
+
```python
|
1765 |
+
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
1766 |
+
|
1767 |
+
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
1768 |
+
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
1769 |
+
|
1770 |
+
>>> # training
|
1771 |
+
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
1772 |
+
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
|
1773 |
+
>>> outputs = model(input_ids=input_ids, labels=labels)
|
1774 |
+
>>> loss = outputs.loss
|
1775 |
+
>>> logits = outputs.logits
|
1776 |
+
|
1777 |
+
>>> # inference
|
1778 |
+
>>> input_ids = tokenizer(
|
1779 |
+
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
|
1780 |
+
... ).input_ids # Batch size 1
|
1781 |
+
>>> outputs = model.generate(input_ids)
|
1782 |
+
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
1783 |
+
>>> # studies have shown that owning a dog is good for you.
|
1784 |
+
```"""
|
1785 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1786 |
+
return_dict = (
|
1787 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1788 |
+
)
|
1789 |
+
|
1790 |
+
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
1791 |
+
if head_mask is not None and decoder_head_mask is None:
|
1792 |
+
if self.config.num_layers == self.config.num_decoder_layers:
|
1793 |
+
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
1794 |
+
decoder_head_mask = head_mask
|
1795 |
+
|
1796 |
+
# Encode if needed (training, first prediction pass)
|
1797 |
+
if encoder_outputs is None:
|
1798 |
+
# Convert encoder inputs in embeddings if needed
|
1799 |
+
encoder_outputs = self.encoder(
|
1800 |
+
input_ids=input_ids,
|
1801 |
+
attention_mask=attention_mask,
|
1802 |
+
inputs_embeds=inputs_embeds,
|
1803 |
+
head_mask=head_mask,
|
1804 |
+
output_attentions=output_attentions,
|
1805 |
+
output_hidden_states=output_hidden_states,
|
1806 |
+
return_dict=return_dict,
|
1807 |
+
)
|
1808 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
1809 |
+
encoder_outputs = BaseModelOutput(
|
1810 |
+
last_hidden_state=encoder_outputs[0],
|
1811 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
1812 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
1813 |
+
)
|
1814 |
+
|
1815 |
+
hidden_states = encoder_outputs[0]
|
1816 |
+
|
1817 |
+
if self.model_parallel:
|
1818 |
+
torch.cuda.set_device(self.decoder.first_device)
|
1819 |
+
|
1820 |
+
if (
|
1821 |
+
labels is not None
|
1822 |
+
and decoder_input_ids is None
|
1823 |
+
and decoder_inputs_embeds is None
|
1824 |
+
):
|
1825 |
+
# get decoder inputs from shifting lm labels to the right
|
1826 |
+
decoder_input_ids = self._shift_right(labels)
|
1827 |
+
|
1828 |
+
# Set device for model parallelism
|
1829 |
+
if self.model_parallel:
|
1830 |
+
torch.cuda.set_device(self.decoder.first_device)
|
1831 |
+
hidden_states = hidden_states.to(self.decoder.first_device)
|
1832 |
+
if decoder_input_ids is not None:
|
1833 |
+
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
|
1834 |
+
if attention_mask is not None:
|
1835 |
+
attention_mask = attention_mask.to(self.decoder.first_device)
|
1836 |
+
if decoder_attention_mask is not None:
|
1837 |
+
decoder_attention_mask = decoder_attention_mask.to(
|
1838 |
+
self.decoder.first_device
|
1839 |
+
)
|
1840 |
+
|
1841 |
+
# Decode
|
1842 |
+
decoder_outputs = self.decoder(
|
1843 |
+
input_ids=decoder_input_ids,
|
1844 |
+
attention_mask=decoder_attention_mask,
|
1845 |
+
inputs_embeds=decoder_inputs_embeds,
|
1846 |
+
past_key_values=past_key_values,
|
1847 |
+
encoder_hidden_states=hidden_states,
|
1848 |
+
encoder_attention_mask=attention_mask,
|
1849 |
+
head_mask=decoder_head_mask,
|
1850 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1851 |
+
use_cache=use_cache,
|
1852 |
+
output_attentions=output_attentions,
|
1853 |
+
output_hidden_states=output_hidden_states,
|
1854 |
+
return_dict=return_dict,
|
1855 |
+
)
|
1856 |
+
|
1857 |
+
sequence_output = decoder_outputs[0]
|
1858 |
+
|
1859 |
+
# Set device for model parallelism
|
1860 |
+
if self.model_parallel:
|
1861 |
+
torch.cuda.set_device(self.encoder.first_device)
|
1862 |
+
self.lm_head = self.lm_head.to(self.encoder.first_device)
|
1863 |
+
sequence_output = sequence_output.to(self.lm_head.weight.device)
|
1864 |
+
|
1865 |
+
if self.config.tie_word_embeddings:
|
1866 |
+
# Rescale output before projecting on vocab
|
1867 |
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
1868 |
+
sequence_output = sequence_output * (self.model_dim**-0.5)
|
1869 |
+
|
1870 |
+
lm_logits = self.lm_head(sequence_output)
|
1871 |
+
|
1872 |
+
loss = None
|
1873 |
+
if labels is not None:
|
1874 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction=reduction)
|
1875 |
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
1876 |
+
if reduction == "none":
|
1877 |
+
loss = loss.view(lm_logits.size(0), -1).sum(1)
|
1878 |
+
|
1879 |
+
if not return_dict:
|
1880 |
+
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
1881 |
+
return ((loss,) + output) if loss is not None else output
|
1882 |
+
|
1883 |
+
return Seq2SeqLMOutput(
|
1884 |
+
loss=loss,
|
1885 |
+
logits=lm_logits,
|
1886 |
+
past_key_values=decoder_outputs.past_key_values,
|
1887 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
1888 |
+
decoder_attentions=decoder_outputs.attentions,
|
1889 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1890 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
1891 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
1892 |
+
encoder_attentions=encoder_outputs.attentions,
|
1893 |
+
)
|
1894 |
+
|
1895 |
+
def prepare_inputs_for_generation(
|
1896 |
+
self,
|
1897 |
+
input_ids,
|
1898 |
+
past=None,
|
1899 |
+
attention_mask=None,
|
1900 |
+
head_mask=None,
|
1901 |
+
decoder_head_mask=None,
|
1902 |
+
cross_attn_head_mask=None,
|
1903 |
+
use_cache=None,
|
1904 |
+
encoder_outputs=None,
|
1905 |
+
**kwargs,
|
1906 |
+
):
|
1907 |
+
|
1908 |
+
# cut decoder_input_ids if past is used
|
1909 |
+
if past is not None:
|
1910 |
+
input_ids = input_ids[:, -1:]
|
1911 |
+
|
1912 |
+
return {
|
1913 |
+
"decoder_input_ids": input_ids,
|
1914 |
+
"past_key_values": past,
|
1915 |
+
"encoder_outputs": encoder_outputs,
|
1916 |
+
"attention_mask": attention_mask,
|
1917 |
+
"head_mask": head_mask,
|
1918 |
+
"decoder_head_mask": decoder_head_mask,
|
1919 |
+
"cross_attn_head_mask": cross_attn_head_mask,
|
1920 |
+
"use_cache": use_cache,
|
1921 |
+
}
|
1922 |
+
|
1923 |
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
1924 |
+
return self._shift_right(labels)
|
1925 |
+
|
1926 |
+
def _reorder_cache(self, past, beam_idx):
|
1927 |
+
# if decoder past is not included in output
|
1928 |
+
# speedy decoding is disabled and no need to reorder
|
1929 |
+
if past is None:
|
1930 |
+
logger.warning(
|
1931 |
+
"You might want to consider setting `use_cache=True` to speed up decoding"
|
1932 |
+
)
|
1933 |
+
return past
|
1934 |
+
|
1935 |
+
reordered_decoder_past = ()
|
1936 |
+
for layer_past_states in past:
|
1937 |
+
# get the correct batch idx from layer past batch dim
|
1938 |
+
# batch dim of `past` is at 2nd position
|
1939 |
+
reordered_layer_past_states = ()
|
1940 |
+
for layer_past_state in layer_past_states:
|
1941 |
+
# need to set correct `past` for each of the four key / value states
|
1942 |
+
reordered_layer_past_states = reordered_layer_past_states + (
|
1943 |
+
layer_past_state.index_select(
|
1944 |
+
0, beam_idx.to(layer_past_state.device)
|
1945 |
+
),
|
1946 |
+
)
|
1947 |
+
|
1948 |
+
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|
1949 |
+
assert len(reordered_layer_past_states) == len(layer_past_states)
|
1950 |
+
|
1951 |
+
reordered_decoder_past = reordered_decoder_past + (
|
1952 |
+
reordered_layer_past_states,
|
1953 |
+
)
|
1954 |
+
return reordered_decoder_past
|
1955 |
+
|
1956 |
+
|
1957 |
+
@add_start_docstrings(
|
1958 |
+
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
|
1959 |
+
T5_START_DOCSTRING,
|
1960 |
+
)
|
1961 |
+
class T5EncoderModel(T5PreTrainedModel):
|
1962 |
+
authorized_missing_keys = [
|
1963 |
+
r"encoder.embed_tokens.weight",
|
1964 |
+
]
|
1965 |
+
|
1966 |
+
def __init__(self, config: T5Config):
|
1967 |
+
super().__init__(config)
|
1968 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
1969 |
+
|
1970 |
+
encoder_config = copy.deepcopy(config)
|
1971 |
+
encoder_config.use_cache = False
|
1972 |
+
encoder_config.is_encoder_decoder = False
|
1973 |
+
self.encoder = T5Stack(encoder_config, self.shared)
|
1974 |
+
|
1975 |
+
# Initialize weights and apply final processing
|
1976 |
+
self.post_init()
|
1977 |
+
|
1978 |
+
# Model parallel
|
1979 |
+
self.model_parallel = False
|
1980 |
+
self.device_map = None
|
1981 |
+
|
1982 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1983 |
+
def parallelize(self, device_map=None):
|
1984 |
+
self.device_map = (
|
1985 |
+
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
1986 |
+
if device_map is None
|
1987 |
+
else device_map
|
1988 |
+
)
|
1989 |
+
assert_device_map(self.device_map, len(self.encoder.block))
|
1990 |
+
self.encoder.parallelize(self.device_map)
|
1991 |
+
self.model_parallel = True
|
1992 |
+
|
1993 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1994 |
+
def deparallelize(self):
|
1995 |
+
self.encoder.deparallelize()
|
1996 |
+
self.encoder = self.encoder.to("cpu")
|
1997 |
+
self.model_parallel = False
|
1998 |
+
self.device_map = None
|
1999 |
+
torch.cuda.empty_cache()
|
2000 |
+
|
2001 |
+
def get_input_embeddings(self):
|
2002 |
+
return self.shared
|
2003 |
+
|
2004 |
+
def set_input_embeddings(self, new_embeddings):
|
2005 |
+
self.shared = new_embeddings
|
2006 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
2007 |
+
|
2008 |
+
def get_encoder(self):
|
2009 |
+
return self.encoder
|
2010 |
+
|
2011 |
+
def _prune_heads(self, heads_to_prune):
|
2012 |
+
"""
|
2013 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
2014 |
+
class PreTrainedModel
|
2015 |
+
"""
|
2016 |
+
for layer, heads in heads_to_prune.items():
|
2017 |
+
self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
|
2018 |
+
|
2019 |
+
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
|
2020 |
+
@replace_return_docstrings(
|
2021 |
+
output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC
|
2022 |
+
)
|
2023 |
+
def forward(
|
2024 |
+
self,
|
2025 |
+
input_ids: Optional[torch.LongTensor] = None,
|
2026 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
2027 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
2028 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
2029 |
+
output_attentions: Optional[bool] = None,
|
2030 |
+
output_hidden_states: Optional[bool] = None,
|
2031 |
+
return_dict: Optional[bool] = None,
|
2032 |
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
2033 |
+
r"""
|
2034 |
+
Returns:
|
2035 |
+
|
2036 |
+
Example:
|
2037 |
+
|
2038 |
+
```python
|
2039 |
+
>>> from transformers import T5Tokenizer, T5EncoderModel
|
2040 |
+
|
2041 |
+
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
2042 |
+
>>> model = T5EncoderModel.from_pretrained("t5-small")
|
2043 |
+
>>> input_ids = tokenizer(
|
2044 |
+
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
|
2045 |
+
... ).input_ids # Batch size 1
|
2046 |
+
>>> outputs = model(input_ids=input_ids)
|
2047 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
2048 |
+
```"""
|
2049 |
+
return_dict = (
|
2050 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
2051 |
+
)
|
2052 |
+
|
2053 |
+
encoder_outputs = self.encoder(
|
2054 |
+
input_ids=input_ids,
|
2055 |
+
attention_mask=attention_mask,
|
2056 |
+
inputs_embeds=inputs_embeds,
|
2057 |
+
head_mask=head_mask,
|
2058 |
+
output_attentions=output_attentions,
|
2059 |
+
output_hidden_states=output_hidden_states,
|
2060 |
+
return_dict=return_dict,
|
2061 |
+
)
|
2062 |
+
|
2063 |
+
return encoder_outputs
|
bliva/models/vit.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
|
7 |
+
Based on timm code base
|
8 |
+
https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from functools import partial
|
16 |
+
|
17 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
18 |
+
from timm.models.registry import register_model
|
19 |
+
from timm.models.layers import trunc_normal_, DropPath
|
20 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
21 |
+
|
22 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
23 |
+
from bliva.models.base_model import BaseEncoder
|
24 |
+
|
25 |
+
|
26 |
+
class Mlp(nn.Module):
|
27 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
in_features,
|
32 |
+
hidden_features=None,
|
33 |
+
out_features=None,
|
34 |
+
act_layer=nn.GELU,
|
35 |
+
drop=0.0,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
out_features = out_features or in_features
|
39 |
+
hidden_features = hidden_features or in_features
|
40 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
41 |
+
self.act = act_layer()
|
42 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
43 |
+
self.drop = nn.Dropout(drop)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x = self.fc1(x)
|
47 |
+
x = self.act(x)
|
48 |
+
x = self.drop(x)
|
49 |
+
x = self.fc2(x)
|
50 |
+
x = self.drop(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class Attention(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
dim,
|
58 |
+
num_heads=8,
|
59 |
+
qkv_bias=False,
|
60 |
+
qk_scale=None,
|
61 |
+
attn_drop=0.0,
|
62 |
+
proj_drop=0.0,
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
self.num_heads = num_heads
|
66 |
+
head_dim = dim // num_heads
|
67 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
68 |
+
self.scale = qk_scale or head_dim**-0.5
|
69 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
70 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
71 |
+
self.proj = nn.Linear(dim, dim)
|
72 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
73 |
+
self.attn_gradients = None
|
74 |
+
self.attention_map = None
|
75 |
+
|
76 |
+
def save_attn_gradients(self, attn_gradients):
|
77 |
+
self.attn_gradients = attn_gradients
|
78 |
+
|
79 |
+
def get_attn_gradients(self):
|
80 |
+
return self.attn_gradients
|
81 |
+
|
82 |
+
def save_attention_map(self, attention_map):
|
83 |
+
self.attention_map = attention_map
|
84 |
+
|
85 |
+
def get_attention_map(self):
|
86 |
+
return self.attention_map
|
87 |
+
|
88 |
+
def forward(self, x, register_hook=False):
|
89 |
+
B, N, C = x.shape
|
90 |
+
qkv = (
|
91 |
+
self.qkv(x)
|
92 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
93 |
+
.permute(2, 0, 3, 1, 4)
|
94 |
+
)
|
95 |
+
q, k, v = (
|
96 |
+
qkv[0],
|
97 |
+
qkv[1],
|
98 |
+
qkv[2],
|
99 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
100 |
+
|
101 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
102 |
+
attn = attn.softmax(dim=-1)
|
103 |
+
attn = self.attn_drop(attn)
|
104 |
+
|
105 |
+
if register_hook:
|
106 |
+
self.save_attention_map(attn)
|
107 |
+
attn.register_hook(self.save_attn_gradients)
|
108 |
+
|
109 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
110 |
+
x = self.proj(x)
|
111 |
+
x = self.proj_drop(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
class Block(nn.Module):
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
dim,
|
119 |
+
num_heads,
|
120 |
+
mlp_ratio=4.0,
|
121 |
+
qkv_bias=False,
|
122 |
+
qk_scale=None,
|
123 |
+
drop=0.0,
|
124 |
+
attn_drop=0.0,
|
125 |
+
drop_path=0.0,
|
126 |
+
act_layer=nn.GELU,
|
127 |
+
norm_layer=nn.LayerNorm,
|
128 |
+
use_grad_checkpointing=False,
|
129 |
+
):
|
130 |
+
super().__init__()
|
131 |
+
self.norm1 = norm_layer(dim)
|
132 |
+
self.attn = Attention(
|
133 |
+
dim,
|
134 |
+
num_heads=num_heads,
|
135 |
+
qkv_bias=qkv_bias,
|
136 |
+
qk_scale=qk_scale,
|
137 |
+
attn_drop=attn_drop,
|
138 |
+
proj_drop=drop,
|
139 |
+
)
|
140 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
141 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
142 |
+
self.norm2 = norm_layer(dim)
|
143 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
144 |
+
self.mlp = Mlp(
|
145 |
+
in_features=dim,
|
146 |
+
hidden_features=mlp_hidden_dim,
|
147 |
+
act_layer=act_layer,
|
148 |
+
drop=drop,
|
149 |
+
)
|
150 |
+
|
151 |
+
if use_grad_checkpointing:
|
152 |
+
self.attn = checkpoint_wrapper(self.attn)
|
153 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
154 |
+
|
155 |
+
def forward(self, x, register_hook=False):
|
156 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
157 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class VisionTransformer(nn.Module):
|
162 |
+
"""Vision Transformer
|
163 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
164 |
+
https://arxiv.org/abs/2010.11929
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
img_size=224,
|
170 |
+
patch_size=16,
|
171 |
+
in_chans=3,
|
172 |
+
num_classes=1000,
|
173 |
+
embed_dim=768,
|
174 |
+
depth=12,
|
175 |
+
num_heads=12,
|
176 |
+
mlp_ratio=4.0,
|
177 |
+
qkv_bias=True,
|
178 |
+
qk_scale=None,
|
179 |
+
representation_size=None,
|
180 |
+
drop_rate=0.0,
|
181 |
+
attn_drop_rate=0.0,
|
182 |
+
drop_path_rate=0.0,
|
183 |
+
norm_layer=None,
|
184 |
+
use_grad_checkpointing=False,
|
185 |
+
ckpt_layer=0,
|
186 |
+
):
|
187 |
+
"""
|
188 |
+
Args:
|
189 |
+
img_size (int, tuple): input image size
|
190 |
+
patch_size (int, tuple): patch size
|
191 |
+
in_chans (int): number of input channels
|
192 |
+
num_classes (int): number of classes for classification head
|
193 |
+
embed_dim (int): embedding dimension
|
194 |
+
depth (int): depth of transformer
|
195 |
+
num_heads (int): number of attention heads
|
196 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
197 |
+
qkv_bias (bool): enable bias for qkv if True
|
198 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
199 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
200 |
+
drop_rate (float): dropout rate
|
201 |
+
attn_drop_rate (float): attention dropout rate
|
202 |
+
drop_path_rate (float): stochastic depth rate
|
203 |
+
norm_layer: (nn.Module): normalization layer
|
204 |
+
"""
|
205 |
+
super().__init__()
|
206 |
+
self.num_features = (
|
207 |
+
self.embed_dim
|
208 |
+
) = embed_dim # num_features for consistency with other models
|
209 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
210 |
+
|
211 |
+
self.patch_embed = PatchEmbed(
|
212 |
+
img_size=img_size,
|
213 |
+
patch_size=patch_size,
|
214 |
+
in_chans=in_chans,
|
215 |
+
embed_dim=embed_dim,
|
216 |
+
)
|
217 |
+
|
218 |
+
num_patches = self.patch_embed.num_patches
|
219 |
+
|
220 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
221 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
222 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
223 |
+
|
224 |
+
dpr = [
|
225 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
226 |
+
] # stochastic depth decay rule
|
227 |
+
self.blocks = nn.ModuleList(
|
228 |
+
[
|
229 |
+
Block(
|
230 |
+
dim=embed_dim,
|
231 |
+
num_heads=num_heads,
|
232 |
+
mlp_ratio=mlp_ratio,
|
233 |
+
qkv_bias=qkv_bias,
|
234 |
+
qk_scale=qk_scale,
|
235 |
+
drop=drop_rate,
|
236 |
+
attn_drop=attn_drop_rate,
|
237 |
+
drop_path=dpr[i],
|
238 |
+
norm_layer=norm_layer,
|
239 |
+
use_grad_checkpointing=(
|
240 |
+
use_grad_checkpointing and i >= depth - ckpt_layer
|
241 |
+
),
|
242 |
+
)
|
243 |
+
for i in range(depth)
|
244 |
+
]
|
245 |
+
)
|
246 |
+
self.norm = norm_layer(embed_dim)
|
247 |
+
|
248 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
249 |
+
trunc_normal_(self.cls_token, std=0.02)
|
250 |
+
self.apply(self._init_weights)
|
251 |
+
|
252 |
+
def _init_weights(self, m):
|
253 |
+
if isinstance(m, nn.Linear):
|
254 |
+
trunc_normal_(m.weight, std=0.02)
|
255 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
256 |
+
nn.init.constant_(m.bias, 0)
|
257 |
+
elif isinstance(m, nn.LayerNorm):
|
258 |
+
nn.init.constant_(m.bias, 0)
|
259 |
+
nn.init.constant_(m.weight, 1.0)
|
260 |
+
|
261 |
+
@torch.jit.ignore
|
262 |
+
def no_weight_decay(self):
|
263 |
+
return {"pos_embed", "cls_token"}
|
264 |
+
|
265 |
+
def forward(self, x, register_blk=-1):
|
266 |
+
B = x.shape[0]
|
267 |
+
x = self.patch_embed(x)
|
268 |
+
|
269 |
+
cls_tokens = self.cls_token.expand(
|
270 |
+
B, -1, -1
|
271 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
272 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
273 |
+
|
274 |
+
x = x + self.pos_embed[:, : x.size(1), :]
|
275 |
+
x = self.pos_drop(x)
|
276 |
+
|
277 |
+
for i, blk in enumerate(self.blocks):
|
278 |
+
x = blk(x, register_blk == i)
|
279 |
+
x = self.norm(x)
|
280 |
+
|
281 |
+
return x
|
282 |
+
|
283 |
+
@torch.jit.ignore()
|
284 |
+
def load_pretrained(self, checkpoint_path, prefix=""):
|
285 |
+
_load_weights(self, checkpoint_path, prefix)
|
286 |
+
|
287 |
+
|
288 |
+
@torch.no_grad()
|
289 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
|
290 |
+
"""Load weights from .npz checkpoints for official Google Brain Flax implementation"""
|
291 |
+
import numpy as np
|
292 |
+
|
293 |
+
def _n2p(w, t=True):
|
294 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
295 |
+
w = w.flatten()
|
296 |
+
if t:
|
297 |
+
if w.ndim == 4:
|
298 |
+
w = w.transpose([3, 2, 0, 1])
|
299 |
+
elif w.ndim == 3:
|
300 |
+
w = w.transpose([2, 0, 1])
|
301 |
+
elif w.ndim == 2:
|
302 |
+
w = w.transpose([1, 0])
|
303 |
+
return torch.from_numpy(w)
|
304 |
+
|
305 |
+
w = np.load(checkpoint_path)
|
306 |
+
if not prefix and "opt/target/embedding/kernel" in w:
|
307 |
+
prefix = "opt/target/"
|
308 |
+
|
309 |
+
if hasattr(model.patch_embed, "backbone"):
|
310 |
+
# hybrid
|
311 |
+
backbone = model.patch_embed.backbone
|
312 |
+
stem_only = not hasattr(backbone, "stem")
|
313 |
+
stem = backbone if stem_only else backbone.stem
|
314 |
+
stem.conv.weight.copy_(
|
315 |
+
adapt_input_conv(
|
316 |
+
stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])
|
317 |
+
)
|
318 |
+
)
|
319 |
+
stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
|
320 |
+
stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
|
321 |
+
if not stem_only:
|
322 |
+
for i, stage in enumerate(backbone.stages):
|
323 |
+
for j, block in enumerate(stage.blocks):
|
324 |
+
bp = f"{prefix}block{i + 1}/unit{j + 1}/"
|
325 |
+
for r in range(3):
|
326 |
+
getattr(block, f"conv{r + 1}").weight.copy_(
|
327 |
+
_n2p(w[f"{bp}conv{r + 1}/kernel"])
|
328 |
+
)
|
329 |
+
getattr(block, f"norm{r + 1}").weight.copy_(
|
330 |
+
_n2p(w[f"{bp}gn{r + 1}/scale"])
|
331 |
+
)
|
332 |
+
getattr(block, f"norm{r + 1}").bias.copy_(
|
333 |
+
_n2p(w[f"{bp}gn{r + 1}/bias"])
|
334 |
+
)
|
335 |
+
if block.downsample is not None:
|
336 |
+
block.downsample.conv.weight.copy_(
|
337 |
+
_n2p(w[f"{bp}conv_proj/kernel"])
|
338 |
+
)
|
339 |
+
block.downsample.norm.weight.copy_(
|
340 |
+
_n2p(w[f"{bp}gn_proj/scale"])
|
341 |
+
)
|
342 |
+
block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
|
343 |
+
embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
|
344 |
+
else:
|
345 |
+
embed_conv_w = adapt_input_conv(
|
346 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"])
|
347 |
+
)
|
348 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
349 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
|
350 |
+
model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
|
351 |
+
pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
|
352 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
353 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
354 |
+
pos_embed_w,
|
355 |
+
model.pos_embed,
|
356 |
+
getattr(model, "num_tokens", 1),
|
357 |
+
model.patch_embed.grid_size,
|
358 |
+
)
|
359 |
+
model.pos_embed.copy_(pos_embed_w)
|
360 |
+
model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
|
361 |
+
model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
|
362 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
363 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
364 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
365 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
366 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
367 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
368 |
+
for i, block in enumerate(model.blocks.children()):
|
369 |
+
block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
|
370 |
+
mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/"
|
371 |
+
block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"]))
|
372 |
+
block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"]))
|
373 |
+
block.attn.qkv.weight.copy_(
|
374 |
+
torch.cat(
|
375 |
+
[
|
376 |
+
_n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T
|
377 |
+
for n in ("query", "key", "value")
|
378 |
+
]
|
379 |
+
)
|
380 |
+
)
|
381 |
+
block.attn.qkv.bias.copy_(
|
382 |
+
torch.cat(
|
383 |
+
[
|
384 |
+
_n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1)
|
385 |
+
for n in ("query", "key", "value")
|
386 |
+
]
|
387 |
+
)
|
388 |
+
)
|
389 |
+
block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1))
|
390 |
+
block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"]))
|
391 |
+
for r in range(2):
|
392 |
+
getattr(block.mlp, f"fc{r + 1}").weight.copy_(
|
393 |
+
_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"])
|
394 |
+
)
|
395 |
+
getattr(block.mlp, f"fc{r + 1}").bias.copy_(
|
396 |
+
_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"])
|
397 |
+
)
|
398 |
+
block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"]))
|
399 |
+
block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"]))
|
400 |
+
|
401 |
+
|
402 |
+
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
403 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
404 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
405 |
+
print("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
|
406 |
+
ntok_new = posemb_new.shape[1]
|
407 |
+
if num_tokens:
|
408 |
+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
409 |
+
ntok_new -= num_tokens
|
410 |
+
else:
|
411 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
412 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
413 |
+
if not len(gs_new): # backwards compatibility
|
414 |
+
gs_new = [int(math.sqrt(ntok_new))] * 2
|
415 |
+
assert len(gs_new) >= 2
|
416 |
+
print("Position embedding grid-size from %s to %s", [gs_old, gs_old], gs_new)
|
417 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
418 |
+
posemb_grid = F.interpolate(
|
419 |
+
posemb_grid, size=gs_new, mode="bicubic", align_corners=False
|
420 |
+
)
|
421 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
422 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
423 |
+
return
|
424 |
+
|
425 |
+
|
426 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
427 |
+
# interpolate position embedding
|
428 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
429 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
430 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
431 |
+
# height (== width) for the checkpoint position embedding
|
432 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
433 |
+
# height (== width) for the new position embedding
|
434 |
+
new_size = int(num_patches**0.5)
|
435 |
+
|
436 |
+
if orig_size != new_size:
|
437 |
+
# class_token and dist_token are kept unchanged
|
438 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
439 |
+
# only the position tokens are interpolated
|
440 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
441 |
+
pos_tokens = pos_tokens.reshape(
|
442 |
+
-1, orig_size, orig_size, embedding_size
|
443 |
+
).permute(0, 3, 1, 2)
|
444 |
+
pos_tokens = torch.nn.functional.interpolate(
|
445 |
+
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
|
446 |
+
)
|
447 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
448 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
449 |
+
print(
|
450 |
+
"reshape position embedding from %d to %d" % (orig_size**2, new_size**2)
|
451 |
+
)
|
452 |
+
|
453 |
+
return new_pos_embed
|
454 |
+
else:
|
455 |
+
return pos_embed_checkpoint
|
456 |
+
|
457 |
+
|
458 |
+
class VisionTransformerEncoder(VisionTransformer, BaseEncoder):
|
459 |
+
@classmethod
|
460 |
+
def from_config(cls, cfg, from_pretrained=False):
|
461 |
+
|
462 |
+
vit_type = cfg.get("vit_type", "base")
|
463 |
+
image_size = cfg.get("image_size", 384)
|
464 |
+
ckpt_layer = cfg.get("vit_ckpt_layer", 0)
|
465 |
+
drop_path_rate = cfg.get("vit_drop_path_rate", 0)
|
466 |
+
norm_layer_eps = cfg.get("vit_layer_norm_epsilon", -1)
|
467 |
+
use_grad_checkpointing = cfg.get("vit_grad_ckpt", False)
|
468 |
+
|
469 |
+
if norm_layer_eps == -1:
|
470 |
+
norm_layer = None
|
471 |
+
else:
|
472 |
+
norm_layer = partial(nn.LayerNorm, eps=norm_layer_eps)
|
473 |
+
|
474 |
+
# norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
475 |
+
assert vit_type in ["base", "large"], "vit parameter must be base or large"
|
476 |
+
if vit_type == "base":
|
477 |
+
vision_width = 768
|
478 |
+
visual_encoder = cls(
|
479 |
+
img_size=image_size,
|
480 |
+
patch_size=16,
|
481 |
+
embed_dim=vision_width,
|
482 |
+
depth=12,
|
483 |
+
num_heads=12,
|
484 |
+
use_grad_checkpointing=use_grad_checkpointing,
|
485 |
+
ckpt_layer=ckpt_layer,
|
486 |
+
drop_path_rate=0 or drop_path_rate,
|
487 |
+
norm_layer=norm_layer,
|
488 |
+
)
|
489 |
+
|
490 |
+
if from_pretrained:
|
491 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
492 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
493 |
+
map_location="cpu",
|
494 |
+
check_hash=True,
|
495 |
+
)
|
496 |
+
state_dict = checkpoint["model"]
|
497 |
+
state_dict["pos_embed"] = interpolate_pos_embed(
|
498 |
+
state_dict["pos_embed"], visual_encoder
|
499 |
+
)
|
500 |
+
msg = visual_encoder.load_state_dict(state_dict, strict=False)
|
501 |
+
|
502 |
+
elif vit_type == "large":
|
503 |
+
vision_width = 1024
|
504 |
+
visual_encoder = cls(
|
505 |
+
img_size=image_size,
|
506 |
+
patch_size=16,
|
507 |
+
embed_dim=vision_width,
|
508 |
+
depth=24,
|
509 |
+
num_heads=16,
|
510 |
+
use_grad_checkpointing=use_grad_checkpointing,
|
511 |
+
ckpt_layer=ckpt_layer,
|
512 |
+
drop_path_rate=0.1 or drop_path_rate,
|
513 |
+
norm_layer=norm_layer,
|
514 |
+
)
|
515 |
+
if from_pretrained:
|
516 |
+
from timm.models.helpers import load_custom_pretrained
|
517 |
+
from timm.models.vision_transformer import default_cfgs
|
518 |
+
|
519 |
+
load_custom_pretrained(
|
520 |
+
visual_encoder, default_cfgs["vit_large_patch16_224_in21k"]
|
521 |
+
)
|
522 |
+
|
523 |
+
visual_encoder.vision_width = vision_width
|
524 |
+
return visual_encoder
|
525 |
+
|
526 |
+
def forward_features(self, x, register_blk=-1):
|
527 |
+
return super().forward(x, register_blk)
|
bliva/processors/__init__.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
from bliva.processors.base_processor import BaseProcessor
|
9 |
+
|
10 |
+
from bliva.processors.blip_processors import (
|
11 |
+
BlipImageTrainProcessor,
|
12 |
+
Blip2ImageTrainProcessor,
|
13 |
+
BlipImageEvalProcessor,
|
14 |
+
BlipCaptionProcessor,
|
15 |
+
)
|
16 |
+
from bliva.processors.clip_processors import ClipImageTrainProcessor
|
17 |
+
|
18 |
+
from bliva.common.registry import registry
|
19 |
+
|
20 |
+
__all__ = [
|
21 |
+
"BaseProcessor",
|
22 |
+
"BlipImageTrainProcessor",
|
23 |
+
"Blip2ImageTrainProcessor",
|
24 |
+
"BlipImageEvalProcessor",
|
25 |
+
"BlipCaptionProcessor",
|
26 |
+
"ClipImageTrainProcessor",
|
27 |
+
]
|
28 |
+
|
29 |
+
|
30 |
+
def load_processor(name, cfg=None):
|
31 |
+
"""
|
32 |
+
Example
|
33 |
+
|
34 |
+
>>> processor = load_processor("alpro_video_train", cfg=None)
|
35 |
+
"""
|
36 |
+
processor = registry.get_processor_class(name).from_config(cfg)
|
37 |
+
|
38 |
+
return processor
|
bliva/processors/base_processor.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
|
11 |
+
class BaseProcessor:
|
12 |
+
def __init__(self):
|
13 |
+
self.transform = lambda x: x
|
14 |
+
return
|
15 |
+
|
16 |
+
def __call__(self, item):
|
17 |
+
return self.transform(item)
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def from_config(cls, cfg=None):
|
21 |
+
return cls()
|
22 |
+
|
23 |
+
def build(self, **kwargs):
|
24 |
+
cfg = OmegaConf.create(kwargs)
|
25 |
+
|
26 |
+
return self.from_config(cfg)
|
bliva/processors/blip_processors.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import re
|
9 |
+
|
10 |
+
from bliva.common.registry import registry
|
11 |
+
from bliva.processors.base_processor import BaseProcessor
|
12 |
+
from bliva.processors.randaugment import RandomAugment
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from torchvision import transforms
|
15 |
+
from torchvision.transforms.functional import InterpolationMode
|
16 |
+
|
17 |
+
|
18 |
+
class BlipImageBaseProcessor(BaseProcessor):
|
19 |
+
def __init__(self, mean=None, std=None):
|
20 |
+
if mean is None:
|
21 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
22 |
+
if std is None:
|
23 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
24 |
+
|
25 |
+
self.normalize = transforms.Normalize(mean, std)
|
26 |
+
|
27 |
+
|
28 |
+
@registry.register_processor("blip_caption")
|
29 |
+
class BlipCaptionProcessor(BaseProcessor):
|
30 |
+
def __init__(self, prompt="", max_words=50):
|
31 |
+
self.prompt = prompt
|
32 |
+
self.max_words = max_words
|
33 |
+
|
34 |
+
def __call__(self, caption):
|
35 |
+
caption = self.prompt + self.pre_caption(caption)
|
36 |
+
|
37 |
+
return caption
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def from_config(cls, cfg=None):
|
41 |
+
if cfg is None:
|
42 |
+
cfg = OmegaConf.create()
|
43 |
+
|
44 |
+
prompt = cfg.get("prompt", "")
|
45 |
+
max_words = cfg.get("max_words", 50)
|
46 |
+
|
47 |
+
return cls(prompt=prompt, max_words=max_words)
|
48 |
+
|
49 |
+
def pre_caption(self, caption):
|
50 |
+
caption = re.sub(
|
51 |
+
r"([.!\"()*#:;~])",
|
52 |
+
" ",
|
53 |
+
caption.lower(),
|
54 |
+
)
|
55 |
+
caption = re.sub(
|
56 |
+
r"\s{2,}",
|
57 |
+
" ",
|
58 |
+
caption,
|
59 |
+
)
|
60 |
+
caption = caption.rstrip("\n")
|
61 |
+
caption = caption.strip(" ")
|
62 |
+
|
63 |
+
# truncate caption
|
64 |
+
caption_words = caption.split(" ")
|
65 |
+
if len(caption_words) > self.max_words:
|
66 |
+
caption = " ".join(caption_words[: self.max_words])
|
67 |
+
|
68 |
+
return caption
|
69 |
+
|
70 |
+
|
71 |
+
@registry.register_processor("blip_question")
|
72 |
+
class BlipQuestionProcessor(BaseProcessor):
|
73 |
+
def __init__(self, max_words=50):
|
74 |
+
self.max_words = max_words
|
75 |
+
|
76 |
+
def __call__(self, question):
|
77 |
+
return self.pre_question(question)
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def from_config(cls, cfg=None):
|
81 |
+
if cfg is None:
|
82 |
+
cfg = OmegaConf.create()
|
83 |
+
|
84 |
+
max_words = cfg.get("max_words", 50)
|
85 |
+
|
86 |
+
return cls(max_words=max_words)
|
87 |
+
|
88 |
+
def pre_question(self, question):
|
89 |
+
question = re.sub(
|
90 |
+
r"([.!\"()*#:;~])",
|
91 |
+
"",
|
92 |
+
question.lower(),
|
93 |
+
)
|
94 |
+
question = question.rstrip(" ")
|
95 |
+
|
96 |
+
# truncate question
|
97 |
+
question_words = question.split(" ")
|
98 |
+
if len(question_words) > self.max_words:
|
99 |
+
question = " ".join(question_words[: self.max_words])
|
100 |
+
|
101 |
+
return question
|
102 |
+
|
103 |
+
|
104 |
+
@registry.register_processor("blip_image_train")
|
105 |
+
class BlipImageTrainProcessor(BlipImageBaseProcessor):
|
106 |
+
def __init__(
|
107 |
+
self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
|
108 |
+
):
|
109 |
+
super().__init__(mean=mean, std=std)
|
110 |
+
|
111 |
+
self.transform = transforms.Compose(
|
112 |
+
[
|
113 |
+
transforms.RandomResizedCrop(
|
114 |
+
image_size,
|
115 |
+
scale=(min_scale, max_scale),
|
116 |
+
interpolation=InterpolationMode.BICUBIC,
|
117 |
+
),
|
118 |
+
transforms.RandomHorizontalFlip(),
|
119 |
+
# RandomAugment(
|
120 |
+
# 2,
|
121 |
+
# 5,
|
122 |
+
# isPIL=True,
|
123 |
+
# augs=[
|
124 |
+
# "Identity",
|
125 |
+
# "AutoContrast",
|
126 |
+
# "Brightness",
|
127 |
+
# "Sharpness",
|
128 |
+
# "Equalize",
|
129 |
+
# "ShearX",
|
130 |
+
# "ShearY",
|
131 |
+
# "TranslateX",
|
132 |
+
# "TranslateY",
|
133 |
+
# "Rotate",
|
134 |
+
# ],
|
135 |
+
# ),
|
136 |
+
transforms.ToTensor(),
|
137 |
+
self.normalize,
|
138 |
+
]
|
139 |
+
)
|
140 |
+
|
141 |
+
def __call__(self, item):
|
142 |
+
return self.transform(item)
|
143 |
+
|
144 |
+
@classmethod
|
145 |
+
def from_config(cls, cfg=None):
|
146 |
+
if cfg is None:
|
147 |
+
cfg = OmegaConf.create()
|
148 |
+
|
149 |
+
image_size = cfg.get("image_size", 384)
|
150 |
+
|
151 |
+
mean = cfg.get("mean", None)
|
152 |
+
std = cfg.get("std", None)
|
153 |
+
|
154 |
+
min_scale = cfg.get("min_scale", 0.5)
|
155 |
+
max_scale = cfg.get("max_scale", 1.0)
|
156 |
+
|
157 |
+
return cls(
|
158 |
+
image_size=image_size,
|
159 |
+
mean=mean,
|
160 |
+
std=std,
|
161 |
+
min_scale=min_scale,
|
162 |
+
max_scale=max_scale,
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
@registry.register_processor("blip_image_eval")
|
167 |
+
class BlipImageEvalProcessor(BlipImageBaseProcessor):
|
168 |
+
def __init__(self, image_size=384, mean=None, std=None):
|
169 |
+
super().__init__(mean=mean, std=std)
|
170 |
+
|
171 |
+
self.transform = transforms.Compose(
|
172 |
+
[
|
173 |
+
transforms.Resize(
|
174 |
+
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
|
175 |
+
),
|
176 |
+
transforms.ToTensor(),
|
177 |
+
self.normalize,
|
178 |
+
]
|
179 |
+
)
|
180 |
+
|
181 |
+
def __call__(self, item):
|
182 |
+
return self.transform(item)
|
183 |
+
|
184 |
+
@classmethod
|
185 |
+
def from_config(cls, cfg=None):
|
186 |
+
if cfg is None:
|
187 |
+
cfg = OmegaConf.create()
|
188 |
+
|
189 |
+
image_size = cfg.get("image_size", 384)
|
190 |
+
|
191 |
+
mean = cfg.get("mean", None)
|
192 |
+
std = cfg.get("std", None)
|
193 |
+
|
194 |
+
return cls(image_size=image_size, mean=mean, std=std)
|
195 |
+
|
196 |
+
|
197 |
+
@registry.register_processor("blip2_image_train")
|
198 |
+
class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
|
199 |
+
def __init__(
|
200 |
+
self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0
|
201 |
+
):
|
202 |
+
super().__init__(mean=mean, std=std)
|
203 |
+
|
204 |
+
self.transform = transforms.Compose(
|
205 |
+
[
|
206 |
+
transforms.RandomResizedCrop(
|
207 |
+
image_size,
|
208 |
+
scale=(min_scale, max_scale),
|
209 |
+
interpolation=InterpolationMode.BICUBIC,
|
210 |
+
),
|
211 |
+
transforms.RandomHorizontalFlip(),
|
212 |
+
transforms.ToTensor(),
|
213 |
+
self.normalize,
|
214 |
+
]
|
215 |
+
)
|
216 |
+
|
217 |
+
def __call__(self, item):
|
218 |
+
return self.transform(item)
|
219 |
+
|
220 |
+
@classmethod
|
221 |
+
def from_config(cls, cfg=None):
|
222 |
+
if cfg is None:
|
223 |
+
cfg = OmegaConf.create()
|
224 |
+
|
225 |
+
image_size = cfg.get("image_size", 364)
|
226 |
+
|
227 |
+
mean = cfg.get("mean", None)
|
228 |
+
std = cfg.get("std", None)
|
229 |
+
|
230 |
+
min_scale = cfg.get("min_scale", 0.5)
|
231 |
+
max_scale = cfg.get("max_scale", 1.0)
|
232 |
+
|
233 |
+
return cls(
|
234 |
+
image_size=image_size,
|
235 |
+
mean=mean,
|
236 |
+
std=std,
|
237 |
+
min_scale=min_scale,
|
238 |
+
max_scale=max_scale,
|
239 |
+
)
|
bliva/processors/clip_processors.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
from bliva.common.registry import registry
|
9 |
+
from bliva.processors.blip_processors import BlipImageBaseProcessor
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode
|
13 |
+
|
14 |
+
|
15 |
+
def _convert_to_rgb(image):
|
16 |
+
return image.convert("RGB")
|
17 |
+
|
18 |
+
|
19 |
+
@registry.register_processor("clip_image_train")
|
20 |
+
class ClipImageTrainProcessor(BlipImageBaseProcessor):
|
21 |
+
def __init__(
|
22 |
+
self, image_size=224, mean=None, std=None, min_scale=0.9, max_scale=1.0
|
23 |
+
):
|
24 |
+
|
25 |
+
super().__init__(mean=mean, std=std)
|
26 |
+
|
27 |
+
self.transform = transforms.Compose(
|
28 |
+
[
|
29 |
+
transforms.RandomResizedCrop(
|
30 |
+
image_size,
|
31 |
+
scale=(min_scale, max_scale),
|
32 |
+
interpolation=InterpolationMode.BICUBIC,
|
33 |
+
),
|
34 |
+
_convert_to_rgb,
|
35 |
+
transforms.ToTensor(),
|
36 |
+
self.normalize,
|
37 |
+
]
|
38 |
+
)
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def from_config(cls, cfg=None):
|
42 |
+
if cfg is None:
|
43 |
+
cfg = OmegaConf.create()
|
44 |
+
|
45 |
+
image_size = cfg.get("image_size", 224)
|
46 |
+
|
47 |
+
mean = cfg.get("mean", None)
|
48 |
+
std = cfg.get("std", None)
|
49 |
+
|
50 |
+
min_scale = cfg.get("min_scale", 0.9)
|
51 |
+
max_scale = cfg.get("max_scale", 1.0)
|
52 |
+
|
53 |
+
return cls(
|
54 |
+
image_size=image_size,
|
55 |
+
mean=mean,
|
56 |
+
std=std,
|
57 |
+
min_scale=min_scale,
|
58 |
+
max_scale=max_scale,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
@registry.register_processor("clip_image_eval")
|
63 |
+
class ClipImageEvalProcessor(BlipImageBaseProcessor):
|
64 |
+
def __init__(self, image_size=224, mean=None, std=None):
|
65 |
+
|
66 |
+
super().__init__(mean=mean, std=std)
|
67 |
+
|
68 |
+
self.transform = transforms.Compose(
|
69 |
+
[
|
70 |
+
transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
71 |
+
transforms.CenterCrop(image_size),
|
72 |
+
_convert_to_rgb,
|
73 |
+
transforms.ToTensor(),
|
74 |
+
self.normalize,
|
75 |
+
]
|
76 |
+
)
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def from_config(cls, cfg=None):
|
80 |
+
if cfg is None:
|
81 |
+
cfg = OmegaConf.create()
|
82 |
+
|
83 |
+
image_size = cfg.get("image_size", 224)
|
84 |
+
|
85 |
+
mean = cfg.get("mean", None)
|
86 |
+
std = cfg.get("std", None)
|
87 |
+
|
88 |
+
return cls(
|
89 |
+
image_size=image_size,
|
90 |
+
mean=mean,
|
91 |
+
std=std,
|
92 |
+
)
|
bliva/processors/randaugment.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
## aug functions
|
15 |
+
def identity_func(img):
|
16 |
+
return img
|
17 |
+
|
18 |
+
|
19 |
+
def autocontrast_func(img, cutoff=0):
|
20 |
+
"""
|
21 |
+
same output as PIL.ImageOps.autocontrast
|
22 |
+
"""
|
23 |
+
n_bins = 256
|
24 |
+
|
25 |
+
def tune_channel(ch):
|
26 |
+
n = ch.size
|
27 |
+
cut = cutoff * n // 100
|
28 |
+
if cut == 0:
|
29 |
+
high, low = ch.max(), ch.min()
|
30 |
+
else:
|
31 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
32 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
33 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
34 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
35 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
36 |
+
if high <= low:
|
37 |
+
table = np.arange(n_bins)
|
38 |
+
else:
|
39 |
+
scale = (n_bins - 1) / (high - low)
|
40 |
+
offset = -low * scale
|
41 |
+
table = np.arange(n_bins) * scale + offset
|
42 |
+
table[table < 0] = 0
|
43 |
+
table[table > n_bins - 1] = n_bins - 1
|
44 |
+
table = table.clip(0, 255).astype(np.uint8)
|
45 |
+
return table[ch]
|
46 |
+
|
47 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
48 |
+
out = cv2.merge(channels)
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
def equalize_func(img):
|
53 |
+
"""
|
54 |
+
same output as PIL.ImageOps.equalize
|
55 |
+
PIL's implementation is different from cv2.equalize
|
56 |
+
"""
|
57 |
+
n_bins = 256
|
58 |
+
|
59 |
+
def tune_channel(ch):
|
60 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
61 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
62 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
63 |
+
if step == 0:
|
64 |
+
return ch
|
65 |
+
n = np.empty_like(hist)
|
66 |
+
n[0] = step // 2
|
67 |
+
n[1:] = hist[:-1]
|
68 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
69 |
+
return table[ch]
|
70 |
+
|
71 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
72 |
+
out = cv2.merge(channels)
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
77 |
+
"""
|
78 |
+
like PIL, rotate by degree, not radians
|
79 |
+
"""
|
80 |
+
H, W = img.shape[0], img.shape[1]
|
81 |
+
center = W / 2, H / 2
|
82 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
83 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
def solarize_func(img, thresh=128):
|
88 |
+
"""
|
89 |
+
same output as PIL.ImageOps.posterize
|
90 |
+
"""
|
91 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
92 |
+
table = table.clip(0, 255).astype(np.uint8)
|
93 |
+
out = table[img]
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
def color_func(img, factor):
|
98 |
+
"""
|
99 |
+
same output as PIL.ImageEnhance.Color
|
100 |
+
"""
|
101 |
+
## implementation according to PIL definition, quite slow
|
102 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
103 |
+
# out = blend(degenerate, img, factor)
|
104 |
+
# M = (
|
105 |
+
# np.eye(3) * factor
|
106 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
107 |
+
# )[np.newaxis, np.newaxis, :]
|
108 |
+
M = np.float32(
|
109 |
+
[[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
|
110 |
+
) * factor + np.float32([[0.114], [0.587], [0.299]])
|
111 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
112 |
+
return out
|
113 |
+
|
114 |
+
|
115 |
+
def contrast_func(img, factor):
|
116 |
+
"""
|
117 |
+
same output as PIL.ImageEnhance.Contrast
|
118 |
+
"""
|
119 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
120 |
+
table = (
|
121 |
+
np.array([(el - mean) * factor + mean for el in range(256)])
|
122 |
+
.clip(0, 255)
|
123 |
+
.astype(np.uint8)
|
124 |
+
)
|
125 |
+
out = table[img]
|
126 |
+
return out
|
127 |
+
|
128 |
+
|
129 |
+
def brightness_func(img, factor):
|
130 |
+
"""
|
131 |
+
same output as PIL.ImageEnhance.Contrast
|
132 |
+
"""
|
133 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
134 |
+
out = table[img]
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
def sharpness_func(img, factor):
|
139 |
+
"""
|
140 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
141 |
+
areas are same
|
142 |
+
"""
|
143 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
144 |
+
kernel[1][1] = 5
|
145 |
+
kernel /= 13
|
146 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
147 |
+
if factor == 0.0:
|
148 |
+
out = degenerate
|
149 |
+
elif factor == 1.0:
|
150 |
+
out = img
|
151 |
+
else:
|
152 |
+
out = img.astype(np.float32)
|
153 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
154 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
155 |
+
out = out.astype(np.uint8)
|
156 |
+
return out
|
157 |
+
|
158 |
+
|
159 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
160 |
+
H, W = img.shape[0], img.shape[1]
|
161 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
162 |
+
out = cv2.warpAffine(
|
163 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
164 |
+
).astype(np.uint8)
|
165 |
+
return out
|
166 |
+
|
167 |
+
|
168 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
169 |
+
"""
|
170 |
+
same output as PIL.Image.transform
|
171 |
+
"""
|
172 |
+
H, W = img.shape[0], img.shape[1]
|
173 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
174 |
+
out = cv2.warpAffine(
|
175 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
176 |
+
).astype(np.uint8)
|
177 |
+
return out
|
178 |
+
|
179 |
+
|
180 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
181 |
+
"""
|
182 |
+
same output as PIL.Image.transform
|
183 |
+
"""
|
184 |
+
H, W = img.shape[0], img.shape[1]
|
185 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
186 |
+
out = cv2.warpAffine(
|
187 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
188 |
+
).astype(np.uint8)
|
189 |
+
return out
|
190 |
+
|
191 |
+
|
192 |
+
def posterize_func(img, bits):
|
193 |
+
"""
|
194 |
+
same output as PIL.ImageOps.posterize
|
195 |
+
"""
|
196 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
197 |
+
return out
|
198 |
+
|
199 |
+
|
200 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
201 |
+
H, W = img.shape[0], img.shape[1]
|
202 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
203 |
+
out = cv2.warpAffine(
|
204 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
205 |
+
).astype(np.uint8)
|
206 |
+
return out
|
207 |
+
|
208 |
+
|
209 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
210 |
+
replace = np.array(replace, dtype=np.uint8)
|
211 |
+
H, W = img.shape[0], img.shape[1]
|
212 |
+
rh, rw = np.random.random(2)
|
213 |
+
pad_size = pad_size // 2
|
214 |
+
ch, cw = int(rh * H), int(rw * W)
|
215 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
216 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
217 |
+
out = img.copy()
|
218 |
+
out[x1:x2, y1:y2, :] = replace
|
219 |
+
return out
|
220 |
+
|
221 |
+
|
222 |
+
### level to args
|
223 |
+
def enhance_level_to_args(MAX_LEVEL):
|
224 |
+
def level_to_args(level):
|
225 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
226 |
+
|
227 |
+
return level_to_args
|
228 |
+
|
229 |
+
|
230 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
231 |
+
def level_to_args(level):
|
232 |
+
level = (level / MAX_LEVEL) * 0.3
|
233 |
+
if np.random.random() > 0.5:
|
234 |
+
level = -level
|
235 |
+
return (level, replace_value)
|
236 |
+
|
237 |
+
return level_to_args
|
238 |
+
|
239 |
+
|
240 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
241 |
+
def level_to_args(level):
|
242 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
243 |
+
if np.random.random() > 0.5:
|
244 |
+
level = -level
|
245 |
+
return (level, replace_value)
|
246 |
+
|
247 |
+
return level_to_args
|
248 |
+
|
249 |
+
|
250 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
251 |
+
def level_to_args(level):
|
252 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
253 |
+
return (level, replace_value)
|
254 |
+
|
255 |
+
return level_to_args
|
256 |
+
|
257 |
+
|
258 |
+
def solarize_level_to_args(MAX_LEVEL):
|
259 |
+
def level_to_args(level):
|
260 |
+
level = int((level / MAX_LEVEL) * 256)
|
261 |
+
return (level,)
|
262 |
+
|
263 |
+
return level_to_args
|
264 |
+
|
265 |
+
|
266 |
+
def none_level_to_args(level):
|
267 |
+
return ()
|
268 |
+
|
269 |
+
|
270 |
+
def posterize_level_to_args(MAX_LEVEL):
|
271 |
+
def level_to_args(level):
|
272 |
+
level = int((level / MAX_LEVEL) * 4)
|
273 |
+
return (level,)
|
274 |
+
|
275 |
+
return level_to_args
|
276 |
+
|
277 |
+
|
278 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
279 |
+
def level_to_args(level):
|
280 |
+
level = (level / MAX_LEVEL) * 30
|
281 |
+
if np.random.random() < 0.5:
|
282 |
+
level = -level
|
283 |
+
return (level, replace_value)
|
284 |
+
|
285 |
+
return level_to_args
|
286 |
+
|
287 |
+
|
288 |
+
func_dict = {
|
289 |
+
"Identity": identity_func,
|
290 |
+
"AutoContrast": autocontrast_func,
|
291 |
+
"Equalize": equalize_func,
|
292 |
+
"Rotate": rotate_func,
|
293 |
+
"Solarize": solarize_func,
|
294 |
+
"Color": color_func,
|
295 |
+
"Contrast": contrast_func,
|
296 |
+
"Brightness": brightness_func,
|
297 |
+
"Sharpness": sharpness_func,
|
298 |
+
"ShearX": shear_x_func,
|
299 |
+
"TranslateX": translate_x_func,
|
300 |
+
"TranslateY": translate_y_func,
|
301 |
+
"Posterize": posterize_func,
|
302 |
+
"ShearY": shear_y_func,
|
303 |
+
}
|
304 |
+
|
305 |
+
translate_const = 10
|
306 |
+
MAX_LEVEL = 10
|
307 |
+
replace_value = (128, 128, 128)
|
308 |
+
arg_dict = {
|
309 |
+
"Identity": none_level_to_args,
|
310 |
+
"AutoContrast": none_level_to_args,
|
311 |
+
"Equalize": none_level_to_args,
|
312 |
+
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
|
313 |
+
"Solarize": solarize_level_to_args(MAX_LEVEL),
|
314 |
+
"Color": enhance_level_to_args(MAX_LEVEL),
|
315 |
+
"Contrast": enhance_level_to_args(MAX_LEVEL),
|
316 |
+
"Brightness": enhance_level_to_args(MAX_LEVEL),
|
317 |
+
"Sharpness": enhance_level_to_args(MAX_LEVEL),
|
318 |
+
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
|
319 |
+
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
320 |
+
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
321 |
+
"Posterize": posterize_level_to_args(MAX_LEVEL),
|
322 |
+
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
|
323 |
+
}
|
324 |
+
|
325 |
+
|
326 |
+
class RandomAugment(object):
|
327 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
328 |
+
self.N = N
|
329 |
+
self.M = M
|
330 |
+
self.isPIL = isPIL
|
331 |
+
if augs:
|
332 |
+
self.augs = augs
|
333 |
+
else:
|
334 |
+
self.augs = list(arg_dict.keys())
|
335 |
+
|
336 |
+
def get_random_ops(self):
|
337 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
338 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
339 |
+
|
340 |
+
def __call__(self, img):
|
341 |
+
if self.isPIL:
|
342 |
+
img = np.array(img)
|
343 |
+
ops = self.get_random_ops()
|
344 |
+
for name, prob, level in ops:
|
345 |
+
if np.random.random() > prob:
|
346 |
+
continue
|
347 |
+
args = arg_dict[name](level)
|
348 |
+
img = func_dict[name](img, *args)
|
349 |
+
return img
|
350 |
+
|
351 |
+
|
352 |
+
class VideoRandomAugment(object):
|
353 |
+
def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
|
354 |
+
self.N = N
|
355 |
+
self.M = M
|
356 |
+
self.p = p
|
357 |
+
self.tensor_in_tensor_out = tensor_in_tensor_out
|
358 |
+
if augs:
|
359 |
+
self.augs = augs
|
360 |
+
else:
|
361 |
+
self.augs = list(arg_dict.keys())
|
362 |
+
|
363 |
+
def get_random_ops(self):
|
364 |
+
sampled_ops = np.random.choice(self.augs, self.N, replace=False)
|
365 |
+
return [(op, self.M) for op in sampled_ops]
|
366 |
+
|
367 |
+
def __call__(self, frames):
|
368 |
+
assert (
|
369 |
+
frames.shape[-1] == 3
|
370 |
+
), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
|
371 |
+
|
372 |
+
if self.tensor_in_tensor_out:
|
373 |
+
frames = frames.numpy().astype(np.uint8)
|
374 |
+
|
375 |
+
num_frames = frames.shape[0]
|
376 |
+
|
377 |
+
ops = num_frames * [self.get_random_ops()]
|
378 |
+
apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
|
379 |
+
|
380 |
+
frames = torch.stack(
|
381 |
+
list(map(self._aug, frames, ops, apply_or_not)), dim=0
|
382 |
+
).float()
|
383 |
+
|
384 |
+
return frames
|
385 |
+
|
386 |
+
def _aug(self, img, ops, apply_or_not):
|
387 |
+
for i, (name, level) in enumerate(ops):
|
388 |
+
if not apply_or_not[i]:
|
389 |
+
continue
|
390 |
+
args = arg_dict[name](level)
|
391 |
+
img = func_dict[name](img, *args)
|
392 |
+
return torch.from_numpy(img)
|
393 |
+
|
394 |
+
|
395 |
+
if __name__ == "__main__":
|
396 |
+
a = RandomAugment()
|
397 |
+
img = np.random.randn(32, 32, 3)
|
398 |
+
a(img)
|
bliva_vicuna7b.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:512b0b5d6d98391e570ff6ee7778d7f39b267650e43098f8e6cc60539d56d037
|
3 |
+
size 15812679883
|
evaluate.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
from PIL import Image
|
5 |
+
from bliva.models import load_model_and_preprocess
|
6 |
+
|
7 |
+
def disable_torch_init():
|
8 |
+
"""
|
9 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
10 |
+
"""
|
11 |
+
import torch
|
12 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
13 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
"""
|
17 |
+
Parse arguments from command line.
|
18 |
+
"""
|
19 |
+
parser = argparse.ArgumentParser(description="Arguments for Evaluation")
|
20 |
+
parser.add_argument(
|
21 |
+
"--answer_mc",
|
22 |
+
action="store_true",
|
23 |
+
default=False,
|
24 |
+
help="Whether to evaluate multiple choice question with candidates."
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--answer_qs",
|
28 |
+
action="store_true",
|
29 |
+
default=False,
|
30 |
+
help="Whether to evaluate only one question image."
|
31 |
+
)
|
32 |
+
|
33 |
+
parser.add_argument("--model_name", type=str, default="bliva_vicuna")
|
34 |
+
parser.add_argument("--device", type=str, default="cuda:0", help="Specify which gpu device to use.")
|
35 |
+
parser.add_argument("--img_path", type=str, required=True, help="the path to the image")
|
36 |
+
parser.add_argument("--question", type=str, required=True, help="the question to ask")
|
37 |
+
parser.add_argument("--candidates", type=str, help="list of choices for mulitple choice question")
|
38 |
+
|
39 |
+
args = parser.parse_args()
|
40 |
+
return args
|
41 |
+
|
42 |
+
def eval_one(image, question, model):
|
43 |
+
"""
|
44 |
+
Evaluate one question
|
45 |
+
"""
|
46 |
+
outputs = model.generate({"image": image, "prompt": question})
|
47 |
+
print("=====================================")
|
48 |
+
print("Question:", question[0])
|
49 |
+
print("-------------------------------------")
|
50 |
+
print("Outputs: ", outputs[0])
|
51 |
+
|
52 |
+
|
53 |
+
def eval_candidates(image, question, candidates, model):
|
54 |
+
"""
|
55 |
+
Evaluate with candidates
|
56 |
+
"""
|
57 |
+
outputs = model.predict_class({"image": image, "prompt": question}, candidates)
|
58 |
+
print("=====================================")
|
59 |
+
print("Question:", question[0])
|
60 |
+
print("-------------------------------------")
|
61 |
+
print("Candidates:", candidates)
|
62 |
+
print("-------------------------------------")
|
63 |
+
print("Outputs: ", candidates[outputs[0][0]])
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
def main(args):
|
68 |
+
np.random.seed(0)
|
69 |
+
|
70 |
+
disable_torch_init()
|
71 |
+
|
72 |
+
if args.model_name == "bliva_vicuna":
|
73 |
+
model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="vicuna7b", is_eval=True, device=args.device)
|
74 |
+
if args.model_name == "bliva_flant5":
|
75 |
+
model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="flant5xxl", is_eval=True, device=args.device)
|
76 |
+
vis_processor = vis_processors["eval"]
|
77 |
+
|
78 |
+
image = Image.open(args.img_path).convert('RGB')
|
79 |
+
|
80 |
+
question = [args.question]
|
81 |
+
|
82 |
+
image = vis_processor(image).unsqueeze(0).to(args.device)
|
83 |
+
|
84 |
+
if args.answer_qs:
|
85 |
+
eval_one(image, question, model)
|
86 |
+
elif args.answer_mc:
|
87 |
+
candidates = [candidate.strip() for candidate in args.candidates.split(",")]
|
88 |
+
eval_candidates(image, question, candidates, model)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
args = parse_args()
|
93 |
+
main(args)
|
hf_vicuna_7b/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "../data/hf_llama_7b/",
|
3 |
+
"architectures": [
|
4 |
+
"LlamaForCausalLM"
|
5 |
+
],
|
6 |
+
"bos_token_id": 1,
|
7 |
+
"eos_token_id": 2,
|
8 |
+
"hidden_act": "silu",
|
9 |
+
"hidden_size": 4096,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 11008,
|
12 |
+
"max_position_embeddings": 2048,
|
13 |
+
"model_type": "llama",
|
14 |
+
"num_attention_heads": 32,
|
15 |
+
"num_hidden_layers": 32,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"rms_norm_eps": 1e-06,
|
18 |
+
"tie_word_embeddings": false,
|
19 |
+
"torch_dtype": "float16",
|
20 |
+
"transformers_version": "4.28.1",
|
21 |
+
"use_cache": true,
|
22 |
+
"vocab_size": 32000
|
23 |
+
}
|
hf_vicuna_7b/generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"pad_token_id": 0,
|
6 |
+
"transformers_version": "4.28.1"
|
7 |
+
}
|
hf_vicuna_7b/pytorch_model-00001-of-00002.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ed572be140240b212049a3e791271eed8a04c40bc732c91a4da4b1469db23b1
|
3 |
+
size 9976634558
|
hf_vicuna_7b/pytorch_model-00002-of-00002.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d9382da358f0ec38c4fca3bcf1e3e65274ae4c78090e0775b4bb4dea6a518e08
|
3 |
+
size 3500315539
|
hf_vicuna_7b/pytorch_model.bin.index.json
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 13476839424
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"lm_head.weight": "pytorch_model-00002-of-00002.bin",
|
7 |
+
"model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
|
8 |
+
"model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
9 |
+
"model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
10 |
+
"model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
11 |
+
"model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
12 |
+
"model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
13 |
+
"model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
14 |
+
"model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
15 |
+
"model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
16 |
+
"model.layers.0.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
17 |
+
"model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
18 |
+
"model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
19 |
+
"model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
20 |
+
"model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
21 |
+
"model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
22 |
+
"model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
23 |
+
"model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
24 |
+
"model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
25 |
+
"model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
26 |
+
"model.layers.1.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
27 |
+
"model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
28 |
+
"model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
29 |
+
"model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
30 |
+
"model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
31 |
+
"model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
32 |
+
"model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
33 |
+
"model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
34 |
+
"model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
35 |
+
"model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
36 |
+
"model.layers.10.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
37 |
+
"model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
38 |
+
"model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
39 |
+
"model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
40 |
+
"model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
41 |
+
"model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
42 |
+
"model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
43 |
+
"model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
44 |
+
"model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
45 |
+
"model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
46 |
+
"model.layers.11.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
47 |
+
"model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
48 |
+
"model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
49 |
+
"model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
50 |
+
"model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
51 |
+
"model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
52 |
+
"model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
53 |
+
"model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
54 |
+
"model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
55 |
+
"model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
56 |
+
"model.layers.12.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
57 |
+
"model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
58 |
+
"model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
59 |
+
"model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
60 |
+
"model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
61 |
+
"model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
62 |
+
"model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
63 |
+
"model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
64 |
+
"model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
65 |
+
"model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
66 |
+
"model.layers.13.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
67 |
+
"model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
68 |
+
"model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
69 |
+
"model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
70 |
+
"model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
71 |
+
"model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
72 |
+
"model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
73 |
+
"model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
74 |
+
"model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
75 |
+
"model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
76 |
+
"model.layers.14.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
77 |
+
"model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
78 |
+
"model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
79 |
+
"model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
80 |
+
"model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
81 |
+
"model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
82 |
+
"model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
83 |
+
"model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
84 |
+
"model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
85 |
+
"model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
86 |
+
"model.layers.15.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
87 |
+
"model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
88 |
+
"model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
89 |
+
"model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
90 |
+
"model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
91 |
+
"model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
92 |
+
"model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
93 |
+
"model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
94 |
+
"model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
95 |
+
"model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
96 |
+
"model.layers.16.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
97 |
+
"model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
98 |
+
"model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
99 |
+
"model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
100 |
+
"model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
101 |
+
"model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
102 |
+
"model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
103 |
+
"model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
104 |
+
"model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
105 |
+
"model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
106 |
+
"model.layers.17.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
107 |
+
"model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
108 |
+
"model.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
109 |
+
"model.layers.18.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
110 |
+
"model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
111 |
+
"model.layers.18.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
112 |
+
"model.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
113 |
+
"model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
114 |
+
"model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
115 |
+
"model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
116 |
+
"model.layers.18.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
117 |
+
"model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
118 |
+
"model.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
119 |
+
"model.layers.19.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
120 |
+
"model.layers.19.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
121 |
+
"model.layers.19.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
122 |
+
"model.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
123 |
+
"model.layers.19.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
124 |
+
"model.layers.19.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
125 |
+
"model.layers.19.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
126 |
+
"model.layers.19.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
127 |
+
"model.layers.19.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
128 |
+
"model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
129 |
+
"model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
130 |
+
"model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
131 |
+
"model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
132 |
+
"model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
133 |
+
"model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
134 |
+
"model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
135 |
+
"model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
136 |
+
"model.layers.2.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
137 |
+
"model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
138 |
+
"model.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
139 |
+
"model.layers.20.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
140 |
+
"model.layers.20.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
141 |
+
"model.layers.20.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
142 |
+
"model.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
143 |
+
"model.layers.20.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
144 |
+
"model.layers.20.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
145 |
+
"model.layers.20.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
146 |
+
"model.layers.20.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
147 |
+
"model.layers.20.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
148 |
+
"model.layers.21.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
149 |
+
"model.layers.21.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
150 |
+
"model.layers.21.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
151 |
+
"model.layers.21.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
152 |
+
"model.layers.21.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
153 |
+
"model.layers.21.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
154 |
+
"model.layers.21.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
155 |
+
"model.layers.21.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
156 |
+
"model.layers.21.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
157 |
+
"model.layers.21.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
158 |
+
"model.layers.22.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
159 |
+
"model.layers.22.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
160 |
+
"model.layers.22.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
161 |
+
"model.layers.22.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
162 |
+
"model.layers.22.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
163 |
+
"model.layers.22.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
164 |
+
"model.layers.22.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
165 |
+
"model.layers.22.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
166 |
+
"model.layers.22.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
167 |
+
"model.layers.22.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
168 |
+
"model.layers.23.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
169 |
+
"model.layers.23.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
170 |
+
"model.layers.23.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
171 |
+
"model.layers.23.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
172 |
+
"model.layers.23.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
173 |
+
"model.layers.23.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
174 |
+
"model.layers.23.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
175 |
+
"model.layers.23.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
176 |
+
"model.layers.23.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
177 |
+
"model.layers.23.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
178 |
+
"model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
179 |
+
"model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
180 |
+
"model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
181 |
+
"model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
182 |
+
"model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
183 |
+
"model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
184 |
+
"model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
185 |
+
"model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
186 |
+
"model.layers.24.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
187 |
+
"model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
188 |
+
"model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
189 |
+
"model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
190 |
+
"model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
191 |
+
"model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
192 |
+
"model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
193 |
+
"model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
194 |
+
"model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
195 |
+
"model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
196 |
+
"model.layers.25.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
197 |
+
"model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
198 |
+
"model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
199 |
+
"model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
200 |
+
"model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
201 |
+
"model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
202 |
+
"model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
203 |
+
"model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
204 |
+
"model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
205 |
+
"model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
206 |
+
"model.layers.26.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
207 |
+
"model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
208 |
+
"model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
209 |
+
"model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
210 |
+
"model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
211 |
+
"model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
212 |
+
"model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
213 |
+
"model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
214 |
+
"model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
215 |
+
"model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
216 |
+
"model.layers.27.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
217 |
+
"model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
218 |
+
"model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
219 |
+
"model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
220 |
+
"model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
221 |
+
"model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
222 |
+
"model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
223 |
+
"model.layers.28.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
224 |
+
"model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
225 |
+
"model.layers.28.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
226 |
+
"model.layers.28.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
227 |
+
"model.layers.28.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
228 |
+
"model.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
229 |
+
"model.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
230 |
+
"model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
231 |
+
"model.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
232 |
+
"model.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
233 |
+
"model.layers.29.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
234 |
+
"model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
235 |
+
"model.layers.29.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
236 |
+
"model.layers.29.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
237 |
+
"model.layers.29.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
238 |
+
"model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
239 |
+
"model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
240 |
+
"model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
241 |
+
"model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
242 |
+
"model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
243 |
+
"model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
244 |
+
"model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
245 |
+
"model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
246 |
+
"model.layers.3.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
247 |
+
"model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
248 |
+
"model.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
249 |
+
"model.layers.30.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
250 |
+
"model.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
251 |
+
"model.layers.30.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
252 |
+
"model.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
253 |
+
"model.layers.30.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
254 |
+
"model.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
255 |
+
"model.layers.30.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
256 |
+
"model.layers.30.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
257 |
+
"model.layers.30.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
258 |
+
"model.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
259 |
+
"model.layers.31.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
260 |
+
"model.layers.31.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
261 |
+
"model.layers.31.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
262 |
+
"model.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
263 |
+
"model.layers.31.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
264 |
+
"model.layers.31.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
265 |
+
"model.layers.31.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
266 |
+
"model.layers.31.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
267 |
+
"model.layers.31.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
268 |
+
"model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
269 |
+
"model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
270 |
+
"model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
271 |
+
"model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
272 |
+
"model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
273 |
+
"model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
274 |
+
"model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
275 |
+
"model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
276 |
+
"model.layers.4.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
277 |
+
"model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
278 |
+
"model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
279 |
+
"model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
280 |
+
"model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
281 |
+
"model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
282 |
+
"model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
283 |
+
"model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
284 |
+
"model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
285 |
+
"model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
286 |
+
"model.layers.5.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
287 |
+
"model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
288 |
+
"model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
289 |
+
"model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
290 |
+
"model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
291 |
+
"model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
292 |
+
"model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
293 |
+
"model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
294 |
+
"model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
295 |
+
"model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
296 |
+
"model.layers.6.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
297 |
+
"model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
298 |
+
"model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
299 |
+
"model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
300 |
+
"model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
301 |
+
"model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
302 |
+
"model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
303 |
+
"model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
304 |
+
"model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
305 |
+
"model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
306 |
+
"model.layers.7.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
307 |
+
"model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
308 |
+
"model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
309 |
+
"model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
310 |
+
"model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
311 |
+
"model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
312 |
+
"model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
313 |
+
"model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
314 |
+
"model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
315 |
+
"model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
316 |
+
"model.layers.8.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
317 |
+
"model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
318 |
+
"model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
319 |
+
"model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
320 |
+
"model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
321 |
+
"model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
322 |
+
"model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
323 |
+
"model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
324 |
+
"model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
325 |
+
"model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
326 |
+
"model.layers.9.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
327 |
+
"model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
328 |
+
"model.norm.weight": "pytorch_model-00002-of-00002.bin"
|
329 |
+
}
|
330 |
+
}
|
hf_vicuna_7b/special_tokens_map.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"unk_token": {
|
17 |
+
"content": "<unk>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": true,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
}
|
23 |
+
}
|
hf_vicuna_7b/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
3 |
+
size 499723
|
hf_vicuna_7b/tokenizer_config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"bos_token": {
|
5 |
+
"__type": "AddedToken",
|
6 |
+
"content": "<s>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": true,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"clean_up_tokenization_spaces": false,
|
13 |
+
"eos_token": {
|
14 |
+
"__type": "AddedToken",
|
15 |
+
"content": "</s>",
|
16 |
+
"lstrip": false,
|
17 |
+
"normalized": true,
|
18 |
+
"rstrip": false,
|
19 |
+
"single_word": false
|
20 |
+
},
|
21 |
+
"model_max_length": 1000000000000000019884624838656,
|
22 |
+
"pad_token": null,
|
23 |
+
"sp_model_kwargs": {},
|
24 |
+
"tokenizer_class": "LlamaTokenizer",
|
25 |
+
"unk_token": {
|
26 |
+
"__type": "AddedToken",
|
27 |
+
"content": "<unk>",
|
28 |
+
"lstrip": false,
|
29 |
+
"normalized": true,
|
30 |
+
"rstrip": false,
|
31 |
+
"single_word": false
|
32 |
+
}
|
33 |
+
}
|
images/example.jpg
ADDED
images/img1.jpg
ADDED
images/img2.jpg
ADDED
images/img3.jpg
ADDED
images/img4.jpg
ADDED
images/img5.jpg
ADDED
images/img6.jpg
ADDED