Upload 235 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- mllm/__init__.py +0 -0
- mllm/__pycache__/__init__.cpython-310.pyc +0 -0
- mllm/config/__init__.py +1 -0
- mllm/config/__pycache__/__init__.cpython-310.pyc +0 -0
- mllm/config/__pycache__/config.cpython-310.pyc +0 -0
- mllm/config/config.py +135 -0
- mllm/conversation/__init__.py +1 -0
- mllm/conversation/__pycache__/__init__.cpython-310.pyc +0 -0
- mllm/conversation/__pycache__/base_conversation.cpython-310.pyc +0 -0
- mllm/conversation/base_conversation.py +503 -0
- mllm/dataset/__init__.py +7 -0
- mllm/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
- mllm/dataset/__pycache__/builder.cpython-310.pyc +0 -0
- mllm/dataset/__pycache__/root.cpython-310.pyc +0 -0
- mllm/dataset/__pycache__/single_image_convsation.cpython-310.pyc +0 -0
- mllm/dataset/__pycache__/single_image_interactive.cpython-310.pyc +0 -0
- mllm/dataset/builder.py +118 -0
- mllm/dataset/process_function/__init__.py +13 -0
- mllm/dataset/process_function/__pycache__/__init__.cpython-310.pyc +0 -0
- mllm/dataset/process_function/__pycache__/box_process_function.cpython-310.pyc +0 -0
- mllm/dataset/process_function/__pycache__/shikra_process_function.cpython-310.pyc +0 -0
- mllm/dataset/process_function/box_process_function.py +326 -0
- mllm/dataset/process_function/shikra_process_function.py +178 -0
- mllm/dataset/root.py +67 -0
- mllm/dataset/single_image_convsation.py +284 -0
- mllm/dataset/single_image_dataset/__init__.py +13 -0
- mllm/dataset/single_image_dataset/__pycache__/__init__.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/caption.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/clevr.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/flickr.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/gpt_gen.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/gqa.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/instr.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/point_qa.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/pope.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/rec.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/reg.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/vcr.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/vqaex.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/__pycache__/vqav2.cpython-310.pyc +0 -0
- mllm/dataset/single_image_dataset/caption.py +31 -0
- mllm/dataset/single_image_dataset/clevr.py +116 -0
- mllm/dataset/single_image_dataset/flickr.py +68 -0
- mllm/dataset/single_image_dataset/gpt_gen.py +58 -0
- mllm/dataset/single_image_dataset/gqa.py +233 -0
- mllm/dataset/single_image_dataset/instr.py +24 -0
- mllm/dataset/single_image_dataset/point_qa.py +247 -0
- mllm/dataset/single_image_dataset/pope.py +36 -0
- mllm/dataset/single_image_dataset/rec.py +128 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
mllm/demo/assets/baseball.png filter=lfs diff=lfs merge=lfs -text
|
mllm/__init__.py
ADDED
File without changes
|
mllm/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (137 Bytes). View file
|
|
mllm/config/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .config import prepare_args
|
mllm/config/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (187 Bytes). View file
|
|
mllm/config/__pycache__/config.cpython-310.pyc
ADDED
Binary file (4.24 kB). View file
|
|
mllm/config/config.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
import argparse
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from typing import List, Tuple
|
7 |
+
from argparse import SUPPRESS
|
8 |
+
|
9 |
+
import datasets
|
10 |
+
import transformers
|
11 |
+
from mmengine.config import Config, DictAction
|
12 |
+
from transformers import HfArgumentParser, set_seed, add_start_docstrings
|
13 |
+
from transformers import Seq2SeqTrainingArguments as HFSeq2SeqTrainingArguments
|
14 |
+
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
logger.setLevel(logging.INFO)
|
18 |
+
logging.basicConfig(
|
19 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
20 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
21 |
+
handlers=[logging.StreamHandler(sys.stdout), ],
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
@add_start_docstrings(HFSeq2SeqTrainingArguments.__doc__)
|
27 |
+
class Seq2SeqTrainingArguments(HFSeq2SeqTrainingArguments):
|
28 |
+
do_multi_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the multi-test set."})
|
29 |
+
|
30 |
+
|
31 |
+
def prepare_args(args=None):
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
parser.add_argument('config', help='train config file path')
|
34 |
+
parser.add_argument(
|
35 |
+
'--cfg-options',
|
36 |
+
nargs='+',
|
37 |
+
action=DictAction,
|
38 |
+
help='override some settings in the used config, the key-value pair '
|
39 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
40 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
41 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
42 |
+
'Note that the quotation marks are necessary and that no white space '
|
43 |
+
'is allowed.')
|
44 |
+
|
45 |
+
hf_parser = HfArgumentParser((Seq2SeqTrainingArguments,))
|
46 |
+
hf_parser, required = block_required_error(hf_parser)
|
47 |
+
|
48 |
+
args, unknown_args = parser.parse_known_args(args)
|
49 |
+
known_hf_args, unknown_args = hf_parser.parse_known_args(unknown_args)
|
50 |
+
if unknown_args:
|
51 |
+
raise ValueError(f"Some specified arguments are not used "
|
52 |
+
f"by the ArgumentParser or HfArgumentParser\n: {unknown_args}")
|
53 |
+
|
54 |
+
# load 'cfg' and 'training_args' from file and cli
|
55 |
+
cfg = Config.fromfile(args.config)
|
56 |
+
if args.cfg_options is not None:
|
57 |
+
cfg.merge_from_dict(args.cfg_options)
|
58 |
+
training_args = cfg.training_args
|
59 |
+
training_args.update(vars(known_hf_args))
|
60 |
+
|
61 |
+
# check training_args require
|
62 |
+
req_but_not_assign = [item for item in required if item not in training_args]
|
63 |
+
if req_but_not_assign:
|
64 |
+
raise ValueError(f"Requires {req_but_not_assign} but not assign.")
|
65 |
+
|
66 |
+
# update cfg.training_args
|
67 |
+
cfg.training_args = training_args
|
68 |
+
|
69 |
+
# initialize and return
|
70 |
+
training_args = Seq2SeqTrainingArguments(**training_args)
|
71 |
+
training_args = check_output_dir(training_args)
|
72 |
+
|
73 |
+
# logging
|
74 |
+
if is_main_process(training_args.local_rank):
|
75 |
+
to_logging_cfg = Config()
|
76 |
+
to_logging_cfg.model_args = cfg.model_args
|
77 |
+
to_logging_cfg.data_args = cfg.data_args
|
78 |
+
to_logging_cfg.training_args = cfg.training_args
|
79 |
+
logger.info(to_logging_cfg.pretty_text)
|
80 |
+
|
81 |
+
# setup logger
|
82 |
+
if training_args.should_log:
|
83 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
84 |
+
transformers.logging.set_verbosity_info()
|
85 |
+
log_level = training_args.get_process_log_level()
|
86 |
+
logger.setLevel(log_level)
|
87 |
+
datasets.utils.logging.set_verbosity(log_level)
|
88 |
+
transformers.logging.set_verbosity(log_level)
|
89 |
+
transformers.logging.enable_default_handler()
|
90 |
+
transformers.logging.enable_explicit_format()
|
91 |
+
# setup_print_for_distributed(is_main_process(training_args))
|
92 |
+
|
93 |
+
# Log on each process the small summary:
|
94 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
95 |
+
logger.warning(
|
96 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
97 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, fp16 training: {training_args.fp16}"
|
98 |
+
)
|
99 |
+
|
100 |
+
# Set seed before initializing model.
|
101 |
+
set_seed(training_args.seed)
|
102 |
+
|
103 |
+
return cfg, training_args
|
104 |
+
|
105 |
+
|
106 |
+
def block_required_error(hf_parser: HfArgumentParser) -> Tuple[HfArgumentParser, List]:
|
107 |
+
required = []
|
108 |
+
# noinspection PyProtectedMember
|
109 |
+
for action in hf_parser._actions:
|
110 |
+
if action.required:
|
111 |
+
required.append(action.dest)
|
112 |
+
action.required = False
|
113 |
+
action.default = SUPPRESS
|
114 |
+
return hf_parser, required
|
115 |
+
|
116 |
+
|
117 |
+
def check_output_dir(training_args):
|
118 |
+
# Detecting last checkpoint.
|
119 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
120 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
121 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
122 |
+
raise ValueError(
|
123 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
124 |
+
"Use --overwrite_output_dir to overcome."
|
125 |
+
)
|
126 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
127 |
+
logger.info(
|
128 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
129 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
130 |
+
)
|
131 |
+
return training_args
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
_ = prepare_args()
|
mllm/conversation/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base_conversation import SeparatorStyle, Conversation, register_conv_template, get_conv_template
|
mllm/conversation/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (290 Bytes). View file
|
|
mllm/conversation/__pycache__/base_conversation.cpython-310.pyc
ADDED
Binary file (11.4 kB). View file
|
|
mllm/conversation/base_conversation.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
2 |
+
"""
|
3 |
+
Conversation prompt templates.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import dataclasses
|
7 |
+
from enum import auto, Enum
|
8 |
+
from typing import List, Tuple, Any, Dict
|
9 |
+
|
10 |
+
|
11 |
+
class SeparatorStyle(Enum):
|
12 |
+
"""Separator styles."""
|
13 |
+
|
14 |
+
ADD_COLON_SINGLE = auto()
|
15 |
+
ADD_COLON_TWO = auto()
|
16 |
+
ADD_SPACE_TWO = auto()
|
17 |
+
NO_COLON_SINGLE = auto()
|
18 |
+
BAIZE = auto()
|
19 |
+
DOLLY = auto()
|
20 |
+
RWKV = auto()
|
21 |
+
PHOENIX = auto()
|
22 |
+
NEW_LINE = auto()
|
23 |
+
BILLA = auto()
|
24 |
+
|
25 |
+
|
26 |
+
@dataclasses.dataclass
|
27 |
+
class Conversation:
|
28 |
+
"""A class that keeps all conversation history."""
|
29 |
+
|
30 |
+
# The name of this template
|
31 |
+
name: str
|
32 |
+
# System prompts
|
33 |
+
system: str
|
34 |
+
# Two roles
|
35 |
+
roles: List[str]
|
36 |
+
# All messages
|
37 |
+
messages: List[List[str]]
|
38 |
+
# Offset of few shot examples
|
39 |
+
offset: int
|
40 |
+
# Separators
|
41 |
+
sep_style: SeparatorStyle
|
42 |
+
sep: str
|
43 |
+
sep2: str = None
|
44 |
+
# Stop criteria (the default one is EOS token)
|
45 |
+
stop_str: str = None
|
46 |
+
# Stops generation if meeting any token in this list
|
47 |
+
stop_token_ids: List[int] = None
|
48 |
+
|
49 |
+
# Used for the state in the gradio servers.
|
50 |
+
# TODO(lmzheng): move this out of this class.
|
51 |
+
conv_id: Any = None
|
52 |
+
skip_next: bool = False
|
53 |
+
model_name: str = None
|
54 |
+
|
55 |
+
def get_prompt(self) -> str:
|
56 |
+
"""Get the prompt for generation."""
|
57 |
+
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
58 |
+
ret = self.system + self.sep
|
59 |
+
for role, message in self.messages:
|
60 |
+
if message:
|
61 |
+
ret += role + ": " + message + self.sep
|
62 |
+
else:
|
63 |
+
ret += role + ":"
|
64 |
+
return ret
|
65 |
+
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
66 |
+
seps = [self.sep, self.sep2]
|
67 |
+
ret = self.system + seps[0]
|
68 |
+
for i, (role, message) in enumerate(self.messages):
|
69 |
+
if message:
|
70 |
+
ret += role + ": " + message + seps[i % 2]
|
71 |
+
else:
|
72 |
+
ret += role + ":"
|
73 |
+
return ret
|
74 |
+
elif self.sep_style == SeparatorStyle.ADD_SPACE_TWO:
|
75 |
+
seps = [self.sep, self.sep2]
|
76 |
+
ret = self.system + seps[0]
|
77 |
+
for i, (role, message) in enumerate(self.messages):
|
78 |
+
if message:
|
79 |
+
ret += role + " " + message + seps[i % 2]
|
80 |
+
else:
|
81 |
+
ret += role + ""
|
82 |
+
return ret
|
83 |
+
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
84 |
+
ret = self.system
|
85 |
+
for role, message in self.messages:
|
86 |
+
if message:
|
87 |
+
ret += role + message + self.sep
|
88 |
+
else:
|
89 |
+
ret += role
|
90 |
+
return ret
|
91 |
+
elif self.sep_style == SeparatorStyle.BAIZE:
|
92 |
+
ret = self.system + "\n"
|
93 |
+
for role, message in self.messages:
|
94 |
+
if message:
|
95 |
+
ret += role + message + "\n"
|
96 |
+
else:
|
97 |
+
ret += role
|
98 |
+
return ret
|
99 |
+
elif self.sep_style == SeparatorStyle.DOLLY:
|
100 |
+
seps = [self.sep, self.sep2]
|
101 |
+
ret = self.system
|
102 |
+
for i, (role, message) in enumerate(self.messages):
|
103 |
+
if message:
|
104 |
+
ret += role + ":\n" + message + seps[i % 2]
|
105 |
+
if i % 2 == 1:
|
106 |
+
ret += "\n\n"
|
107 |
+
else:
|
108 |
+
ret += role + ":\n"
|
109 |
+
return ret
|
110 |
+
elif self.sep_style == SeparatorStyle.RWKV:
|
111 |
+
ret = self.system
|
112 |
+
for i, (role, message) in enumerate(self.messages):
|
113 |
+
if message:
|
114 |
+
ret += (
|
115 |
+
role
|
116 |
+
+ ": "
|
117 |
+
+ message.replace("\r\n", "\n").replace("\n\n", "\n")
|
118 |
+
)
|
119 |
+
ret += "\n\n"
|
120 |
+
else:
|
121 |
+
ret += role + ":"
|
122 |
+
return ret
|
123 |
+
elif self.sep_style == SeparatorStyle.PHOENIX:
|
124 |
+
ret = self.system
|
125 |
+
for role, message in self.messages:
|
126 |
+
if message:
|
127 |
+
ret += role + ": " + "<s>" + message + "</s>"
|
128 |
+
else:
|
129 |
+
ret += role + ": " + "<s>"
|
130 |
+
return ret
|
131 |
+
elif self.sep_style == SeparatorStyle.NEW_LINE:
|
132 |
+
ret = self.system + self.sep
|
133 |
+
for role, message in self.messages:
|
134 |
+
if message:
|
135 |
+
ret += role + "\n" + message + self.sep
|
136 |
+
else:
|
137 |
+
ret += role + "\n"
|
138 |
+
return ret
|
139 |
+
elif self.sep_style == SeparatorStyle.BILLA:
|
140 |
+
ret = self.system + self.sep
|
141 |
+
for role, message in self.messages:
|
142 |
+
if message:
|
143 |
+
ret += role + ": " + message + self.sep
|
144 |
+
else:
|
145 |
+
ret += role + ": " # must be end with a space
|
146 |
+
return ret
|
147 |
+
else:
|
148 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
149 |
+
|
150 |
+
def append_message(self, role: str, message: str):
|
151 |
+
"""Append a new message."""
|
152 |
+
self.messages.append([role, message])
|
153 |
+
|
154 |
+
def to_gradio_chatbot(self):
|
155 |
+
"""Convert the history to gradio chatbot format"""
|
156 |
+
ret = []
|
157 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
158 |
+
if i % 2 == 0:
|
159 |
+
ret.append([msg, None])
|
160 |
+
else:
|
161 |
+
ret[-1][-1] = msg
|
162 |
+
return ret
|
163 |
+
|
164 |
+
def to_openai_api_messages(self):
|
165 |
+
"""Convert the conversation to OpenAI chat completion format."""
|
166 |
+
ret = [{"role": "system", "content": self.system}]
|
167 |
+
|
168 |
+
for i, (_, msg) in enumerate(self.messages[self.offset:]):
|
169 |
+
if i % 2 == 0:
|
170 |
+
ret.append({"role": "user", "content": msg})
|
171 |
+
else:
|
172 |
+
if msg is not None:
|
173 |
+
ret.append({"role": "assistant", "content": msg})
|
174 |
+
return ret
|
175 |
+
|
176 |
+
def copy(self):
|
177 |
+
return Conversation(
|
178 |
+
name=self.name,
|
179 |
+
system=self.system,
|
180 |
+
roles=self.roles,
|
181 |
+
messages=[[x, y] for x, y in self.messages],
|
182 |
+
offset=self.offset,
|
183 |
+
sep_style=self.sep_style,
|
184 |
+
sep=self.sep,
|
185 |
+
sep2=self.sep2,
|
186 |
+
stop_str=self.stop_str,
|
187 |
+
stop_token_ids=self.stop_token_ids,
|
188 |
+
conv_id=self.conv_id,
|
189 |
+
model_name=self.model_name,
|
190 |
+
)
|
191 |
+
|
192 |
+
def dict(self):
|
193 |
+
return {
|
194 |
+
"name": self.name,
|
195 |
+
"system": self.system,
|
196 |
+
"roles": self.roles,
|
197 |
+
"messages": self.messages,
|
198 |
+
"offset": self.offset,
|
199 |
+
"conv_id": self.conv_id,
|
200 |
+
"model_name": self.model_name,
|
201 |
+
}
|
202 |
+
|
203 |
+
|
204 |
+
# A global registry for all conversation templates
|
205 |
+
conv_templates: Dict[str, Conversation] = {}
|
206 |
+
|
207 |
+
|
208 |
+
def register_conv_template(template: Conversation, override: bool = False):
|
209 |
+
"""Register a new conversation template."""
|
210 |
+
if not override:
|
211 |
+
assert template.name not in conv_templates, f"{template.name} has been registered."
|
212 |
+
conv_templates[template.name] = template
|
213 |
+
|
214 |
+
|
215 |
+
def get_conv_template(name: str) -> Conversation:
|
216 |
+
"""Get a conversation template."""
|
217 |
+
return conv_templates[name].copy()
|
218 |
+
|
219 |
+
|
220 |
+
# A template with one conversation example
|
221 |
+
register_conv_template(
|
222 |
+
Conversation(
|
223 |
+
name="one_shot",
|
224 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
225 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
226 |
+
roles=("Human", "Assistant"),
|
227 |
+
messages=(
|
228 |
+
(
|
229 |
+
"Human",
|
230 |
+
"What are the key differences between renewable and non-renewable energy sources?",
|
231 |
+
),
|
232 |
+
(
|
233 |
+
"Assistant",
|
234 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
235 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
236 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
237 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
238 |
+
"renewable and non-renewable energy sources:\n"
|
239 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
240 |
+
"energy sources are finite and will eventually run out.\n"
|
241 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
242 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
243 |
+
"and other negative effects.\n"
|
244 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
245 |
+
"have lower operational costs than non-renewable sources.\n"
|
246 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
247 |
+
"locations than non-renewable sources.\n"
|
248 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
249 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
250 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
251 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.",
|
252 |
+
),
|
253 |
+
),
|
254 |
+
offset=2,
|
255 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
256 |
+
sep="\n### ",
|
257 |
+
stop_str="###",
|
258 |
+
)
|
259 |
+
)
|
260 |
+
|
261 |
+
# Vicuna v1.1 template
|
262 |
+
register_conv_template(
|
263 |
+
Conversation(
|
264 |
+
name="vicuna_v1.1",
|
265 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
266 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
267 |
+
roles=("USER", "ASSISTANT"),
|
268 |
+
messages=(),
|
269 |
+
offset=0,
|
270 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
271 |
+
sep=" ",
|
272 |
+
sep2="</s>",
|
273 |
+
)
|
274 |
+
)
|
275 |
+
|
276 |
+
# Koala default template
|
277 |
+
register_conv_template(
|
278 |
+
Conversation(
|
279 |
+
name="koala_v1",
|
280 |
+
system="BEGINNING OF CONVERSATION:",
|
281 |
+
roles=("USER", "GPT"),
|
282 |
+
messages=(),
|
283 |
+
offset=0,
|
284 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
285 |
+
sep=" ",
|
286 |
+
sep2="</s>",
|
287 |
+
)
|
288 |
+
)
|
289 |
+
|
290 |
+
# Dolly V2 default template
|
291 |
+
register_conv_template(
|
292 |
+
Conversation(
|
293 |
+
name="dolly_v2",
|
294 |
+
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
|
295 |
+
roles=("### Instruction", "### Response"),
|
296 |
+
messages=(),
|
297 |
+
offset=0,
|
298 |
+
sep_style=SeparatorStyle.DOLLY,
|
299 |
+
sep="\n\n",
|
300 |
+
sep2="### End",
|
301 |
+
)
|
302 |
+
)
|
303 |
+
|
304 |
+
# OpenAssistant Pythia default template
|
305 |
+
register_conv_template(
|
306 |
+
Conversation(
|
307 |
+
name="oasst_pythia",
|
308 |
+
system="",
|
309 |
+
roles=("<|prompter|>", "<|assistant|>"),
|
310 |
+
messages=(),
|
311 |
+
offset=0,
|
312 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
313 |
+
sep="<|endoftext|>",
|
314 |
+
)
|
315 |
+
)
|
316 |
+
|
317 |
+
# StableLM Alpha default template
|
318 |
+
register_conv_template(
|
319 |
+
Conversation(
|
320 |
+
name="stablelm",
|
321 |
+
system="""<|SYSTEM|># StableLM Tuned (Alpha version)
|
322 |
+
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
323 |
+
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
324 |
+
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
325 |
+
- StableLM will refuse to participate in anything that could harm a human.
|
326 |
+
""",
|
327 |
+
roles=("<|USER|>", "<|ASSISTANT|>"),
|
328 |
+
messages=(),
|
329 |
+
offset=0,
|
330 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
331 |
+
sep="",
|
332 |
+
stop_token_ids=[50278, 50279, 50277, 1, 0],
|
333 |
+
)
|
334 |
+
)
|
335 |
+
|
336 |
+
# Baize default template
|
337 |
+
register_conv_template(
|
338 |
+
Conversation(
|
339 |
+
name="baize",
|
340 |
+
system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.",
|
341 |
+
roles=("[|Human|]", "[|AI|]"),
|
342 |
+
messages=(
|
343 |
+
("[|Human|]", "Hello!"),
|
344 |
+
("[|AI|]", "Hi!"),
|
345 |
+
),
|
346 |
+
offset=2,
|
347 |
+
sep_style=SeparatorStyle.BAIZE,
|
348 |
+
sep="[|Human|]",
|
349 |
+
stop_str="[|Human|]",
|
350 |
+
)
|
351 |
+
)
|
352 |
+
|
353 |
+
# RWKV-4-Raven default template
|
354 |
+
register_conv_template(
|
355 |
+
Conversation(
|
356 |
+
name="rwkv",
|
357 |
+
system="The following is a coherent verbose detailed conversation between Bob and Alice.\n\n",
|
358 |
+
roles=("Bob", "Alice"),
|
359 |
+
messages=(
|
360 |
+
("Bob", "Hi"),
|
361 |
+
(
|
362 |
+
"Alice",
|
363 |
+
"Hi. I am your assistant and I will answer all questions. Please feel free to ask any question and I will always answer it.",
|
364 |
+
),
|
365 |
+
),
|
366 |
+
offset=2,
|
367 |
+
sep_style=SeparatorStyle.RWKV,
|
368 |
+
sep="",
|
369 |
+
stop_str="\n\n",
|
370 |
+
)
|
371 |
+
)
|
372 |
+
|
373 |
+
# Buddy default template
|
374 |
+
register_conv_template(
|
375 |
+
Conversation(
|
376 |
+
name="openbuddy",
|
377 |
+
system="""Consider a conversation between User (a human) and Assistant (named Buddy).
|
378 |
+
Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
|
379 |
+
Buddy cannot access the Internet.
|
380 |
+
Buddy can fluently speak the user's language (e.g. English, Chinese).
|
381 |
+
Buddy can generate poems, stories, code, essays, songs, parodies, and more.
|
382 |
+
Buddy possesses vast knowledge about the world, history, and culture.
|
383 |
+
Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
|
384 |
+
Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
|
385 |
+
|
386 |
+
User: Hi.
|
387 |
+
Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
|
388 |
+
roles=("User", "Assistant"),
|
389 |
+
messages=(),
|
390 |
+
offset=0,
|
391 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
392 |
+
sep="\n",
|
393 |
+
)
|
394 |
+
)
|
395 |
+
|
396 |
+
# Phoenix default template
|
397 |
+
register_conv_template(
|
398 |
+
Conversation(
|
399 |
+
name="phoenix",
|
400 |
+
system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
401 |
+
roles=("Human", "Assistant"),
|
402 |
+
messages=(),
|
403 |
+
offset=0,
|
404 |
+
sep_style=SeparatorStyle.PHOENIX,
|
405 |
+
sep="</s>",
|
406 |
+
)
|
407 |
+
)
|
408 |
+
|
409 |
+
# ChatGPT default template
|
410 |
+
register_conv_template(
|
411 |
+
Conversation(
|
412 |
+
name="chatgpt",
|
413 |
+
system="You are a helpful assistant.",
|
414 |
+
roles=("user", "assistant"),
|
415 |
+
messages=(),
|
416 |
+
offset=0,
|
417 |
+
sep_style=None,
|
418 |
+
sep=None,
|
419 |
+
)
|
420 |
+
)
|
421 |
+
|
422 |
+
# Claude default template
|
423 |
+
register_conv_template(
|
424 |
+
Conversation(
|
425 |
+
name="claude",
|
426 |
+
system="",
|
427 |
+
roles=("Human", "Assistant"),
|
428 |
+
messages=(),
|
429 |
+
offset=0,
|
430 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
431 |
+
sep="\n\n",
|
432 |
+
)
|
433 |
+
)
|
434 |
+
|
435 |
+
# MPT default template
|
436 |
+
register_conv_template(
|
437 |
+
Conversation(
|
438 |
+
name="mpt",
|
439 |
+
system="""<|im_start|>system
|
440 |
+
- You are a helpful assistant chatbot trained by MosaicML.
|
441 |
+
- You answer questions.
|
442 |
+
- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
443 |
+
- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.
|
444 |
+
""",
|
445 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
446 |
+
messages=(),
|
447 |
+
offset=0,
|
448 |
+
sep_style=SeparatorStyle.NEW_LINE,
|
449 |
+
sep="<|im_end|>",
|
450 |
+
stop_token_ids=[50278, 0],
|
451 |
+
)
|
452 |
+
)
|
453 |
+
|
454 |
+
# Bard default template
|
455 |
+
# Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150
|
456 |
+
# https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40
|
457 |
+
register_conv_template(
|
458 |
+
Conversation(
|
459 |
+
name="bard",
|
460 |
+
system="",
|
461 |
+
roles=("0", "1"),
|
462 |
+
messages=(),
|
463 |
+
offset=0,
|
464 |
+
sep_style=None,
|
465 |
+
sep=None,
|
466 |
+
)
|
467 |
+
)
|
468 |
+
|
469 |
+
# BiLLa default template
|
470 |
+
register_conv_template(
|
471 |
+
Conversation(
|
472 |
+
name="billa",
|
473 |
+
system="",
|
474 |
+
roles=("Human", "Assistant"),
|
475 |
+
messages=(),
|
476 |
+
offset=0,
|
477 |
+
sep_style=SeparatorStyle.BILLA,
|
478 |
+
sep="\n",
|
479 |
+
stop_str="Human:",
|
480 |
+
)
|
481 |
+
)
|
482 |
+
|
483 |
+
# custom otter template
|
484 |
+
register_conv_template(
|
485 |
+
Conversation(
|
486 |
+
name='otter',
|
487 |
+
system='',
|
488 |
+
roles=('User:', 'GPT:<answer>'),
|
489 |
+
messages=(),
|
490 |
+
offset=0,
|
491 |
+
sep_style=SeparatorStyle.ADD_SPACE_TWO,
|
492 |
+
sep=' ',
|
493 |
+
sep2='<|endofchunk|>',
|
494 |
+
)
|
495 |
+
)
|
496 |
+
|
497 |
+
if __name__ == "__main__":
|
498 |
+
conv = get_conv_template("vicuna_v1.1")
|
499 |
+
conv.append_message(conv.roles[0], "Hello!")
|
500 |
+
conv.append_message(conv.roles[1], "Hi!")
|
501 |
+
conv.append_message(conv.roles[0], "How are you?")
|
502 |
+
conv.append_message(conv.roles[1], None)
|
503 |
+
print(conv.get_prompt())
|
mllm/dataset/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .root import *
|
2 |
+
from .utils import *
|
3 |
+
from .process_function import *
|
4 |
+
from .single_image_convsation import *
|
5 |
+
from .single_image_dataset import *
|
6 |
+
|
7 |
+
from .builder import prepare_data
|
mllm/dataset/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (322 Bytes). View file
|
|
mllm/dataset/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (2.96 kB). View file
|
|
mllm/dataset/__pycache__/root.cpython-310.pyc
ADDED
Binary file (2.42 kB). View file
|
|
mllm/dataset/__pycache__/single_image_convsation.cpython-310.pyc
ADDED
Binary file (11 kB). View file
|
|
mllm/dataset/__pycache__/single_image_interactive.cpython-310.pyc
ADDED
Binary file (4.15 kB). View file
|
|
mllm/dataset/builder.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Callable, Dict, Tuple, Any, Optional
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from transformers import EvalPrediction, TrainingArguments
|
6 |
+
|
7 |
+
from .root import DATASETS, METRICS, TRANSFORMS, FUNCTIONS
|
8 |
+
from .single_image_convsation import SingleImageConvDataset
|
9 |
+
from .single_image_interactive import SingleImageInteractive
|
10 |
+
from ..conversation import get_conv_template
|
11 |
+
from .utils import init_ceph_client_if_needed
|
12 |
+
|
13 |
+
DatasetDict = Dict[str, Dataset]
|
14 |
+
ComputeMetrics = Callable[[EvalPrediction], Dict]
|
15 |
+
|
16 |
+
|
17 |
+
def prepare_data(
|
18 |
+
data_args,
|
19 |
+
model_args,
|
20 |
+
training_args: TrainingArguments,
|
21 |
+
preprocessor: Dict[str, Any],
|
22 |
+
) -> Tuple[DatasetDict, Optional[ComputeMetrics]]:
|
23 |
+
# raw dataset
|
24 |
+
datasets = {
|
25 |
+
'train': partial(DATASETS.build, data_args.train) if training_args.do_train else None,
|
26 |
+
'validation': partial(DATASETS.build, data_args.validation) if training_args.do_eval else None,
|
27 |
+
'test': partial(DATASETS.build, data_args.test) if training_args.do_predict else None,
|
28 |
+
}
|
29 |
+
# compute metric
|
30 |
+
compute_metric_cfg = data_args.get('compute_metric', None)
|
31 |
+
compute_metrics = build_compute_metric(compute_metric_cfg, preprocessor)
|
32 |
+
# conv dataset wrap
|
33 |
+
conv_args = model_args.conv_args
|
34 |
+
tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
|
35 |
+
conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
|
36 |
+
conv_template = partial(get_conv_template, name=conv_template)
|
37 |
+
transforms = conv_args.get('transforms', None)
|
38 |
+
if transforms is not None:
|
39 |
+
transforms = TRANSFORMS.build(transforms)
|
40 |
+
# process func
|
41 |
+
process_func = {}
|
42 |
+
for k, v in model_args.process_func_args.items():
|
43 |
+
process_func[k] = FUNCTIONS.build(cfg=v)
|
44 |
+
|
45 |
+
conv_dataset_cls = partial(
|
46 |
+
SingleImageConvDataset,
|
47 |
+
preprocessor=preprocessor,
|
48 |
+
process_func=process_func,
|
49 |
+
tokenize_kwargs=tokenize_kwargs,
|
50 |
+
conv_template=conv_template,
|
51 |
+
training_args=training_args,
|
52 |
+
transforms=transforms,
|
53 |
+
)
|
54 |
+
ds = {
|
55 |
+
'train': conv_dataset_cls(dataset_generator=datasets['train'], mode='train') if datasets['train'] is not None else None,
|
56 |
+
'validation': conv_dataset_cls(dataset_generator=datasets['validation'], mode='validation') if datasets['validation'] is not None else None,
|
57 |
+
'test': conv_dataset_cls(dataset_generator=datasets['test'], mode='test') if datasets['test'] is not None else None,
|
58 |
+
}
|
59 |
+
|
60 |
+
# multi test set
|
61 |
+
if hasattr(data_args, 'multitest') and bool(data_args.multitest) \
|
62 |
+
and hasattr(training_args, 'do_multi_predict') and training_args.do_multi_predict:
|
63 |
+
print(f"processing multitest set")
|
64 |
+
k2v = {}
|
65 |
+
for k, item in data_args.multitest.items():
|
66 |
+
_dataset_cls = partial(DATASETS.build, item['cfg'])
|
67 |
+
_compute_metric = build_compute_metric(item['compute_metric'], preprocessor)
|
68 |
+
k2v[k] = {
|
69 |
+
"dataset": conv_dataset_cls(dataset_generator=_dataset_cls, mode='test'),
|
70 |
+
"compute_metric": _compute_metric
|
71 |
+
}
|
72 |
+
ds['multitest'] = k2v
|
73 |
+
print(f"processing multitest set. done.")
|
74 |
+
|
75 |
+
# in default, ceph client do init at the beginning of program.
|
76 |
+
# importantly, before dataloader worker fork.
|
77 |
+
lazy_init = data_args.get('lazy_init', True)
|
78 |
+
if not lazy_init:
|
79 |
+
init_ceph_client_if_needed()
|
80 |
+
return ds, compute_metrics
|
81 |
+
|
82 |
+
|
83 |
+
def build_compute_metric(compute_metric_cfg, preprocessor):
|
84 |
+
if compute_metric_cfg is not None:
|
85 |
+
compute_metric_cfg = dict(compute_metric_cfg) # copy cfg because we modify it
|
86 |
+
compute_metric_cfg.update(dict(preprocessor=preprocessor))
|
87 |
+
compute_metrics = METRICS.build(cfg=compute_metric_cfg)
|
88 |
+
else:
|
89 |
+
compute_metrics = None
|
90 |
+
return compute_metrics
|
91 |
+
|
92 |
+
|
93 |
+
def prepare_interactive(
|
94 |
+
model_args,
|
95 |
+
preprocessor: Dict[str, Any],
|
96 |
+
):
|
97 |
+
conv_args = model_args.conv_args
|
98 |
+
tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
|
99 |
+
conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
|
100 |
+
conv_template = partial(get_conv_template, name=conv_template)
|
101 |
+
transforms = conv_args.get('transforms', None)
|
102 |
+
if transforms is not None:
|
103 |
+
transforms = TRANSFORMS.build(transforms)
|
104 |
+
# process func
|
105 |
+
process_func = {}
|
106 |
+
for k, v in model_args.process_func_args.items():
|
107 |
+
process_func[k] = FUNCTIONS.build(cfg=v)
|
108 |
+
|
109 |
+
ds = SingleImageInteractive(
|
110 |
+
preprocessor=preprocessor,
|
111 |
+
process_func=process_func,
|
112 |
+
tokenize_kwargs=tokenize_kwargs,
|
113 |
+
conv_template=conv_template,
|
114 |
+
training_args=None,
|
115 |
+
transforms=transforms,
|
116 |
+
mode='test',
|
117 |
+
)
|
118 |
+
return ds
|
mllm/dataset/process_function/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .shikra_process_function import (
|
2 |
+
ShikraConvProcess,
|
3 |
+
ShikraImageProcessor,
|
4 |
+
ShikraTextProcess,
|
5 |
+
)
|
6 |
+
|
7 |
+
from .box_process_function import (
|
8 |
+
BoxFormatProcess,
|
9 |
+
BoxFormatter,
|
10 |
+
PlainBoxFormatter,
|
11 |
+
TokenFormatter,
|
12 |
+
prepare_target_processor,
|
13 |
+
)
|
mllm/dataset/process_function/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (458 Bytes). View file
|
|
mllm/dataset/process_function/__pycache__/box_process_function.cpython-310.pyc
ADDED
Binary file (10.7 kB). View file
|
|
mllm/dataset/process_function/__pycache__/shikra_process_function.cpython-310.pyc
ADDED
Binary file (6.02 kB). View file
|
|
mllm/dataset/process_function/box_process_function.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
import typing
|
5 |
+
from typing import List, Dict, Any, Tuple, Union
|
6 |
+
|
7 |
+
from ..utils.transform import norm_box_xyxy, norm_point_xyxy
|
8 |
+
|
9 |
+
from ..root import (
|
10 |
+
FUNCTIONS,
|
11 |
+
BaseTargetProcessFunc,
|
12 |
+
BOXES_PLACEHOLDER,
|
13 |
+
BOXES_PROCESSOR,
|
14 |
+
POINTS_PLACEHOLDER,
|
15 |
+
)
|
16 |
+
|
17 |
+
from ...utils import smart_tokenizer_and_embedding_resize
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
logger.setLevel(logging.INFO)
|
21 |
+
logging.basicConfig(
|
22 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
23 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
24 |
+
handlers=[logging.StreamHandler(sys.stdout), ],
|
25 |
+
)
|
26 |
+
|
27 |
+
Box = List[Union[float, int]]
|
28 |
+
Boxes = List[Box]
|
29 |
+
BoxesSeq = List[Boxes]
|
30 |
+
|
31 |
+
|
32 |
+
@FUNCTIONS.register_module()
|
33 |
+
class BoxFormatProcess(BaseTargetProcessFunc):
|
34 |
+
def __call__(self, raw_conv: List[Dict[str, Any]], target: Dict[str, Any], preprocessor: Dict[str, Any],
|
35 |
+
multimage_mode=False) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
36 |
+
box_formatter = preprocessor['target']['boxes']
|
37 |
+
|
38 |
+
if multimage_mode:
|
39 |
+
target = typing.cast(list, target)
|
40 |
+
outer_normalized_boxes = []
|
41 |
+
for tgt in target:
|
42 |
+
normalized_boxes = []
|
43 |
+
if tgt is not None and 'boxes' in tgt:
|
44 |
+
for box in tgt['boxes']:
|
45 |
+
normalized_boxes.append(
|
46 |
+
norm_box_xyxy(box, w=tgt['width'], h=tgt['height'])
|
47 |
+
)
|
48 |
+
outer_normalized_boxes.append(normalized_boxes)
|
49 |
+
normalized_boxes = outer_normalized_boxes
|
50 |
+
outer_normalized_points = []
|
51 |
+
for tgt in target:
|
52 |
+
normalized_points = []
|
53 |
+
if tgt is not None and 'boxes' in tgt:
|
54 |
+
for box in tgt['boxes']:
|
55 |
+
normalized_points.append(
|
56 |
+
norm_box_xyxy(box, w=tgt['width'], h=tgt['height'])
|
57 |
+
)
|
58 |
+
outer_normalized_points.append(normalized_points)
|
59 |
+
normalized_points = outer_normalized_points
|
60 |
+
else:
|
61 |
+
# normalize target
|
62 |
+
normalized_boxes = []
|
63 |
+
if target is not None and 'boxes' in target:
|
64 |
+
for box in target['boxes']:
|
65 |
+
normalized_boxes.append(
|
66 |
+
norm_box_xyxy(box, w=target['width'], h=target['height'])
|
67 |
+
)
|
68 |
+
normalized_points = []
|
69 |
+
if target is not None and 'points' in target:
|
70 |
+
for point in target['points']:
|
71 |
+
normalized_points.append(
|
72 |
+
norm_point_xyxy(point, w=target['width'], h=target['height'])
|
73 |
+
)
|
74 |
+
|
75 |
+
# convert bboxes_seq
|
76 |
+
for sentence in raw_conv:
|
77 |
+
words: str = sentence['value']
|
78 |
+
boxes_seq: List[List[int]] = sentence.get('boxes_seq', None)
|
79 |
+
if boxes_seq is not None:
|
80 |
+
# map box seq
|
81 |
+
boxes_seq: List[Boxes] = map_obj(normalized_boxes, boxes_seq)
|
82 |
+
# reformat; replace <boxes> placeholder
|
83 |
+
converted = box_formatter(words, boxes_seq)
|
84 |
+
words = converted
|
85 |
+
points_seq: List[List[int]] = sentence.get('points_seq', None)
|
86 |
+
if points_seq is not None:
|
87 |
+
# map point seq
|
88 |
+
points_seq: List[Boxes] = map_obj(normalized_points, points_seq)
|
89 |
+
# reformat; replace <points> placeholder
|
90 |
+
converted = box_formatter.call_on_point(words, points_seq)
|
91 |
+
words = converted
|
92 |
+
if boxes_seq is not None or points_seq is not None:
|
93 |
+
sentence['raw_value'] = sentence['value']
|
94 |
+
sentence['value'] = words
|
95 |
+
return raw_conv, target
|
96 |
+
|
97 |
+
|
98 |
+
def map_obj(boxes_value: List[List[float]], boxes_seq: List[List[int]]) -> List[List[List[float]]]:
|
99 |
+
"""
|
100 |
+
>>> normalized_boxes = [[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2], [0.3, 0.3, 0.3, 0.3]]
|
101 |
+
>>> boxes_seq_ = [[3, 1], [2]]
|
102 |
+
>>> var = map_obj(normalized_boxes, boxes_seq_)
|
103 |
+
>>> assert var == [[[0.3,0.3,0.3,0.3], [0.1,0.1,0.1,0.1]], [0.2,0.2,0.2,0.2]]
|
104 |
+
"""
|
105 |
+
try:
|
106 |
+
ret = []
|
107 |
+
for boxes in boxes_seq:
|
108 |
+
boxes_ret = []
|
109 |
+
for box_index in boxes:
|
110 |
+
if isinstance(box_index, (list, tuple)):
|
111 |
+
boxes_ret.append(boxes_value[box_index[0]][box_index[1]])
|
112 |
+
else:
|
113 |
+
boxes_ret.append(boxes_value[box_index])
|
114 |
+
ret.append(boxes_ret)
|
115 |
+
return ret
|
116 |
+
except:
|
117 |
+
raise SystemExit(f"error: map obj {boxes_value} {boxes_seq}")
|
118 |
+
|
119 |
+
|
120 |
+
class BoxFormatter:
|
121 |
+
def __init__(self, bboxes_token=BOXES_PLACEHOLDER, points_token=POINTS_PLACEHOLDER):
|
122 |
+
self.bboxes_token = bboxes_token
|
123 |
+
self.points_token = points_token
|
124 |
+
# normally the bboxes_token_pat is the same as bboxes_token if u not use some weird token
|
125 |
+
self.bboxes_token_pat = re.compile(bboxes_token)
|
126 |
+
self.points_token_pat = re.compile(points_token)
|
127 |
+
|
128 |
+
def __call__(self, sentence: str, bboxes_seq: BoxesSeq) -> str:
|
129 |
+
all_box = self.bboxes_token_pat.findall(sentence)
|
130 |
+
assert len(all_box) == len(bboxes_seq), f"not match. sentence: {sentence}. boxes:{bboxes_seq}"
|
131 |
+
if len(all_box) == 0:
|
132 |
+
return sentence
|
133 |
+
bboxes_strs = [self.format_box(bboxes) for bboxes in bboxes_seq]
|
134 |
+
converted = sentence.replace(self.bboxes_token, '{}').format(*bboxes_strs)
|
135 |
+
return converted
|
136 |
+
|
137 |
+
def call_on_point(self, sentence: str, points_seq: BoxesSeq) -> str:
|
138 |
+
all_box = self.points_token_pat.findall(sentence)
|
139 |
+
assert len(all_box) == len(points_seq), f"not match. sentence: {sentence}. boxes:{points_seq}"
|
140 |
+
if len(all_box) == 0:
|
141 |
+
return sentence
|
142 |
+
bboxes_strs = [self.format_point(bboxes) for bboxes in points_seq]
|
143 |
+
converted = sentence.replace(self.points_token, '{}').format(*bboxes_strs)
|
144 |
+
return converted
|
145 |
+
|
146 |
+
def format_point(self, points) -> str:
|
147 |
+
raise NotImplementedError
|
148 |
+
|
149 |
+
def format_box(self, bboxes: Boxes) -> str:
|
150 |
+
raise NotImplementedError
|
151 |
+
|
152 |
+
def extract(self, string: str) -> List[Boxes]:
|
153 |
+
raise NotImplementedError
|
154 |
+
|
155 |
+
def extract_point(self, string: str) -> List[Boxes]:
|
156 |
+
raise NotImplementedError
|
157 |
+
|
158 |
+
|
159 |
+
@BOXES_PROCESSOR.register_module()
|
160 |
+
class PlainBoxFormatter(BoxFormatter):
|
161 |
+
|
162 |
+
def __init__(self, *args, precision=3, use_small_brackets=False, **kwargs):
|
163 |
+
super().__init__(*args, **kwargs)
|
164 |
+
self.precision = precision
|
165 |
+
self.use_small_brackets = use_small_brackets
|
166 |
+
|
167 |
+
small_brackets_pat = re.compile(r'\(\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\)')
|
168 |
+
small_brackets_point_pat = re.compile(r'\(\d(?:\.\d*)?(?:,\d(?:\.\d*)?)(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?))*\)')
|
169 |
+
|
170 |
+
middle_brackets_pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]')
|
171 |
+
middle_brackets_point_pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?)(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?))*\]')
|
172 |
+
|
173 |
+
self.pat = small_brackets_pat if use_small_brackets else middle_brackets_pat
|
174 |
+
self.point_pat = small_brackets_point_pat if use_small_brackets else middle_brackets_point_pat
|
175 |
+
|
176 |
+
def format_box(self, boxes: Boxes) -> str:
|
177 |
+
box_strs = []
|
178 |
+
for box in boxes:
|
179 |
+
box_strs.append(','.join([f"{elem:.{self.precision}f}" for elem in box]))
|
180 |
+
box_str = ';'.join(box_strs)
|
181 |
+
if self.use_small_brackets:
|
182 |
+
return "(" + box_str + ")"
|
183 |
+
return "[" + box_str + "]"
|
184 |
+
|
185 |
+
def format_point(self, points) -> str:
|
186 |
+
return self.format_box(points)
|
187 |
+
|
188 |
+
def extract(self, string: str) -> List[Boxes]:
|
189 |
+
""" balabala<boxes>balabala<boxes> -> [boxes, boxes] """
|
190 |
+
ret = []
|
191 |
+
for bboxes_str in self.pat.findall(string):
|
192 |
+
bboxes = []
|
193 |
+
bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";")
|
194 |
+
for bbox_str in bbox_strs:
|
195 |
+
bbox = list(map(float, bbox_str.split(',')))
|
196 |
+
bboxes.append(bbox)
|
197 |
+
ret.append(bboxes)
|
198 |
+
return ret
|
199 |
+
|
200 |
+
def extract_point(self, string: str) -> List[Boxes]:
|
201 |
+
""" balabala<boxes>balabala<boxes> -> [boxes, boxes] """
|
202 |
+
ret = []
|
203 |
+
for bboxes_str in self.point_pat.findall(string):
|
204 |
+
bboxes = []
|
205 |
+
bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";")
|
206 |
+
for bbox_str in bbox_strs:
|
207 |
+
bbox = list(map(float, bbox_str.split(',')))
|
208 |
+
bboxes.append(bbox)
|
209 |
+
ret.append(bboxes)
|
210 |
+
return ret
|
211 |
+
|
212 |
+
|
213 |
+
@BOXES_PROCESSOR.register_module()
|
214 |
+
class TokenFormatter(BoxFormatter):
|
215 |
+
|
216 |
+
def __init__(self, num_bins=1001):
|
217 |
+
super().__init__()
|
218 |
+
self.extract_box_pat = re.compile(r'<b_st><bin_\d*?>(?:<bin_\d*?>){3}(?:<b_sep><bin_\d*?>(?:<bin_\d*?>){3})*<b_ed>')
|
219 |
+
self.extract_point_pat = re.compile(r'<p_st><bin_\d*?>(?:<bin_\d*?>){1}(?:<p_sep><bin_\d*?>(?:<bin_\d*?>){1})*<p_ed>')
|
220 |
+
self.num_bins = num_bins
|
221 |
+
self.use_sep = True
|
222 |
+
self.use_begin_end = True
|
223 |
+
|
224 |
+
self.box_begin = '<b_st>'
|
225 |
+
self.box_sep = '<b_sep>'
|
226 |
+
self.box_end = '<b_ed>'
|
227 |
+
|
228 |
+
self.point_begin = '<p_st>'
|
229 |
+
self.point_sep = '<p_sep>'
|
230 |
+
self.point_end = '<p_ed>'
|
231 |
+
|
232 |
+
def format_point(self, points) -> str:
|
233 |
+
final_str = []
|
234 |
+
for bbox in points:
|
235 |
+
quant_x0 = "<bin_{}>".format(round((bbox[0] * (self.num_bins - 1))))
|
236 |
+
quant_y0 = "<bin_{}>".format(round((bbox[1] * (self.num_bins - 1))))
|
237 |
+
region_coord = "{} {}".format(quant_x0, quant_y0)
|
238 |
+
final_str.append(region_coord)
|
239 |
+
if self.use_sep:
|
240 |
+
final_str = self.point_sep.join(final_str)
|
241 |
+
else:
|
242 |
+
final_str = ''.join(final_str)
|
243 |
+
if self.use_begin_end:
|
244 |
+
final_str = self.point_begin + final_str + self.point_end
|
245 |
+
return final_str
|
246 |
+
|
247 |
+
def format_box(self, bboxes: Boxes) -> str:
|
248 |
+
final_str = []
|
249 |
+
for bbox in bboxes:
|
250 |
+
quant_x0 = "<bin_{}>".format(round((bbox[0] * (self.num_bins - 1))))
|
251 |
+
quant_y0 = "<bin_{}>".format(round((bbox[1] * (self.num_bins - 1))))
|
252 |
+
quant_x1 = "<bin_{}>".format(round((bbox[2] * (self.num_bins - 1))))
|
253 |
+
quant_y1 = "<bin_{}>".format(round((bbox[3] * (self.num_bins - 1))))
|
254 |
+
region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
|
255 |
+
final_str.append(region_coord)
|
256 |
+
if self.use_sep:
|
257 |
+
final_str = self.box_sep.join(final_str)
|
258 |
+
else:
|
259 |
+
final_str = ''.join(final_str)
|
260 |
+
if self.use_begin_end:
|
261 |
+
final_str = self.box_begin + final_str + self.box_end
|
262 |
+
return final_str
|
263 |
+
|
264 |
+
def extract(self, string: str) -> List[Boxes]:
|
265 |
+
ret = []
|
266 |
+
for bboxes_str in self.extract_box_pat.findall(string.replace(" ", "")):
|
267 |
+
bboxes = []
|
268 |
+
bbox_strs = bboxes_str.replace(self.box_begin, "").replace(self.box_end, "").split(self.box_sep)
|
269 |
+
for bbox_str in bbox_strs:
|
270 |
+
elems = list(map(int, re.findall(r'<bin_(\d*?)>', bbox_str)))
|
271 |
+
bbox = [elem / (self.num_bins - 1) for elem in elems]
|
272 |
+
bboxes.append(bbox)
|
273 |
+
ret.append(bboxes)
|
274 |
+
return ret
|
275 |
+
|
276 |
+
def extract_point(self, string: str) -> List[Boxes]:
|
277 |
+
ret = []
|
278 |
+
for bboxes_str in self.extract_point_pat.findall(string):
|
279 |
+
bboxes = []
|
280 |
+
bbox_strs = bboxes_str.replace(self.point_begin, "").replace(self.point_end, "").split(self.point_sep)
|
281 |
+
for bbox_str in bbox_strs:
|
282 |
+
elems = list(map(int, re.findall(r'<bin_(\d*?)>', bbox_str)))
|
283 |
+
bbox = [elem / (self.num_bins - 1) for elem in elems]
|
284 |
+
bboxes.append(bbox)
|
285 |
+
ret.append(bboxes)
|
286 |
+
return ret
|
287 |
+
|
288 |
+
def post_process_model_tokenizer(self, model, preprocessor, model_args, training_args):
|
289 |
+
tokenizer = preprocessor['text']
|
290 |
+
|
291 |
+
additional_special_tokens = [
|
292 |
+
self.box_begin, self.box_sep, self.box_end,
|
293 |
+
self.point_begin, self.point_sep, self.point_end,
|
294 |
+
]
|
295 |
+
for i in range(self.num_bins):
|
296 |
+
additional_special_tokens.append(f'<bin_{i}>')
|
297 |
+
|
298 |
+
smart_tokenizer_and_embedding_resize(
|
299 |
+
{'additional_special_tokens': additional_special_tokens},
|
300 |
+
tokenizer,
|
301 |
+
model,
|
302 |
+
)
|
303 |
+
return model, preprocessor
|
304 |
+
|
305 |
+
|
306 |
+
# FIXME: merge into load_pretrained
|
307 |
+
def prepare_target_processor(
|
308 |
+
model, # multimodal llm
|
309 |
+
preprocessor: Dict[str, Any],
|
310 |
+
model_args,
|
311 |
+
training_args,
|
312 |
+
):
|
313 |
+
if not hasattr(model_args, 'target_processor'):
|
314 |
+
return model, preprocessor
|
315 |
+
|
316 |
+
target_processor = {}
|
317 |
+
if 'boxes' in model_args['target_processor']:
|
318 |
+
boxes_cfg = model_args['target_processor']['boxes']
|
319 |
+
boxes_processor = BOXES_PROCESSOR.build(boxes_cfg)
|
320 |
+
target_processor['boxes'] = boxes_processor
|
321 |
+
if hasattr(boxes_processor, "post_process_model_tokenizer"):
|
322 |
+
model, preprocessor = boxes_processor.post_process_model_tokenizer(
|
323 |
+
model, preprocessor, model_args, training_args,
|
324 |
+
)
|
325 |
+
preprocessor['target'] = target_processor
|
326 |
+
return model, preprocessor
|
mllm/dataset/process_function/shikra_process_function.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import copy
|
3 |
+
import warnings
|
4 |
+
import logging
|
5 |
+
from typing import Dict, Any, List
|
6 |
+
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from transformers import LlamaTokenizer
|
11 |
+
|
12 |
+
from ..root import (
|
13 |
+
FUNCTIONS,
|
14 |
+
IMAGE_PLACEHOLDER,
|
15 |
+
BaseImageProcessFunc,
|
16 |
+
BaseConvProcessFunc,
|
17 |
+
BaseTextProcessFunc,
|
18 |
+
)
|
19 |
+
from ...conversation import SeparatorStyle, Conversation
|
20 |
+
|
21 |
+
IGNORE_INDEX = -100
|
22 |
+
DEFAULT_IMAGE_TOKEN = IMAGE_PLACEHOLDER
|
23 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
24 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
25 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
logger.setLevel(logging.INFO)
|
29 |
+
logging.basicConfig(
|
30 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
31 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
32 |
+
handlers=[logging.StreamHandler(sys.stdout), ],
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
@FUNCTIONS.register_module()
|
37 |
+
class ShikraConvProcess(BaseConvProcessFunc):
|
38 |
+
def __call__(self, raw_conv: List[Dict[str, Any]], preprocessor: Dict[str, Any], conv_template: Conversation) -> List[Dict[str, Any]]:
|
39 |
+
conv_processor_cfg = preprocessor['conv']
|
40 |
+
|
41 |
+
image_token_len = conv_processor_cfg['image_token_len']
|
42 |
+
sep_image_conv_front = conv_processor_cfg.get('sep_image_conv_front', False)
|
43 |
+
use_im_start_end = conv_processor_cfg.get('use_im_start_end', False)
|
44 |
+
# assert DEFAULT_IMAGE_PATCH_TOKEN in preprocessor['text'].get_vocab()
|
45 |
+
# if use_im_start_end:
|
46 |
+
# assert DEFAULT_IM_START_TOKEN in preprocessor['text'].get_vocab()
|
47 |
+
# assert DEFAULT_IM_END_TOKEN in preprocessor['text'].get_vocab()
|
48 |
+
|
49 |
+
if sep_image_conv_front:
|
50 |
+
raw_conv[0]['value'] = raw_conv[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
51 |
+
raw_conv[0]['value'] = DEFAULT_IMAGE_TOKEN + conv_template.sep + conv_template.roles[0] + ": " + raw_conv[0]['value']
|
52 |
+
for sentence in raw_conv:
|
53 |
+
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
54 |
+
if use_im_start_end:
|
55 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
56 |
+
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
57 |
+
|
58 |
+
return raw_conv
|
59 |
+
|
60 |
+
|
61 |
+
@FUNCTIONS.register_module()
|
62 |
+
class ShikraTextProcess(BaseTextProcessFunc):
|
63 |
+
|
64 |
+
def __call__(self, conv: Conversation, preprocessor: Dict[str, Any], mode: str, **tokenize_kwargs) -> Dict[str, Any]:
|
65 |
+
tokenizer = preprocessor['text']
|
66 |
+
assert isinstance(tokenizer, LlamaTokenizer), "only work for LlamaTokenizer"
|
67 |
+
|
68 |
+
_truncation_size = tokenize_kwargs.pop('truncation_size', None)
|
69 |
+
_kwargs = {'return_tensors': 'pt'}
|
70 |
+
_kwargs.update(tokenize_kwargs)
|
71 |
+
|
72 |
+
if conv.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
73 |
+
if mode in ['train']:
|
74 |
+
ret = self.tk_conv_colon_two_train(conv, tokenizer, **_kwargs)
|
75 |
+
else:
|
76 |
+
ret = self.tk_conv_colon_two_eval(conv, tokenizer, **_kwargs)
|
77 |
+
else:
|
78 |
+
raise ValueError(f"unrecognized conv_style: {conv.sep_style}.\n the conv is {conv}")
|
79 |
+
|
80 |
+
if _truncation_size is None:
|
81 |
+
return ret
|
82 |
+
if len(ret['input_ids']) <= _truncation_size:
|
83 |
+
return ret
|
84 |
+
|
85 |
+
origin_len = len(ret['input_ids'])
|
86 |
+
ids_to_remove_num = origin_len - _truncation_size
|
87 |
+
# truncation. should carefully not truncate <img_token>
|
88 |
+
ids_should_not_remove = list(map(
|
89 |
+
tokenizer.convert_tokens_to_ids,
|
90 |
+
(DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN)
|
91 |
+
))
|
92 |
+
back_no_image = all(ids not in ids_should_not_remove for ids in ret['input_ids'][_truncation_size:])
|
93 |
+
if back_no_image:
|
94 |
+
tgt_ids = list(range(_truncation_size))
|
95 |
+
else:
|
96 |
+
ids_to_remove = set()
|
97 |
+
for idx in range(origin_len - 1, -1, -1):
|
98 |
+
if ret['input_ids'][idx] not in ids_should_not_remove:
|
99 |
+
ids_to_remove.add(idx)
|
100 |
+
if len(ids_to_remove) >= ids_to_remove_num:
|
101 |
+
break
|
102 |
+
tgt_ids = [_ for _ in range(origin_len) if _ not in ids_to_remove]
|
103 |
+
logger.warning(f"truncate sample size from {origin_len} to {len(tgt_ids)}.")
|
104 |
+
assert len(tgt_ids) == _truncation_size, f"{len(tgt_ids)}, {_truncation_size}, {ret['input_ids'].tolist()}"
|
105 |
+
truncated_ret = {k: v[tgt_ids] for k, v in ret.items()}
|
106 |
+
return truncated_ret
|
107 |
+
|
108 |
+
# noinspection PyMethodMayBeStatic
|
109 |
+
def tk_conv_colon_two_train(self, conv, tokenizer, **kwargs):
|
110 |
+
conversation = conv.get_prompt()
|
111 |
+
input_ids = tokenizer([conversation, ], **kwargs).input_ids[0]
|
112 |
+
target = copy.deepcopy(input_ids)
|
113 |
+
assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
|
114 |
+
# Mask targets
|
115 |
+
sep = conv.sep + conv.roles[1] + ": "
|
116 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
117 |
+
rounds = conversation.split(conv.sep2)
|
118 |
+
cur_len = 1
|
119 |
+
target[:cur_len] = IGNORE_INDEX
|
120 |
+
for i, rou in enumerate(rounds):
|
121 |
+
if rou == "":
|
122 |
+
break
|
123 |
+
parts = rou.split(sep)
|
124 |
+
if len(parts) != 2:
|
125 |
+
break
|
126 |
+
parts[0] += sep
|
127 |
+
round_len = len(tokenizer(rou).input_ids)
|
128 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # <s> <space>
|
129 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
130 |
+
cur_len += round_len
|
131 |
+
target[cur_len:] = IGNORE_INDEX
|
132 |
+
if cur_len < tokenizer.model_max_length:
|
133 |
+
if cur_len != total_len:
|
134 |
+
target[:] = IGNORE_INDEX
|
135 |
+
warnings.warn(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. (ignored):\n{conversation}")
|
136 |
+
return dict(
|
137 |
+
input_ids=input_ids,
|
138 |
+
attention_mask=input_ids.ne(tokenizer.pad_token_id),
|
139 |
+
labels=target,
|
140 |
+
)
|
141 |
+
|
142 |
+
# noinspection PyMethodMayBeStatic
|
143 |
+
def tk_conv_colon_two_eval(self, conv, tokenizer, **kwargs):
|
144 |
+
assert len(conv.messages) >= 2
|
145 |
+
# target = conv.messages[-1][-1]
|
146 |
+
target = conv.get_prompt()
|
147 |
+
|
148 |
+
conv.messages[-1][-1] = ""
|
149 |
+
conversation = conv.get_prompt()
|
150 |
+
input_ids = tokenizer([conversation, ], **kwargs).input_ids[0]
|
151 |
+
|
152 |
+
target = tokenizer([target, ], add_special_tokens=False, **kwargs).input_ids[0]
|
153 |
+
target[target == tokenizer.pad_token_id] = IGNORE_INDEX
|
154 |
+
return dict(
|
155 |
+
input_ids=input_ids,
|
156 |
+
attention_mask=input_ids.ne(tokenizer.pad_token_id),
|
157 |
+
labels=target,
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
@FUNCTIONS.register_module()
|
162 |
+
class ShikraImageProcessor(BaseImageProcessFunc):
|
163 |
+
def __call__(self, image: Image.Image, preprocessor: Dict[str, Any]) -> Dict[str, Any]:
|
164 |
+
image_processor = preprocessor['image']
|
165 |
+
|
166 |
+
if isinstance(image, (list, tuple)):
|
167 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
|
168 |
+
assert False, 'Shikra not support MultiImage'
|
169 |
+
elif isinstance(image, PIL.Image.Image):
|
170 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
171 |
+
else:
|
172 |
+
if hasattr(image_processor, 'crop_size'):
|
173 |
+
crop_size = image_processor.crop_size
|
174 |
+
height, width = crop_size['height'], crop_size['width']
|
175 |
+
else:
|
176 |
+
raise ValueError("got empty image. and don't know how to pad")
|
177 |
+
image = torch.zeros(3, height, width)
|
178 |
+
return {'image': image}
|
mllm/dataset/root.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any, List, Tuple
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
from mmengine import DATASETS, TRANSFORMS, METRICS, FUNCTIONS, Registry
|
5 |
+
|
6 |
+
from ..conversation import Conversation
|
7 |
+
|
8 |
+
IMAGE_PLACEHOLDER = '<image>'
|
9 |
+
BOXES_PLACEHOLDER = '<boxes>'
|
10 |
+
EXPR_PLACEHOLDER = '<expr>'
|
11 |
+
OBJS_PLACEHOLDER = '<objs>'
|
12 |
+
QUESTION_PLACEHOLDER = '<question>'
|
13 |
+
POINTS_PLACEHOLDER = '<points>'
|
14 |
+
# processor
|
15 |
+
BOXES_PROCESSOR = Registry('Processor for Boxes')
|
16 |
+
|
17 |
+
|
18 |
+
# only for static type checking
|
19 |
+
class BaseConvProcessFunc:
|
20 |
+
def __call__(
|
21 |
+
self,
|
22 |
+
raw_conv: List[Dict[str, Any]],
|
23 |
+
preprocessor: Dict[str, Any],
|
24 |
+
conv_template: Conversation,
|
25 |
+
) -> List[Dict[str, Any]]:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
|
29 |
+
class BaseTargetProcessFunc:
|
30 |
+
def __call__(
|
31 |
+
self,
|
32 |
+
raw_conv: List[Dict[str, Any]],
|
33 |
+
target: Dict[str, Any],
|
34 |
+
preprocessor: Dict[str, Any],
|
35 |
+
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
36 |
+
raise NotImplementedError
|
37 |
+
|
38 |
+
|
39 |
+
class BaseTextProcessFunc:
|
40 |
+
def __call__(
|
41 |
+
self,
|
42 |
+
conv: Conversation,
|
43 |
+
preprocessor: Dict[str, Any],
|
44 |
+
mode: str,
|
45 |
+
**tokenize_kwargs,
|
46 |
+
) -> Dict[str, Any]:
|
47 |
+
raise NotImplementedError
|
48 |
+
|
49 |
+
|
50 |
+
class BaseImageProcessFunc:
|
51 |
+
def __call__(
|
52 |
+
self,
|
53 |
+
image: Image.Image,
|
54 |
+
preprocessor: Dict[str, Any],
|
55 |
+
) -> Dict[str, Any]:
|
56 |
+
raise NotImplementedError
|
57 |
+
|
58 |
+
|
59 |
+
__all__ = [
|
60 |
+
'IMAGE_PLACEHOLDER', 'BOXES_PLACEHOLDER', 'EXPR_PLACEHOLDER', 'OBJS_PLACEHOLDER', 'QUESTION_PLACEHOLDER', 'POINTS_PLACEHOLDER',
|
61 |
+
'FUNCTIONS',
|
62 |
+
'DATASETS',
|
63 |
+
'TRANSFORMS',
|
64 |
+
'METRICS',
|
65 |
+
'BOXES_PROCESSOR',
|
66 |
+
'BaseConvProcessFunc', 'BaseTargetProcessFunc', 'BaseTextProcessFunc', 'BaseImageProcessFunc',
|
67 |
+
]
|
mllm/dataset/single_image_convsation.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from functools import partial
|
3 |
+
from typing import Dict, Any, Callable, List, Optional, Tuple, Type
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from transformers import TrainingArguments
|
9 |
+
|
10 |
+
from .root import IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER
|
11 |
+
from ..conversation import Conversation, get_conv_template
|
12 |
+
from ..utils import post_process_generate_ids
|
13 |
+
|
14 |
+
|
15 |
+
class SingleImageConvDatasetMixin:
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
*args,
|
20 |
+
preprocessor: Dict[str, Any],
|
21 |
+
process_func: Dict[str, Any],
|
22 |
+
conv_template: Callable[[], Conversation] = partial(get_conv_template, name='vicuna_v1.1'),
|
23 |
+
mode='train',
|
24 |
+
tokenize_kwargs: dict = None,
|
25 |
+
training_args: TrainingArguments = None,
|
26 |
+
transforms: Optional[Callable] = None,
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
+
super().__init__(*args, **kwargs)
|
30 |
+
assert mode in ['train', 'validation', 'test']
|
31 |
+
|
32 |
+
self.preprocessor = preprocessor
|
33 |
+
self.process_func = process_func
|
34 |
+
self.conv_template = conv_template
|
35 |
+
self.mode = mode
|
36 |
+
self.tokenize_kwargs = tokenize_kwargs if tokenize_kwargs is not None else {}
|
37 |
+
self.training_args = training_args
|
38 |
+
self.transforms = transforms
|
39 |
+
|
40 |
+
def __getitem__(self, index, debug_mode=False, return_conv=False) -> Dict[str, Any]:
|
41 |
+
# getitem
|
42 |
+
item = self.get_raw_item(index)
|
43 |
+
image: Image.Image = item.get('image', None)
|
44 |
+
target: Dict[str, Any] = item.get('target', None)
|
45 |
+
raw_conv: List[Dict[str, Any]] = item['conversations']
|
46 |
+
|
47 |
+
# transform
|
48 |
+
assert isinstance(image, list) == isinstance(target, list)
|
49 |
+
multimage_mode = isinstance(image, list)
|
50 |
+
if isinstance(image, list):
|
51 |
+
# TODO: validate raw item
|
52 |
+
transformed_image, transformed_target = [], []
|
53 |
+
for img, tgt in zip(image, target):
|
54 |
+
if self.transforms is not None and image is not None:
|
55 |
+
img, tgt = self.transforms(img, tgt)
|
56 |
+
if tgt is not None:
|
57 |
+
tgt['width'], tgt['height'] = img.width, img.height
|
58 |
+
transformed_image.append(img)
|
59 |
+
transformed_target.append(tgt)
|
60 |
+
image, target = transformed_image, transformed_target
|
61 |
+
else:
|
62 |
+
self.validate_raw_item(item) # only validate for single image.
|
63 |
+
if self.transforms is not None and image is not None:
|
64 |
+
image, target = self.transforms(image, target)
|
65 |
+
has_image = 'image' in item and bool(item['image'])
|
66 |
+
has_target = 'target' in item and bool(item['target']) and any(bool(elem) for elem in item['target'].values())
|
67 |
+
if has_target and has_image:
|
68 |
+
target['width'], target['height'] = image.width, image.height
|
69 |
+
|
70 |
+
# preprocess
|
71 |
+
raw_conv = self.process_conv(raw_conv)
|
72 |
+
raw_conv, image = self.process_conv_multimage(raw_conv, image)
|
73 |
+
raw_conv, _ = self.process_target(raw_conv, target, multimage_mode=multimage_mode)
|
74 |
+
conv = self.build_conv(raw_conv)
|
75 |
+
if return_conv:
|
76 |
+
# noinspection PyTypeChecker
|
77 |
+
return conv
|
78 |
+
text_dict = self.process_text(conv)
|
79 |
+
image_dict = self.process_image(image)
|
80 |
+
|
81 |
+
# return
|
82 |
+
ret_dict = {}
|
83 |
+
ret_dict.update(text_dict)
|
84 |
+
ret_dict.update(image_dict)
|
85 |
+
self._print_sample(ret_dict, raw_conv, conv)
|
86 |
+
if debug_mode:
|
87 |
+
return {'ret': ret_dict, 'raw_conv': raw_conv, 'conv': conv, 'image': image}
|
88 |
+
return ret_dict
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
raise NotImplementedError
|
92 |
+
|
93 |
+
# noinspection PyMethodMayBeStatic
|
94 |
+
def process_conv_multimage(self, raw_conv, image):
|
95 |
+
# re-sort multi image
|
96 |
+
if image is None:
|
97 |
+
return raw_conv, image
|
98 |
+
if not isinstance(image, (list, tuple)):
|
99 |
+
return raw_conv, image
|
100 |
+
image_seqs = []
|
101 |
+
for conv in raw_conv:
|
102 |
+
image_seqs.extend(conv['image_seq'] if 'image_seq' in conv else [])
|
103 |
+
images = []
|
104 |
+
for idx in image_seqs:
|
105 |
+
images.append(image[idx])
|
106 |
+
return raw_conv, images
|
107 |
+
|
108 |
+
def get_raw_item(self, index) -> Dict[str, Any]:
|
109 |
+
"""
|
110 |
+
return item format like this.
|
111 |
+
item = {
|
112 |
+
'image': # PIL.Image.Image,
|
113 |
+
'target': {
|
114 |
+
# xmin, ymin, xmax, ymax
|
115 |
+
'boxes': [
|
116 |
+
[10, 10, 256, 265], # dog1
|
117 |
+
[24, 18, 378, 768], # dog2
|
118 |
+
[100, 310, 670, 653], # man
|
119 |
+
[278, 320, 809, 673], # rope
|
120 |
+
],
|
121 |
+
}
|
122 |
+
|
123 |
+
"conversations": [
|
124 |
+
{
|
125 |
+
'from': 'human',
|
126 |
+
'value': 'What is the relation between the two dogs <boxes> and the man <boxes> in the image <image> ?',
|
127 |
+
'boxes_seq': [[0, 1], [2], ],
|
128 |
+
},
|
129 |
+
{
|
130 |
+
'from': 'gpt',
|
131 |
+
'value': 'a rope <boxes> is connecting the left dog <boxes> with the man <boxes>. '
|
132 |
+
'So the man <boxes> is walking the dog <boxes>.'
|
133 |
+
'And the man <boxes> has no relationship with the right dog <boxes>',
|
134 |
+
'boxes_seq': [[3], [0], [2], [2], [0], [2], [1]],
|
135 |
+
}
|
136 |
+
]
|
137 |
+
}
|
138 |
+
# placeholder: <image> <boxes>
|
139 |
+
"""
|
140 |
+
raise NotImplementedError
|
141 |
+
|
142 |
+
# noinspection PyMethodMayBeStatic
|
143 |
+
def validate_raw_item(self, item):
|
144 |
+
has_image = 'image' in item and bool(item['image'])
|
145 |
+
has_target = 'target' in item and bool(item['target']) and any(bool(elem) for elem in item['target'].values())
|
146 |
+
has_target_boxes = 'boxes' in item['target'] if has_target else False
|
147 |
+
raw_conv: List[Dict[str, Any]] = item['conversations']
|
148 |
+
|
149 |
+
# check image
|
150 |
+
human_input_has_image_placeholder = any(
|
151 |
+
sentence['from'] == 'human' and IMAGE_PLACEHOLDER in sentence['value'] for sentence in raw_conv
|
152 |
+
)
|
153 |
+
if human_input_has_image_placeholder:
|
154 |
+
assert has_image
|
155 |
+
if has_image and (not human_input_has_image_placeholder):
|
156 |
+
warnings.warn(f'item has image but the question has no image placeholder.\n{item}')
|
157 |
+
gpt_input_has_image_placeholder = any(
|
158 |
+
sentence['from'] == 'gpt' and IMAGE_PLACEHOLDER in sentence['value'] for sentence in raw_conv
|
159 |
+
)
|
160 |
+
assert not gpt_input_has_image_placeholder
|
161 |
+
|
162 |
+
# check target
|
163 |
+
has_boxes_placeholder = any(
|
164 |
+
BOXES_PLACEHOLDER in sentence['value'] for sentence in raw_conv
|
165 |
+
)
|
166 |
+
if has_boxes_placeholder:
|
167 |
+
assert has_target_boxes
|
168 |
+
# not check box placeholder num this will be checked in format process
|
169 |
+
|
170 |
+
def build_conv(self, source: List[Dict[str, Any]]) -> Conversation:
|
171 |
+
conv = self.conv_template()
|
172 |
+
role_map = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
173 |
+
assert len(source) > 0
|
174 |
+
assert source[0]['from'] == 'human'
|
175 |
+
for sentence in source:
|
176 |
+
role = role_map[sentence['from']]
|
177 |
+
conv.append_message(role, sentence['value'])
|
178 |
+
return conv
|
179 |
+
|
180 |
+
def process_conv(self, raw_conv: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
181 |
+
"""
|
182 |
+
some utils preprocess for raw_conv.
|
183 |
+
e.g. replace <image> placeholder to sequence <im_start> <im_patch>*256 <im_end>
|
184 |
+
"""
|
185 |
+
return self.process_func['conv'](raw_conv, self.preprocessor, self.conv_template)
|
186 |
+
|
187 |
+
def process_target(self, raw_conv: List[Dict[str, Any]], target: Dict[str, Any], multimage_mode=False) -> Tuple[
|
188 |
+
List[Dict[str, Any]], Dict[str, Any]]:
|
189 |
+
"""
|
190 |
+
convert target placeholder to actual information in raw_conv.
|
191 |
+
e.g. normalize bounding boxes; convert bounding boxes format; replace <boxes> placeholder
|
192 |
+
"""
|
193 |
+
return self.process_func['target'](raw_conv, target, self.preprocessor, multimage_mode=multimage_mode)
|
194 |
+
|
195 |
+
def process_text(self, conv: Conversation) -> Dict[str, Any]:
|
196 |
+
"""
|
197 |
+
convert Conversation object to torch.Tensor, e.g. input_ids, labels, attention_mask, etc.
|
198 |
+
self.tokenize_kwargs control something like padding/truncation behavior.
|
199 |
+
"""
|
200 |
+
return self.process_func['text'](conv, self.preprocessor, self.mode, **self.tokenize_kwargs)
|
201 |
+
|
202 |
+
def process_image(self, image: Image.Image) -> Dict[str, Any]:
|
203 |
+
"""
|
204 |
+
convert Image.Image object to torch.Tensor
|
205 |
+
"""
|
206 |
+
return self.process_func['image'](image, self.preprocessor)
|
207 |
+
|
208 |
+
def _print_sample(self, ret_dict, raw_conv, conv):
|
209 |
+
if not hasattr(self, '_printed_sample'):
|
210 |
+
self._printed_sample = True
|
211 |
+
post_processed_labels = post_process_generate_ids(self.preprocessor['text'], ret_dict['labels'])
|
212 |
+
print(f"=================== {self.mode} sample ===================", flush=True)
|
213 |
+
print(f" input_ids: {self.preprocessor['text'].convert_ids_to_tokens(ret_dict['input_ids'])}")
|
214 |
+
print(f" labels: {self.preprocessor['text'].convert_ids_to_tokens(post_processed_labels)}")
|
215 |
+
print(f"decoded input_ids: {self.preprocessor['text'].decode(ret_dict['input_ids'])}")
|
216 |
+
print(f"decoded labels: {self.preprocessor['text'].decode(post_processed_labels)}")
|
217 |
+
if 'image' in ret_dict and ret_dict['image'] is not None:
|
218 |
+
image = ret_dict['image']
|
219 |
+
if isinstance(image, torch.Tensor):
|
220 |
+
print(f" image: {image.shape}")
|
221 |
+
elif isinstance(image, dict):
|
222 |
+
print(f" image: {image.keys()}")
|
223 |
+
elif isinstance(image, list) and len(image) > 0:
|
224 |
+
print(f" image: {len(image)}, {type(image[0])}")
|
225 |
+
else:
|
226 |
+
print(f" image: {type(image)}")
|
227 |
+
print("====================================================", flush=True)
|
228 |
+
try:
|
229 |
+
if self.training_args is not None:
|
230 |
+
_save_obj = {
|
231 |
+
'ret_dict': ret_dict,
|
232 |
+
'raw_conv': raw_conv,
|
233 |
+
'conv': conv.get_prompt(),
|
234 |
+
}
|
235 |
+
from pathlib import Path
|
236 |
+
output_dir = Path(self.training_args.output_dir)
|
237 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
238 |
+
_local_rank = self.training_args.local_rank
|
239 |
+
_word_size = self.training_args.world_size
|
240 |
+
_file_path = str(output_dir / f'sample_check_{self.mode}_{_local_rank}_{_word_size}.pt')
|
241 |
+
print(f'saving some sample to {_file_path} for check.')
|
242 |
+
torch.save(_save_obj, _file_path)
|
243 |
+
except Exception as e:
|
244 |
+
warnings.warn(f'try to save samples but get exception: {e.args}. ignored.')
|
245 |
+
|
246 |
+
|
247 |
+
class SingleImageConvDataset(SingleImageConvDatasetMixin, Dataset):
|
248 |
+
_repr_indent = 4
|
249 |
+
|
250 |
+
def __init__(self, *args, dataset_generator: Type[Dataset], **kwargs):
|
251 |
+
super().__init__(*args, **kwargs)
|
252 |
+
self.dataset_generator = dataset_generator
|
253 |
+
self.dataset = None
|
254 |
+
|
255 |
+
def initialize_if_needed(self):
|
256 |
+
"""
|
257 |
+
lazy initialize for big in-memory python object due to python 'copy-on-read' behavior
|
258 |
+
when num_worker > 0. refer: https://github.com/pytorch/pytorch/issues/13246
|
259 |
+
"""
|
260 |
+
if self.dataset is None:
|
261 |
+
# warnings.warn("it's highly recommended that set persistent_workers=True, "
|
262 |
+
# "otherwise this initialize code will run in every epoch beginning."
|
263 |
+
# "(ignore me if set)")
|
264 |
+
self.dataset = self.dataset_generator()
|
265 |
+
|
266 |
+
def __len__(self):
|
267 |
+
self.initialize_if_needed()
|
268 |
+
return len(self.dataset)
|
269 |
+
|
270 |
+
def get_raw_item(self, index) -> Dict[str, Any]:
|
271 |
+
self.initialize_if_needed()
|
272 |
+
return self.dataset[index]
|
273 |
+
|
274 |
+
def __repr__(self) -> str:
|
275 |
+
head = "Dataset " + self.__class__.__name__
|
276 |
+
body = [
|
277 |
+
f"Number of datapoints: {self.__len__()}",
|
278 |
+
]
|
279 |
+
body += self.dataset.__repr__().splitlines()
|
280 |
+
lines = [head] + [" " * self._repr_indent + line for line in body]
|
281 |
+
return "\n".join(lines)
|
282 |
+
|
283 |
+
|
284 |
+
__all__ = ['SingleImageConvDatasetMixin', 'SingleImageConvDataset']
|
mllm/dataset/single_image_dataset/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .flickr import FlickrParser, FlickrDataset
|
2 |
+
from .rec import RECDataset, RECComputeMetrics
|
3 |
+
from .reg import REGDataset, GCDataset
|
4 |
+
from .caption import CaptionDataset
|
5 |
+
from .instr import InstructDataset
|
6 |
+
from .gqa import GQADataset, GQAComputeMetrics
|
7 |
+
from .clevr import ClevrDataset
|
8 |
+
from .point_qa import Point_QA_local, Point_QA_twice, V7W_POINT, PointQAComputeMetrics
|
9 |
+
from .gpt_gen import GPT4Gen
|
10 |
+
from .vcr import VCRDataset, VCRPredDataset
|
11 |
+
from .vqav2 import VQAv2Dataset
|
12 |
+
from .vqaex import VQAEXDataset
|
13 |
+
from .pope import POPEVQADataset
|
mllm/dataset/single_image_dataset/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (909 Bytes). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/caption.cpython-310.pyc
ADDED
Binary file (1.07 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/clevr.cpython-310.pyc
ADDED
Binary file (3.89 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/flickr.cpython-310.pyc
ADDED
Binary file (2.65 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/gpt_gen.cpython-310.pyc
ADDED
Binary file (1.64 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/gqa.cpython-310.pyc
ADDED
Binary file (6.73 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/instr.cpython-310.pyc
ADDED
Binary file (1.08 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/point_qa.cpython-310.pyc
ADDED
Binary file (5.11 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/pope.cpython-310.pyc
ADDED
Binary file (1.2 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/rec.cpython-310.pyc
ADDED
Binary file (3.73 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/reg.cpython-310.pyc
ADDED
Binary file (1.39 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/vcr.cpython-310.pyc
ADDED
Binary file (5.39 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/vqaex.cpython-310.pyc
ADDED
Binary file (1.48 kB). View file
|
|
mllm/dataset/single_image_dataset/__pycache__/vqav2.cpython-310.pyc
ADDED
Binary file (1.29 kB). View file
|
|
mllm/dataset/single_image_dataset/caption.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..root import DATASETS, IMAGE_PLACEHOLDER
|
2 |
+
from ..utils import MInstrDataset
|
3 |
+
|
4 |
+
|
5 |
+
@DATASETS.register_module()
|
6 |
+
class CaptionDataset(MInstrDataset):
|
7 |
+
def __init__(self, *args, **kwargs):
|
8 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,))
|
9 |
+
|
10 |
+
def __getitem__(self, index):
|
11 |
+
item = self.get_raw_item(index)
|
12 |
+
img_path = item['img_path']
|
13 |
+
caption = item['caption']
|
14 |
+
|
15 |
+
image = self.get_image(img_path)
|
16 |
+
question = self.get_template()
|
17 |
+
|
18 |
+
ret = {
|
19 |
+
'image': image,
|
20 |
+
'conversations': [
|
21 |
+
{
|
22 |
+
'from': 'human',
|
23 |
+
'value': question,
|
24 |
+
},
|
25 |
+
{
|
26 |
+
'from': 'gpt',
|
27 |
+
'value': caption,
|
28 |
+
}
|
29 |
+
]
|
30 |
+
}
|
31 |
+
return ret
|
mllm/dataset/single_image_dataset/clevr.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from ..root import DATASETS, IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER, POINTS_PLACEHOLDER
|
4 |
+
from ..utils import MInstrDataset
|
5 |
+
|
6 |
+
|
7 |
+
@DATASETS.register_module()
|
8 |
+
class ClevrDataset(MInstrDataset):
|
9 |
+
def __init__(self, *args, scene_graph_file, version, **kwargs):
|
10 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
|
11 |
+
self.scene_graph_file = scene_graph_file
|
12 |
+
self.version = version
|
13 |
+
qtype, atype = version.split('-')
|
14 |
+
assert qtype in ['q']
|
15 |
+
assert atype in ['a', 's', 'bs']
|
16 |
+
self.qtype = qtype
|
17 |
+
self.atype = atype
|
18 |
+
|
19 |
+
if scene_graph_file is None:
|
20 |
+
self.scene_graph = None
|
21 |
+
else:
|
22 |
+
self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')]
|
23 |
+
|
24 |
+
def get_raw_item(self, index):
|
25 |
+
question = json.loads(self.data[index])
|
26 |
+
if self.scene_graph is None:
|
27 |
+
scene = None
|
28 |
+
else:
|
29 |
+
scene = json.loads(self.scene_graph[question['image_index']])
|
30 |
+
return question, scene
|
31 |
+
|
32 |
+
def __getitem__(self, index):
|
33 |
+
question, scene = self.get_raw_item(index)
|
34 |
+
img_path = question['image_filename']
|
35 |
+
image = self.get_image(img_path)
|
36 |
+
|
37 |
+
if self.atype == 'a':
|
38 |
+
boxes = []
|
39 |
+
answer = f"The answer is {question['answer']}."
|
40 |
+
answer_boxes_seq = []
|
41 |
+
elif self.atype == 's':
|
42 |
+
answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=False)
|
43 |
+
answer += f" The answer is {question['answer']}."
|
44 |
+
elif self.atype == 'bs':
|
45 |
+
answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=True)
|
46 |
+
answer += f" The answer is {question['answer']}."
|
47 |
+
else:
|
48 |
+
assert False
|
49 |
+
|
50 |
+
if self.qtype == 'q':
|
51 |
+
query_boxes_seq = []
|
52 |
+
final_query = self.get_template().replace(QUESTION_PLACEHOLDER, question['question'])
|
53 |
+
else:
|
54 |
+
assert False
|
55 |
+
|
56 |
+
ret = {
|
57 |
+
'image': image,
|
58 |
+
'target': {'points': boxes},
|
59 |
+
'conversations': [
|
60 |
+
{
|
61 |
+
'from': 'human',
|
62 |
+
'value': final_query,
|
63 |
+
'points_seq': query_boxes_seq,
|
64 |
+
},
|
65 |
+
{
|
66 |
+
'from': 'gpt',
|
67 |
+
'value': answer,
|
68 |
+
'points_seq': answer_boxes_seq,
|
69 |
+
}
|
70 |
+
]
|
71 |
+
}
|
72 |
+
return ret
|
73 |
+
|
74 |
+
|
75 |
+
def get_boxes_idx(boxes_list, refs):
|
76 |
+
def get_idx(boxes_list, box):
|
77 |
+
if box in boxes_list:
|
78 |
+
return boxes_list.index(box)
|
79 |
+
else:
|
80 |
+
boxes_list.append(box)
|
81 |
+
return len(boxes_list) - 1
|
82 |
+
|
83 |
+
idx = [get_idx(boxes_list, box) for box in refs]
|
84 |
+
return idx
|
85 |
+
|
86 |
+
|
87 |
+
def clevr_ss_cot(obj, scene, add_ref=False):
|
88 |
+
cot = []
|
89 |
+
boxes = []
|
90 |
+
seq = []
|
91 |
+
|
92 |
+
def can_add_ref():
|
93 |
+
if p['function'] in ['unique', 'union', 'intersect', 'relate', 'same_size', 'same_shape', 'same_material', 'same_color']:
|
94 |
+
return True
|
95 |
+
if p['function'] in ['scene', 'filter_color', 'filter_material', 'filter_shape', 'filter_size']:
|
96 |
+
if idx + 1 < len(obj['program']) and obj['program'][idx + 1]['function'] in ['exist', 'count']:
|
97 |
+
return True
|
98 |
+
return False
|
99 |
+
|
100 |
+
for idx, p in enumerate(obj['program']):
|
101 |
+
func = f"{p['function']}:{p['value_inputs'][0]}" if 'value_inputs' in p and p['value_inputs'] else p['function']
|
102 |
+
inputs = f"[{','.join(map(str, p['inputs']))}]" if p['inputs'] else ""
|
103 |
+
|
104 |
+
if add_ref and can_add_ref():
|
105 |
+
if p['ans']:
|
106 |
+
objs = POINTS_PLACEHOLDER
|
107 |
+
idx = get_boxes_idx(boxes_list=boxes, refs=[scene['objects'][_]['pixel_coords'][:2] for _ in p['ans']])
|
108 |
+
seq.append(idx)
|
109 |
+
else:
|
110 |
+
objs = f" Found no object."
|
111 |
+
else:
|
112 |
+
objs = ""
|
113 |
+
cot.append(f"{func}{inputs}{objs}")
|
114 |
+
|
115 |
+
ret = " -> ".join(cot)
|
116 |
+
return ret, boxes, seq
|
mllm/dataset/single_image_dataset/flickr.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
|
3 |
+
from ..root import DATASETS, BOXES_PLACEHOLDER, IMAGE_PLACEHOLDER
|
4 |
+
from ..utils import MInstrDataset
|
5 |
+
from ..utils.flickr30k_entities_utils import (
|
6 |
+
flatten_annotation,
|
7 |
+
PHRASE_ED_PLACEHOLDER,
|
8 |
+
PHRASE_ST_PLACEHOLDER,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
class FlickrParser(Dataset):
|
13 |
+
def __init__(self, filename, annotation_dir):
|
14 |
+
self.filename = filename
|
15 |
+
self.annotation_dir = annotation_dir
|
16 |
+
|
17 |
+
self.indexes = [line.strip() for line in open(filename, 'r', encoding='utf8')]
|
18 |
+
self.data = flatten_annotation(self.annotation_dir, self.indexes)
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.data)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
return self.data[index]
|
25 |
+
|
26 |
+
def dump(self, filename):
|
27 |
+
import json
|
28 |
+
with open(filename, 'w', encoding='utf8') as f:
|
29 |
+
for obj in self.data:
|
30 |
+
obj_str = json.dumps(obj)
|
31 |
+
f.write(obj_str)
|
32 |
+
f.write('\n')
|
33 |
+
|
34 |
+
|
35 |
+
@DATASETS.register_module()
|
36 |
+
class FlickrDataset(MInstrDataset):
|
37 |
+
|
38 |
+
def __init__(self, *args, **kwargs):
|
39 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,))
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.data)
|
43 |
+
|
44 |
+
def __getitem__(self, index):
|
45 |
+
item = self.get_raw_item(index)
|
46 |
+
img_path = f"{item['image_id']}.jpg"
|
47 |
+
caption = item['sentence']
|
48 |
+
|
49 |
+
image = self.get_image(img_path)
|
50 |
+
caption = caption.replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
|
51 |
+
question = self.get_template()
|
52 |
+
|
53 |
+
ret = {
|
54 |
+
'image': image,
|
55 |
+
'target': {'boxes': item['boxes']},
|
56 |
+
'conversations': [
|
57 |
+
{
|
58 |
+
'from': 'human',
|
59 |
+
'value': question,
|
60 |
+
},
|
61 |
+
{
|
62 |
+
'from': 'gpt',
|
63 |
+
'value': caption,
|
64 |
+
'boxes_seq': item['boxes_seq'],
|
65 |
+
}
|
66 |
+
]
|
67 |
+
}
|
68 |
+
return ret
|
mllm/dataset/single_image_dataset/gpt_gen.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..root import (
|
2 |
+
DATASETS,
|
3 |
+
QUESTION_PLACEHOLDER,
|
4 |
+
IMAGE_PLACEHOLDER,
|
5 |
+
BOXES_PLACEHOLDER,
|
6 |
+
)
|
7 |
+
from ..utils import MInstrDataset
|
8 |
+
from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class GPT4Gen(MInstrDataset):
|
13 |
+
def __init__(self, *args, version, **kwargs):
|
14 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
|
15 |
+
self.version = version
|
16 |
+
assert version in ['a', 'c', 'bc']
|
17 |
+
|
18 |
+
def __getitem__(self, item):
|
19 |
+
raw = self.get_raw_item(item)
|
20 |
+
#
|
21 |
+
image = self.get_image(raw['img_path'])
|
22 |
+
#
|
23 |
+
boxes = raw['boxes']
|
24 |
+
#
|
25 |
+
question = raw['question']
|
26 |
+
question = question.replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
|
27 |
+
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
|
28 |
+
query_boxes_seq = raw['question_boxes_seq']
|
29 |
+
|
30 |
+
if self.version == 'a':
|
31 |
+
final_answer = raw['answer']
|
32 |
+
answer_boxes_seq = None
|
33 |
+
elif self.version == 'c':
|
34 |
+
final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, '')
|
35 |
+
answer_boxes_seq = None
|
36 |
+
elif self.version == 'bc':
|
37 |
+
final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
|
38 |
+
answer_boxes_seq = raw['answer_boxes_seq']
|
39 |
+
else:
|
40 |
+
assert False
|
41 |
+
|
42 |
+
ret = {
|
43 |
+
'image': image,
|
44 |
+
'target': {'boxes': boxes},
|
45 |
+
'conversations': [
|
46 |
+
{
|
47 |
+
'from': 'human',
|
48 |
+
'value': final_question,
|
49 |
+
'boxes_seq': query_boxes_seq,
|
50 |
+
},
|
51 |
+
{
|
52 |
+
'from': 'gpt',
|
53 |
+
'value': final_answer,
|
54 |
+
'boxes_seq': answer_boxes_seq,
|
55 |
+
}
|
56 |
+
]
|
57 |
+
}
|
58 |
+
return ret
|
mllm/dataset/single_image_dataset/gqa.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
|
4 |
+
from ..root import DATASETS, IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER, QUESTION_PLACEHOLDER, METRICS
|
5 |
+
from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER
|
6 |
+
from ..utils import MInstrDataset, BaseComputeMetrics
|
7 |
+
|
8 |
+
REFID_PAT = re.compile(r'(\s\((?:(?:\d+(?:,\d+)*)|-)\)\s?)')
|
9 |
+
ANS_EXTRACT_PAT = re.compile(r'(?:(?:(?:(?:(?:So t)|(?:T)|(?:t))he answer is)|(?:Answer:)) (.+))')
|
10 |
+
|
11 |
+
|
12 |
+
@DATASETS.register_module()
|
13 |
+
class GQADataset(MInstrDataset):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
*args,
|
17 |
+
scene_graph_file,
|
18 |
+
scene_graph_index,
|
19 |
+
version,
|
20 |
+
question_box_prob=0.5,
|
21 |
+
**kwargs
|
22 |
+
):
|
23 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
|
24 |
+
self.scene_graph_file = scene_graph_file
|
25 |
+
self.scene_graph_index = scene_graph_index
|
26 |
+
self.version = version
|
27 |
+
self.question_box_prob = question_box_prob
|
28 |
+
qtype, atype = version.split('-')
|
29 |
+
assert qtype in ['q', 'qb', 'qbp']
|
30 |
+
assert atype in ['a', 'c', 'bc', 's', 'bs', 'l', 'bl']
|
31 |
+
self.qtype = qtype
|
32 |
+
self.atype = atype
|
33 |
+
|
34 |
+
assert bool(scene_graph_file) == bool(scene_graph_index)
|
35 |
+
if scene_graph_file is not None and scene_graph_index is not None:
|
36 |
+
self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')]
|
37 |
+
self.scene_index = json.load(open(scene_graph_index, 'r', encoding='utf8'))
|
38 |
+
else:
|
39 |
+
self.scene_graph = None
|
40 |
+
self.scene_index = None
|
41 |
+
|
42 |
+
def get_raw_item(self, index):
|
43 |
+
question = json.loads(self.data[index])
|
44 |
+
if self.scene_graph is None:
|
45 |
+
return question, None
|
46 |
+
scene = json.loads(self.scene_graph[self.scene_index[question['imageId']]])
|
47 |
+
return question, scene
|
48 |
+
|
49 |
+
def __getitem__(self, index):
|
50 |
+
question, scene = self.get_raw_item(index)
|
51 |
+
img_path = f"{question['imageId']}.jpg"
|
52 |
+
image = self.get_image(img_path)
|
53 |
+
|
54 |
+
# answer
|
55 |
+
if self.atype == 'bc':
|
56 |
+
boxes = question['cot']['boxes']
|
57 |
+
answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
|
58 |
+
answer_boxes_seq = question['cot']['seq']
|
59 |
+
elif self.atype == 'c':
|
60 |
+
boxes = []
|
61 |
+
answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, "")
|
62 |
+
answer_boxes_seq = []
|
63 |
+
elif self.atype == 'bs':
|
64 |
+
boxes, bss, answer_boxes_seq = get_bss_example(question, scene)
|
65 |
+
answer = f"{bss}. The answer is {question['answer']}."
|
66 |
+
elif self.atype == 's':
|
67 |
+
boxes = []
|
68 |
+
ss = REFID_PAT.sub('', question['semanticStr'])
|
69 |
+
answer = f"{ss}. The answer is {question['answer']}."
|
70 |
+
answer_boxes_seq = []
|
71 |
+
elif self.atype == 'bl':
|
72 |
+
boxes, answer, answer_boxes_seq = get_bl_example(question, scene)
|
73 |
+
elif self.atype == 'l':
|
74 |
+
boxes = []
|
75 |
+
_, answer, _ = get_bl_example(question, scene)
|
76 |
+
answer = answer.replace(BOXES_PLACEHOLDER, "")
|
77 |
+
answer_boxes_seq = []
|
78 |
+
elif self.atype == 'a':
|
79 |
+
boxes = []
|
80 |
+
answer = f"The answer is {question['answer']}."
|
81 |
+
answer_boxes_seq = []
|
82 |
+
else:
|
83 |
+
assert False
|
84 |
+
|
85 |
+
# question
|
86 |
+
if self.qtype == 'q':
|
87 |
+
boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene)
|
88 |
+
elif self.qtype == 'qb':
|
89 |
+
boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene)
|
90 |
+
elif self.qtype == 'qbp':
|
91 |
+
if self.rng.uniform() > self.question_box_prob:
|
92 |
+
boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene)
|
93 |
+
else:
|
94 |
+
boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene)
|
95 |
+
else:
|
96 |
+
assert False
|
97 |
+
|
98 |
+
final_query = self.get_template().replace(QUESTION_PLACEHOLDER, query)
|
99 |
+
|
100 |
+
ret = {
|
101 |
+
'image': image,
|
102 |
+
'target': {'boxes': boxes},
|
103 |
+
'conversations': [
|
104 |
+
{
|
105 |
+
'from': 'human',
|
106 |
+
'value': final_query,
|
107 |
+
'boxes_seq': query_boxes_seq,
|
108 |
+
},
|
109 |
+
{
|
110 |
+
'from': 'gpt',
|
111 |
+
'value': answer,
|
112 |
+
'boxes_seq': answer_boxes_seq,
|
113 |
+
}
|
114 |
+
]
|
115 |
+
}
|
116 |
+
return ret
|
117 |
+
|
118 |
+
|
119 |
+
def prepare_query_dummy(boxes_list, q, scene):
|
120 |
+
return boxes_list, q['question'], []
|
121 |
+
|
122 |
+
|
123 |
+
def prepare_query_box(boxes_list, q, scene):
|
124 |
+
def get_boxes_idx(box):
|
125 |
+
if box in boxes_list:
|
126 |
+
return boxes_list.index(box)
|
127 |
+
else:
|
128 |
+
boxes_list.append(box)
|
129 |
+
return len(boxes_list) - 1
|
130 |
+
|
131 |
+
def add_boxes_by_rids(rids):
|
132 |
+
def get_box_xyxy(obj):
|
133 |
+
x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h']
|
134 |
+
return x, y, x + w, y + h
|
135 |
+
|
136 |
+
boxes_idx = []
|
137 |
+
for rid in rids:
|
138 |
+
ref = scene['objects'][rid]
|
139 |
+
ref_box = list(get_box_xyxy(ref))
|
140 |
+
boxes_idx.append(get_boxes_idx(ref_box))
|
141 |
+
return boxes_idx
|
142 |
+
|
143 |
+
sent = list(q['question'].split())
|
144 |
+
query_boxes_seq = []
|
145 |
+
for span, rids_str in q['annotations']['question'].items():
|
146 |
+
span = tuple(map(int, span.split(':')))
|
147 |
+
if len(span) == 1:
|
148 |
+
span = [span[0], span[0] + 1]
|
149 |
+
sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}"
|
150 |
+
boxes_idx = add_boxes_by_rids(rids_str.split(','))
|
151 |
+
query_boxes_seq.append(boxes_idx)
|
152 |
+
sent_converted = " ".join(sent).strip()
|
153 |
+
return boxes_list, sent_converted, query_boxes_seq
|
154 |
+
|
155 |
+
|
156 |
+
def add_boxes_by_rids(boxes_list, rids, scene):
|
157 |
+
def get_boxes_idx(boxes_list, box):
|
158 |
+
if box in boxes_list:
|
159 |
+
return boxes_list.index(box)
|
160 |
+
else:
|
161 |
+
boxes_list.append(box)
|
162 |
+
return len(boxes_list) - 1
|
163 |
+
|
164 |
+
def get_box_xyxy(obj):
|
165 |
+
x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h']
|
166 |
+
return x, y, x + w, y + h
|
167 |
+
|
168 |
+
boxes_idx = []
|
169 |
+
for rid in rids:
|
170 |
+
ref = scene['objects'][rid]
|
171 |
+
ref_box = list(get_box_xyxy(ref))
|
172 |
+
boxes_idx.append(get_boxes_idx(boxes_list, ref_box))
|
173 |
+
return boxes_idx
|
174 |
+
|
175 |
+
|
176 |
+
def get_bss_example(question, scene):
|
177 |
+
def format_refids(item):
|
178 |
+
item = item.strip()[1:-1]
|
179 |
+
return item.split(',')
|
180 |
+
|
181 |
+
s = question['semanticStr']
|
182 |
+
print(REFID_PAT.findall(s))
|
183 |
+
formats = []
|
184 |
+
boxes = []
|
185 |
+
seqs = []
|
186 |
+
|
187 |
+
for item in REFID_PAT.findall(s):
|
188 |
+
if '-' in item:
|
189 |
+
formats.append('')
|
190 |
+
else:
|
191 |
+
formats.append('<boxes>')
|
192 |
+
refids = format_refids(item)
|
193 |
+
idx = add_boxes_by_rids(boxes, refids, scene)
|
194 |
+
seqs.append(idx)
|
195 |
+
answer = REFID_PAT.sub('{}', s).format(*formats)
|
196 |
+
|
197 |
+
print(answer)
|
198 |
+
print(boxes)
|
199 |
+
print(seqs)
|
200 |
+
return boxes, answer, seqs
|
201 |
+
|
202 |
+
|
203 |
+
def get_bl_example(ann, scene):
|
204 |
+
boxes = []
|
205 |
+
boxes_seq = []
|
206 |
+
|
207 |
+
origin_sent = ann['fullAnswer']
|
208 |
+
origin_sent = re.sub('(?:^Yes,)|(?:^No,)', '', origin_sent).strip()
|
209 |
+
sent = list(origin_sent.split())
|
210 |
+
for span, rids_str in ann['annotations']['fullAnswer'].items():
|
211 |
+
span = tuple(map(int, span.split(':')))
|
212 |
+
if len(span) == 1:
|
213 |
+
span = [span[0], span[0] + 1]
|
214 |
+
sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}"
|
215 |
+
rids = rids_str.split(',')
|
216 |
+
boxes_idx = add_boxes_by_rids(boxes, rids, scene)
|
217 |
+
boxes_seq.append(boxes_idx)
|
218 |
+
|
219 |
+
answer = "".join(sent)
|
220 |
+
answer += f"The answer is {ann['answer']}."
|
221 |
+
return boxes, answer, boxes_seq
|
222 |
+
|
223 |
+
|
224 |
+
@METRICS.register_module()
|
225 |
+
class GQAComputeMetrics(BaseComputeMetrics):
|
226 |
+
def extract_ans(self, string: str):
|
227 |
+
try:
|
228 |
+
found = ANS_EXTRACT_PAT.findall(string.strip())
|
229 |
+
if len(found) != 1:
|
230 |
+
return None
|
231 |
+
return found[0].strip().rstrip('.').strip()
|
232 |
+
except (IndexError, AttributeError):
|
233 |
+
return None
|
mllm/dataset/single_image_dataset/instr.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..root import DATASETS
|
2 |
+
from ..utils import MInstrDataset
|
3 |
+
|
4 |
+
|
5 |
+
@DATASETS.register_module()
|
6 |
+
class InstructDataset(MInstrDataset):
|
7 |
+
def __init__(self, *args, add_coco_prefix=False, **kwargs):
|
8 |
+
super().__init__(*args, **kwargs, placeholders=(), template_string='', template_file=None)
|
9 |
+
self.add_coco_prefix = add_coco_prefix
|
10 |
+
|
11 |
+
def __getitem__(self, index):
|
12 |
+
item = self.get_raw_item(index)
|
13 |
+
if self.add_coco_prefix:
|
14 |
+
img_path = f"COCO_train2014_{item['image']}"
|
15 |
+
else:
|
16 |
+
img_path = item['image']
|
17 |
+
conversations = item['conversations']
|
18 |
+
|
19 |
+
image = self.get_image(img_path)
|
20 |
+
ret = {
|
21 |
+
'image': image,
|
22 |
+
'conversations': conversations,
|
23 |
+
}
|
24 |
+
return ret
|
mllm/dataset/single_image_dataset/point_qa.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from .. import BaseComputeMetrics
|
4 |
+
from ..root import (
|
5 |
+
DATASETS,
|
6 |
+
METRICS,
|
7 |
+
QUESTION_PLACEHOLDER,
|
8 |
+
IMAGE_PLACEHOLDER,
|
9 |
+
BOXES_PLACEHOLDER,
|
10 |
+
POINTS_PLACEHOLDER,
|
11 |
+
)
|
12 |
+
from ..utils import MInstrDataset
|
13 |
+
|
14 |
+
|
15 |
+
# noinspection PyPep8Naming
|
16 |
+
@DATASETS.register_module()
|
17 |
+
class Point_QA_local(MInstrDataset):
|
18 |
+
def __init__(self, *args, version='p', qbp_p_prob=0.5, **kwargs):
|
19 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
|
20 |
+
assert version in ['b', 'p', 'bp']
|
21 |
+
self.version = version
|
22 |
+
self.qbp_p_prob = qbp_p_prob
|
23 |
+
|
24 |
+
def __getitem__(self, index):
|
25 |
+
item = self.get_raw_item(index)
|
26 |
+
# image
|
27 |
+
img_path = item['file_path']
|
28 |
+
image = self.get_image(img_path)
|
29 |
+
# answer
|
30 |
+
answer = item['answer']
|
31 |
+
# question
|
32 |
+
question = item['question']
|
33 |
+
bbox = item['bbox']
|
34 |
+
point = item['point']
|
35 |
+
|
36 |
+
version = self.version
|
37 |
+
if version == 'bp':
|
38 |
+
version = 'p' if self.rng.random() < self.qbp_p_prob else 'b'
|
39 |
+
if version == 'b':
|
40 |
+
question = question + BOXES_PLACEHOLDER
|
41 |
+
query_boxes_seq = [[0]]
|
42 |
+
query_points_seq = None
|
43 |
+
elif version == 'p':
|
44 |
+
question = question + POINTS_PLACEHOLDER
|
45 |
+
query_boxes_seq = None
|
46 |
+
query_points_seq = [[0]]
|
47 |
+
else:
|
48 |
+
assert False
|
49 |
+
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
|
50 |
+
|
51 |
+
ret = {
|
52 |
+
'image': image,
|
53 |
+
'target': {
|
54 |
+
'boxes': [bbox],
|
55 |
+
'points': [point],
|
56 |
+
},
|
57 |
+
'conversations': [
|
58 |
+
{
|
59 |
+
'from': 'human',
|
60 |
+
'value': final_question,
|
61 |
+
'boxes_seq': query_boxes_seq,
|
62 |
+
'points_seq': query_points_seq,
|
63 |
+
},
|
64 |
+
{
|
65 |
+
'from': 'gpt',
|
66 |
+
'value': f'The answer is {answer} .',
|
67 |
+
}
|
68 |
+
]
|
69 |
+
}
|
70 |
+
return ret
|
71 |
+
|
72 |
+
|
73 |
+
# noinspection PyPep8Naming
|
74 |
+
@DATASETS.register_module()
|
75 |
+
class Point_QA_twice(MInstrDataset):
|
76 |
+
def __init__(self, *args, version='gq-p', bp_p_prob=0.5, **kwargs):
|
77 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
|
78 |
+
self.version = version
|
79 |
+
self.bp_p_prob = bp_p_prob
|
80 |
+
qtype, rtype = version.split('-')
|
81 |
+
assert qtype in ['oq', 'sq', 'gq']
|
82 |
+
assert rtype in ['b', 'p', 'bp']
|
83 |
+
self.qtype = qtype
|
84 |
+
self.rtype = rtype
|
85 |
+
|
86 |
+
def __getitem__(self, index):
|
87 |
+
item = self.get_raw_item(index)
|
88 |
+
# image
|
89 |
+
img_path = item['file_path']
|
90 |
+
image = self.get_image(img_path)
|
91 |
+
# answer
|
92 |
+
answer = item['answer']
|
93 |
+
# question
|
94 |
+
bbox = item['bbox']
|
95 |
+
point = item['point']
|
96 |
+
if self.qtype == 'oq':
|
97 |
+
question = item['obj_question']
|
98 |
+
elif self.qtype == 'sq':
|
99 |
+
question = item['super_question']
|
100 |
+
elif self.qtype == 'gq':
|
101 |
+
question = item['general_question']
|
102 |
+
else:
|
103 |
+
assert False
|
104 |
+
rtype = self.rtype
|
105 |
+
if rtype == 'bp':
|
106 |
+
rtype = 'p' if self.rng.random() < self.bp_p_prob else 'b'
|
107 |
+
if rtype == 'p':
|
108 |
+
question = question + POINTS_PLACEHOLDER
|
109 |
+
query_boxes_seq = None
|
110 |
+
query_points_seq = [[0]]
|
111 |
+
elif rtype == 'b':
|
112 |
+
question = question + BOXES_PLACEHOLDER
|
113 |
+
query_boxes_seq = [[0]]
|
114 |
+
query_points_seq = None
|
115 |
+
else:
|
116 |
+
assert False
|
117 |
+
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
|
118 |
+
|
119 |
+
ret = {
|
120 |
+
'image': image,
|
121 |
+
'target': {
|
122 |
+
'boxes': [bbox],
|
123 |
+
'points': [point],
|
124 |
+
},
|
125 |
+
'conversations': [
|
126 |
+
{
|
127 |
+
'from': 'human',
|
128 |
+
'value': final_question,
|
129 |
+
'boxes_seq': query_boxes_seq,
|
130 |
+
'points_seq': query_points_seq,
|
131 |
+
},
|
132 |
+
{
|
133 |
+
'from': 'gpt',
|
134 |
+
'value': f'The answer is {answer} .',
|
135 |
+
}
|
136 |
+
]
|
137 |
+
}
|
138 |
+
return ret
|
139 |
+
|
140 |
+
|
141 |
+
# noinspection PyPep8Naming
|
142 |
+
@DATASETS.register_module()
|
143 |
+
class V7W_POINT(MInstrDataset):
|
144 |
+
def __init__(self, *args, version, do_shuffle_choice=True, **kwargs):
|
145 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
|
146 |
+
self.version = version
|
147 |
+
self.do_shuffle_choice = do_shuffle_choice
|
148 |
+
assert version in ['p', 'b']
|
149 |
+
|
150 |
+
def __len__(self):
|
151 |
+
return len(self.data)
|
152 |
+
|
153 |
+
def __getitem__(self, index):
|
154 |
+
item = self.get_raw_item(index)
|
155 |
+
# image
|
156 |
+
img_path = item['file_path']
|
157 |
+
image = self.get_image(img_path)
|
158 |
+
# question
|
159 |
+
bboxes = item['candidates']
|
160 |
+
points = []
|
161 |
+
final_question = item['question'] + ' Candidates: ' + " ".join([BOXES_PLACEHOLDER for _ in range(len(bboxes))])
|
162 |
+
query_boxes_seq = []
|
163 |
+
for _ in range(len(bboxes)):
|
164 |
+
query_boxes_seq.append([_])
|
165 |
+
# answer
|
166 |
+
if self.version == 'p':
|
167 |
+
final_question += f" answer in point format."
|
168 |
+
points.append(item['point'])
|
169 |
+
final_answer = f"The answer is {POINTS_PLACEHOLDER} ."
|
170 |
+
answer_boxes_seq = None
|
171 |
+
answer_points_seq = [[0]]
|
172 |
+
elif self.version == 'b':
|
173 |
+
final_question += f" answer in box format."
|
174 |
+
idx = bboxes.index(item['answer'])
|
175 |
+
final_answer = f"The answer is {BOXES_PLACEHOLDER} ."
|
176 |
+
answer_boxes_seq = [[idx]]
|
177 |
+
answer_points_seq = None
|
178 |
+
else:
|
179 |
+
assert False
|
180 |
+
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, final_question)
|
181 |
+
if self.do_shuffle_choice:
|
182 |
+
self.rng.shuffle(query_boxes_seq)
|
183 |
+
# bboxes, query_boxes_seq, answer_boxes_seq = self.shuffle_boxes(bboxes, query_boxes_seq, answer_boxes_seq)
|
184 |
+
|
185 |
+
ret = {
|
186 |
+
'image': image,
|
187 |
+
'target': {
|
188 |
+
'boxes': bboxes,
|
189 |
+
'points': points,
|
190 |
+
},
|
191 |
+
'conversations': [
|
192 |
+
{
|
193 |
+
'from': 'human',
|
194 |
+
'value': final_question,
|
195 |
+
'boxes_seq': query_boxes_seq,
|
196 |
+
},
|
197 |
+
{
|
198 |
+
'from': 'gpt',
|
199 |
+
'value': final_answer,
|
200 |
+
'boxes_seq': answer_boxes_seq,
|
201 |
+
'points_seq': answer_points_seq,
|
202 |
+
|
203 |
+
}
|
204 |
+
]
|
205 |
+
}
|
206 |
+
return ret
|
207 |
+
|
208 |
+
# def shuffle_boxes(self, bboxes, query_boxes_seq, answer_boxes_seq):
|
209 |
+
# idx_mapping = list(range(len(bboxes)))
|
210 |
+
# self.rng.shuffle(idx_mapping)
|
211 |
+
#
|
212 |
+
# new_bboxes = [None for _ in range(len(bboxes))]
|
213 |
+
# for idx_old, idx_new in enumerate(idx_mapping):
|
214 |
+
# new_bboxes[idx_new] = bboxes[idx_old]
|
215 |
+
#
|
216 |
+
# if query_boxes_seq is None:
|
217 |
+
# new_query_boxes_seq = None
|
218 |
+
# else:
|
219 |
+
# new_query_boxes_seq = []
|
220 |
+
# for boxes in query_boxes_seq:
|
221 |
+
# new_boxes = [idx_mapping[box_idx] for box_idx in boxes]
|
222 |
+
# new_query_boxes_seq.append(new_boxes)
|
223 |
+
#
|
224 |
+
# if answer_boxes_seq is None:
|
225 |
+
# new_answer_boxes_seq = None
|
226 |
+
# else:
|
227 |
+
# new_answer_boxes_seq = []
|
228 |
+
# for boxes in answer_boxes_seq:
|
229 |
+
# new_boxes = [idx_mapping[box_idx] for box_idx in boxes]
|
230 |
+
# new_answer_boxes_seq.append(new_boxes)
|
231 |
+
#
|
232 |
+
# return new_bboxes, new_query_boxes_seq, new_answer_boxes_seq
|
233 |
+
|
234 |
+
|
235 |
+
ANS_EXTRACT_PAT = re.compile(r'(?:The answer is (.+?)\.)')
|
236 |
+
|
237 |
+
|
238 |
+
@METRICS.register_module()
|
239 |
+
class PointQAComputeMetrics(BaseComputeMetrics):
|
240 |
+
def extract_ans(self, string: str):
|
241 |
+
try:
|
242 |
+
found = ANS_EXTRACT_PAT.findall(string.strip())
|
243 |
+
if len(found) != 1:
|
244 |
+
return None
|
245 |
+
return found[0].strip()
|
246 |
+
except (IndexError, AttributeError):
|
247 |
+
return None
|
mllm/dataset/single_image_dataset/pope.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..root import (
|
2 |
+
DATASETS,
|
3 |
+
QUESTION_PLACEHOLDER,
|
4 |
+
IMAGE_PLACEHOLDER,
|
5 |
+
)
|
6 |
+
from ..utils import MInstrDataset
|
7 |
+
|
8 |
+
|
9 |
+
@DATASETS.register_module()
|
10 |
+
class POPEVQADataset(MInstrDataset):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
|
13 |
+
|
14 |
+
def __getitem__(self, index):
|
15 |
+
item = self.get_raw_item(index)
|
16 |
+
image = self.get_image(image_path=item['image'])
|
17 |
+
|
18 |
+
question = item['text']
|
19 |
+
final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
|
20 |
+
|
21 |
+
label = str(item['label']).lower()
|
22 |
+
|
23 |
+
ret = {
|
24 |
+
'image': image,
|
25 |
+
'conversations': [
|
26 |
+
{
|
27 |
+
'from': 'human',
|
28 |
+
'value': final_question,
|
29 |
+
},
|
30 |
+
{
|
31 |
+
'from': 'gpt',
|
32 |
+
'value': f"The answer is {label} .",
|
33 |
+
},
|
34 |
+
]
|
35 |
+
}
|
36 |
+
return ret
|
mllm/dataset/single_image_dataset/rec.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
import warnings
|
4 |
+
from typing import Dict, Any, Sequence
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torchvision.ops import box_iou
|
8 |
+
|
9 |
+
from ..utils import (
|
10 |
+
MInstrDataset,
|
11 |
+
BaseComputeMetrics,
|
12 |
+
)
|
13 |
+
|
14 |
+
from ..process_function import (
|
15 |
+
BoxFormatter,
|
16 |
+
)
|
17 |
+
|
18 |
+
from ..root import (
|
19 |
+
DATASETS,
|
20 |
+
METRICS,
|
21 |
+
IMAGE_PLACEHOLDER,
|
22 |
+
BOXES_PLACEHOLDER,
|
23 |
+
EXPR_PLACEHOLDER,
|
24 |
+
)
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
logger.setLevel(logging.INFO)
|
28 |
+
logging.basicConfig(
|
29 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
30 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
31 |
+
handlers=[logging.StreamHandler(sys.stdout), ],
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
@DATASETS.register_module()
|
36 |
+
class RECDataset(MInstrDataset):
|
37 |
+
def __init__(self, *args, **kwargs):
|
38 |
+
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, EXPR_PLACEHOLDER))
|
39 |
+
|
40 |
+
def __getitem__(self, index):
|
41 |
+
item = self.get_raw_item(index)
|
42 |
+
img_path = item['img_path']
|
43 |
+
expr = item['expression']
|
44 |
+
bbox = item['bbox']
|
45 |
+
|
46 |
+
image = self.get_image(img_path)
|
47 |
+
question = self.get_template().replace(EXPR_PLACEHOLDER, expr)
|
48 |
+
|
49 |
+
ret = {
|
50 |
+
'image': image,
|
51 |
+
'target': {
|
52 |
+
'boxes': [bbox],
|
53 |
+
},
|
54 |
+
'conversations': [
|
55 |
+
{
|
56 |
+
'from': 'human',
|
57 |
+
'value': question,
|
58 |
+
},
|
59 |
+
{
|
60 |
+
'from': 'gpt',
|
61 |
+
'value': f'Answer: {BOXES_PLACEHOLDER} .',
|
62 |
+
'boxes_seq': [[0]],
|
63 |
+
}
|
64 |
+
]
|
65 |
+
}
|
66 |
+
return ret
|
67 |
+
|
68 |
+
|
69 |
+
@METRICS.register_module()
|
70 |
+
class RECComputeMetrics(BaseComputeMetrics):
|
71 |
+
def __init__(self, *args, **kwargs):
|
72 |
+
super().__init__(*args, **kwargs)
|
73 |
+
self.box_formatter: BoxFormatter = self.preprocessor['target']['boxes']
|
74 |
+
|
75 |
+
def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]:
|
76 |
+
failed = 0
|
77 |
+
target_failed = 0
|
78 |
+
|
79 |
+
pred_boxes, target_boxes = [], []
|
80 |
+
for pred, target in zip(preds, targets):
|
81 |
+
extract_pred = self.extract_ans(pred)
|
82 |
+
extract_target = self.extract_ans(target)
|
83 |
+
if extract_target is None:
|
84 |
+
target_failed += 1
|
85 |
+
logger.warning(f"failed to extract ans for target: {target}")
|
86 |
+
continue
|
87 |
+
if extract_pred is None:
|
88 |
+
failed += 1
|
89 |
+
logger.warning(f"failed to extract ans for pred: {pred}")
|
90 |
+
extract_pred = [0, 0, 0, 0]
|
91 |
+
target_boxes.append(extract_target)
|
92 |
+
pred_boxes.append(extract_pred)
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
target_boxes = torch.tensor(target_boxes)
|
96 |
+
pred_boxes = torch.tensor(pred_boxes)
|
97 |
+
# normalized box value is too small, so that the area is 0.
|
98 |
+
ious = box_iou(pred_boxes * 1000, target_boxes * 1000)
|
99 |
+
ious = torch.einsum('i i -> i', ious) # take diag elem
|
100 |
+
# NOTE: please note iou only calculate for success target
|
101 |
+
iou = ious.mean().item()
|
102 |
+
correct = (ious > 0.5).sum().item()
|
103 |
+
|
104 |
+
# HACK: currently we expand image to square. so this iou is the real iou.
|
105 |
+
warn_message = "this iou is calculate on normalized box. just for non-rigorous training progress checking." \
|
106 |
+
"the value is consistent with real iou only if image.width == image.height."
|
107 |
+
warnings.warn(warn_message)
|
108 |
+
|
109 |
+
return {
|
110 |
+
'accuracy': 1.0 * correct / len(targets),
|
111 |
+
'target_failed': target_failed,
|
112 |
+
'failed': failed,
|
113 |
+
'iou': iou,
|
114 |
+
'warning': warn_message,
|
115 |
+
}
|
116 |
+
|
117 |
+
def extract_ans(self, string: str):
|
118 |
+
try:
|
119 |
+
list_of_boxes = self.box_formatter.extract(string)
|
120 |
+
if len(list_of_boxes) != 1 or len(list_of_boxes[0]) != 1:
|
121 |
+
return None
|
122 |
+
box = list_of_boxes[0][0]
|
123 |
+
if len(box) != 4:
|
124 |
+
return None
|
125 |
+
return box
|
126 |
+
except Exception as e:
|
127 |
+
logger.warning(f"extract_ans for {string} but get exception: {e}")
|
128 |
+
return None
|