CYF200127 commited on
Commit
3e1d9f3
·
verified ·
1 Parent(s): 6b368e8

Upload 235 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. mllm/__init__.py +0 -0
  3. mllm/__pycache__/__init__.cpython-310.pyc +0 -0
  4. mllm/config/__init__.py +1 -0
  5. mllm/config/__pycache__/__init__.cpython-310.pyc +0 -0
  6. mllm/config/__pycache__/config.cpython-310.pyc +0 -0
  7. mllm/config/config.py +135 -0
  8. mllm/conversation/__init__.py +1 -0
  9. mllm/conversation/__pycache__/__init__.cpython-310.pyc +0 -0
  10. mllm/conversation/__pycache__/base_conversation.cpython-310.pyc +0 -0
  11. mllm/conversation/base_conversation.py +503 -0
  12. mllm/dataset/__init__.py +7 -0
  13. mllm/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  14. mllm/dataset/__pycache__/builder.cpython-310.pyc +0 -0
  15. mllm/dataset/__pycache__/root.cpython-310.pyc +0 -0
  16. mllm/dataset/__pycache__/single_image_convsation.cpython-310.pyc +0 -0
  17. mllm/dataset/__pycache__/single_image_interactive.cpython-310.pyc +0 -0
  18. mllm/dataset/builder.py +118 -0
  19. mllm/dataset/process_function/__init__.py +13 -0
  20. mllm/dataset/process_function/__pycache__/__init__.cpython-310.pyc +0 -0
  21. mllm/dataset/process_function/__pycache__/box_process_function.cpython-310.pyc +0 -0
  22. mllm/dataset/process_function/__pycache__/shikra_process_function.cpython-310.pyc +0 -0
  23. mllm/dataset/process_function/box_process_function.py +326 -0
  24. mllm/dataset/process_function/shikra_process_function.py +178 -0
  25. mllm/dataset/root.py +67 -0
  26. mllm/dataset/single_image_convsation.py +284 -0
  27. mllm/dataset/single_image_dataset/__init__.py +13 -0
  28. mllm/dataset/single_image_dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  29. mllm/dataset/single_image_dataset/__pycache__/caption.cpython-310.pyc +0 -0
  30. mllm/dataset/single_image_dataset/__pycache__/clevr.cpython-310.pyc +0 -0
  31. mllm/dataset/single_image_dataset/__pycache__/flickr.cpython-310.pyc +0 -0
  32. mllm/dataset/single_image_dataset/__pycache__/gpt_gen.cpython-310.pyc +0 -0
  33. mllm/dataset/single_image_dataset/__pycache__/gqa.cpython-310.pyc +0 -0
  34. mllm/dataset/single_image_dataset/__pycache__/instr.cpython-310.pyc +0 -0
  35. mllm/dataset/single_image_dataset/__pycache__/point_qa.cpython-310.pyc +0 -0
  36. mllm/dataset/single_image_dataset/__pycache__/pope.cpython-310.pyc +0 -0
  37. mllm/dataset/single_image_dataset/__pycache__/rec.cpython-310.pyc +0 -0
  38. mllm/dataset/single_image_dataset/__pycache__/reg.cpython-310.pyc +0 -0
  39. mllm/dataset/single_image_dataset/__pycache__/vcr.cpython-310.pyc +0 -0
  40. mllm/dataset/single_image_dataset/__pycache__/vqaex.cpython-310.pyc +0 -0
  41. mllm/dataset/single_image_dataset/__pycache__/vqav2.cpython-310.pyc +0 -0
  42. mllm/dataset/single_image_dataset/caption.py +31 -0
  43. mllm/dataset/single_image_dataset/clevr.py +116 -0
  44. mllm/dataset/single_image_dataset/flickr.py +68 -0
  45. mllm/dataset/single_image_dataset/gpt_gen.py +58 -0
  46. mllm/dataset/single_image_dataset/gqa.py +233 -0
  47. mllm/dataset/single_image_dataset/instr.py +24 -0
  48. mllm/dataset/single_image_dataset/point_qa.py +247 -0
  49. mllm/dataset/single_image_dataset/pope.py +36 -0
  50. mllm/dataset/single_image_dataset/rec.py +128 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ mllm/demo/assets/baseball.png filter=lfs diff=lfs merge=lfs -text
mllm/__init__.py ADDED
File without changes
mllm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (137 Bytes). View file
 
mllm/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .config import prepare_args
mllm/config/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file
 
mllm/config/__pycache__/config.cpython-310.pyc ADDED
Binary file (4.24 kB). View file
 
mllm/config/config.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import argparse
5
+ from dataclasses import dataclass, field
6
+ from typing import List, Tuple
7
+ from argparse import SUPPRESS
8
+
9
+ import datasets
10
+ import transformers
11
+ from mmengine.config import Config, DictAction
12
+ from transformers import HfArgumentParser, set_seed, add_start_docstrings
13
+ from transformers import Seq2SeqTrainingArguments as HFSeq2SeqTrainingArguments
14
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
15
+
16
+ logger = logging.getLogger(__name__)
17
+ logger.setLevel(logging.INFO)
18
+ logging.basicConfig(
19
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
20
+ datefmt="%m/%d/%Y %H:%M:%S",
21
+ handlers=[logging.StreamHandler(sys.stdout), ],
22
+ )
23
+
24
+
25
+ @dataclass
26
+ @add_start_docstrings(HFSeq2SeqTrainingArguments.__doc__)
27
+ class Seq2SeqTrainingArguments(HFSeq2SeqTrainingArguments):
28
+ do_multi_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the multi-test set."})
29
+
30
+
31
+ def prepare_args(args=None):
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('config', help='train config file path')
34
+ parser.add_argument(
35
+ '--cfg-options',
36
+ nargs='+',
37
+ action=DictAction,
38
+ help='override some settings in the used config, the key-value pair '
39
+ 'in xxx=yyy format will be merged into config file. If the value to '
40
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
41
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
42
+ 'Note that the quotation marks are necessary and that no white space '
43
+ 'is allowed.')
44
+
45
+ hf_parser = HfArgumentParser((Seq2SeqTrainingArguments,))
46
+ hf_parser, required = block_required_error(hf_parser)
47
+
48
+ args, unknown_args = parser.parse_known_args(args)
49
+ known_hf_args, unknown_args = hf_parser.parse_known_args(unknown_args)
50
+ if unknown_args:
51
+ raise ValueError(f"Some specified arguments are not used "
52
+ f"by the ArgumentParser or HfArgumentParser\n: {unknown_args}")
53
+
54
+ # load 'cfg' and 'training_args' from file and cli
55
+ cfg = Config.fromfile(args.config)
56
+ if args.cfg_options is not None:
57
+ cfg.merge_from_dict(args.cfg_options)
58
+ training_args = cfg.training_args
59
+ training_args.update(vars(known_hf_args))
60
+
61
+ # check training_args require
62
+ req_but_not_assign = [item for item in required if item not in training_args]
63
+ if req_but_not_assign:
64
+ raise ValueError(f"Requires {req_but_not_assign} but not assign.")
65
+
66
+ # update cfg.training_args
67
+ cfg.training_args = training_args
68
+
69
+ # initialize and return
70
+ training_args = Seq2SeqTrainingArguments(**training_args)
71
+ training_args = check_output_dir(training_args)
72
+
73
+ # logging
74
+ if is_main_process(training_args.local_rank):
75
+ to_logging_cfg = Config()
76
+ to_logging_cfg.model_args = cfg.model_args
77
+ to_logging_cfg.data_args = cfg.data_args
78
+ to_logging_cfg.training_args = cfg.training_args
79
+ logger.info(to_logging_cfg.pretty_text)
80
+
81
+ # setup logger
82
+ if training_args.should_log:
83
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
84
+ transformers.logging.set_verbosity_info()
85
+ log_level = training_args.get_process_log_level()
86
+ logger.setLevel(log_level)
87
+ datasets.utils.logging.set_verbosity(log_level)
88
+ transformers.logging.set_verbosity(log_level)
89
+ transformers.logging.enable_default_handler()
90
+ transformers.logging.enable_explicit_format()
91
+ # setup_print_for_distributed(is_main_process(training_args))
92
+
93
+ # Log on each process the small summary:
94
+ logger.info(f"Training/evaluation parameters {training_args}")
95
+ logger.warning(
96
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
97
+ + f" distributed training: {bool(training_args.local_rank != -1)}, fp16 training: {training_args.fp16}"
98
+ )
99
+
100
+ # Set seed before initializing model.
101
+ set_seed(training_args.seed)
102
+
103
+ return cfg, training_args
104
+
105
+
106
+ def block_required_error(hf_parser: HfArgumentParser) -> Tuple[HfArgumentParser, List]:
107
+ required = []
108
+ # noinspection PyProtectedMember
109
+ for action in hf_parser._actions:
110
+ if action.required:
111
+ required.append(action.dest)
112
+ action.required = False
113
+ action.default = SUPPRESS
114
+ return hf_parser, required
115
+
116
+
117
+ def check_output_dir(training_args):
118
+ # Detecting last checkpoint.
119
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
120
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
121
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
122
+ raise ValueError(
123
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
124
+ "Use --overwrite_output_dir to overcome."
125
+ )
126
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
127
+ logger.info(
128
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
129
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
130
+ )
131
+ return training_args
132
+
133
+
134
+ if __name__ == "__main__":
135
+ _ = prepare_args()
mllm/conversation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .base_conversation import SeparatorStyle, Conversation, register_conv_template, get_conv_template
mllm/conversation/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (290 Bytes). View file
 
mllm/conversation/__pycache__/base_conversation.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
mllm/conversation/base_conversation.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
2
+ """
3
+ Conversation prompt templates.
4
+ """
5
+
6
+ import dataclasses
7
+ from enum import auto, Enum
8
+ from typing import List, Tuple, Any, Dict
9
+
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Separator styles."""
13
+
14
+ ADD_COLON_SINGLE = auto()
15
+ ADD_COLON_TWO = auto()
16
+ ADD_SPACE_TWO = auto()
17
+ NO_COLON_SINGLE = auto()
18
+ BAIZE = auto()
19
+ DOLLY = auto()
20
+ RWKV = auto()
21
+ PHOENIX = auto()
22
+ NEW_LINE = auto()
23
+ BILLA = auto()
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class Conversation:
28
+ """A class that keeps all conversation history."""
29
+
30
+ # The name of this template
31
+ name: str
32
+ # System prompts
33
+ system: str
34
+ # Two roles
35
+ roles: List[str]
36
+ # All messages
37
+ messages: List[List[str]]
38
+ # Offset of few shot examples
39
+ offset: int
40
+ # Separators
41
+ sep_style: SeparatorStyle
42
+ sep: str
43
+ sep2: str = None
44
+ # Stop criteria (the default one is EOS token)
45
+ stop_str: str = None
46
+ # Stops generation if meeting any token in this list
47
+ stop_token_ids: List[int] = None
48
+
49
+ # Used for the state in the gradio servers.
50
+ # TODO(lmzheng): move this out of this class.
51
+ conv_id: Any = None
52
+ skip_next: bool = False
53
+ model_name: str = None
54
+
55
+ def get_prompt(self) -> str:
56
+ """Get the prompt for generation."""
57
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
58
+ ret = self.system + self.sep
59
+ for role, message in self.messages:
60
+ if message:
61
+ ret += role + ": " + message + self.sep
62
+ else:
63
+ ret += role + ":"
64
+ return ret
65
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
66
+ seps = [self.sep, self.sep2]
67
+ ret = self.system + seps[0]
68
+ for i, (role, message) in enumerate(self.messages):
69
+ if message:
70
+ ret += role + ": " + message + seps[i % 2]
71
+ else:
72
+ ret += role + ":"
73
+ return ret
74
+ elif self.sep_style == SeparatorStyle.ADD_SPACE_TWO:
75
+ seps = [self.sep, self.sep2]
76
+ ret = self.system + seps[0]
77
+ for i, (role, message) in enumerate(self.messages):
78
+ if message:
79
+ ret += role + " " + message + seps[i % 2]
80
+ else:
81
+ ret += role + ""
82
+ return ret
83
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
84
+ ret = self.system
85
+ for role, message in self.messages:
86
+ if message:
87
+ ret += role + message + self.sep
88
+ else:
89
+ ret += role
90
+ return ret
91
+ elif self.sep_style == SeparatorStyle.BAIZE:
92
+ ret = self.system + "\n"
93
+ for role, message in self.messages:
94
+ if message:
95
+ ret += role + message + "\n"
96
+ else:
97
+ ret += role
98
+ return ret
99
+ elif self.sep_style == SeparatorStyle.DOLLY:
100
+ seps = [self.sep, self.sep2]
101
+ ret = self.system
102
+ for i, (role, message) in enumerate(self.messages):
103
+ if message:
104
+ ret += role + ":\n" + message + seps[i % 2]
105
+ if i % 2 == 1:
106
+ ret += "\n\n"
107
+ else:
108
+ ret += role + ":\n"
109
+ return ret
110
+ elif self.sep_style == SeparatorStyle.RWKV:
111
+ ret = self.system
112
+ for i, (role, message) in enumerate(self.messages):
113
+ if message:
114
+ ret += (
115
+ role
116
+ + ": "
117
+ + message.replace("\r\n", "\n").replace("\n\n", "\n")
118
+ )
119
+ ret += "\n\n"
120
+ else:
121
+ ret += role + ":"
122
+ return ret
123
+ elif self.sep_style == SeparatorStyle.PHOENIX:
124
+ ret = self.system
125
+ for role, message in self.messages:
126
+ if message:
127
+ ret += role + ": " + "<s>" + message + "</s>"
128
+ else:
129
+ ret += role + ": " + "<s>"
130
+ return ret
131
+ elif self.sep_style == SeparatorStyle.NEW_LINE:
132
+ ret = self.system + self.sep
133
+ for role, message in self.messages:
134
+ if message:
135
+ ret += role + "\n" + message + self.sep
136
+ else:
137
+ ret += role + "\n"
138
+ return ret
139
+ elif self.sep_style == SeparatorStyle.BILLA:
140
+ ret = self.system + self.sep
141
+ for role, message in self.messages:
142
+ if message:
143
+ ret += role + ": " + message + self.sep
144
+ else:
145
+ ret += role + ": " # must be end with a space
146
+ return ret
147
+ else:
148
+ raise ValueError(f"Invalid style: {self.sep_style}")
149
+
150
+ def append_message(self, role: str, message: str):
151
+ """Append a new message."""
152
+ self.messages.append([role, message])
153
+
154
+ def to_gradio_chatbot(self):
155
+ """Convert the history to gradio chatbot format"""
156
+ ret = []
157
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
158
+ if i % 2 == 0:
159
+ ret.append([msg, None])
160
+ else:
161
+ ret[-1][-1] = msg
162
+ return ret
163
+
164
+ def to_openai_api_messages(self):
165
+ """Convert the conversation to OpenAI chat completion format."""
166
+ ret = [{"role": "system", "content": self.system}]
167
+
168
+ for i, (_, msg) in enumerate(self.messages[self.offset:]):
169
+ if i % 2 == 0:
170
+ ret.append({"role": "user", "content": msg})
171
+ else:
172
+ if msg is not None:
173
+ ret.append({"role": "assistant", "content": msg})
174
+ return ret
175
+
176
+ def copy(self):
177
+ return Conversation(
178
+ name=self.name,
179
+ system=self.system,
180
+ roles=self.roles,
181
+ messages=[[x, y] for x, y in self.messages],
182
+ offset=self.offset,
183
+ sep_style=self.sep_style,
184
+ sep=self.sep,
185
+ sep2=self.sep2,
186
+ stop_str=self.stop_str,
187
+ stop_token_ids=self.stop_token_ids,
188
+ conv_id=self.conv_id,
189
+ model_name=self.model_name,
190
+ )
191
+
192
+ def dict(self):
193
+ return {
194
+ "name": self.name,
195
+ "system": self.system,
196
+ "roles": self.roles,
197
+ "messages": self.messages,
198
+ "offset": self.offset,
199
+ "conv_id": self.conv_id,
200
+ "model_name": self.model_name,
201
+ }
202
+
203
+
204
+ # A global registry for all conversation templates
205
+ conv_templates: Dict[str, Conversation] = {}
206
+
207
+
208
+ def register_conv_template(template: Conversation, override: bool = False):
209
+ """Register a new conversation template."""
210
+ if not override:
211
+ assert template.name not in conv_templates, f"{template.name} has been registered."
212
+ conv_templates[template.name] = template
213
+
214
+
215
+ def get_conv_template(name: str) -> Conversation:
216
+ """Get a conversation template."""
217
+ return conv_templates[name].copy()
218
+
219
+
220
+ # A template with one conversation example
221
+ register_conv_template(
222
+ Conversation(
223
+ name="one_shot",
224
+ system="A chat between a curious human and an artificial intelligence assistant. "
225
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
226
+ roles=("Human", "Assistant"),
227
+ messages=(
228
+ (
229
+ "Human",
230
+ "What are the key differences between renewable and non-renewable energy sources?",
231
+ ),
232
+ (
233
+ "Assistant",
234
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
235
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
236
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
237
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
238
+ "renewable and non-renewable energy sources:\n"
239
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
240
+ "energy sources are finite and will eventually run out.\n"
241
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
242
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
243
+ "and other negative effects.\n"
244
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
245
+ "have lower operational costs than non-renewable sources.\n"
246
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
247
+ "locations than non-renewable sources.\n"
248
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
249
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
250
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
251
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.",
252
+ ),
253
+ ),
254
+ offset=2,
255
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
256
+ sep="\n### ",
257
+ stop_str="###",
258
+ )
259
+ )
260
+
261
+ # Vicuna v1.1 template
262
+ register_conv_template(
263
+ Conversation(
264
+ name="vicuna_v1.1",
265
+ system="A chat between a curious user and an artificial intelligence assistant. "
266
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
267
+ roles=("USER", "ASSISTANT"),
268
+ messages=(),
269
+ offset=0,
270
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
271
+ sep=" ",
272
+ sep2="</s>",
273
+ )
274
+ )
275
+
276
+ # Koala default template
277
+ register_conv_template(
278
+ Conversation(
279
+ name="koala_v1",
280
+ system="BEGINNING OF CONVERSATION:",
281
+ roles=("USER", "GPT"),
282
+ messages=(),
283
+ offset=0,
284
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
285
+ sep=" ",
286
+ sep2="</s>",
287
+ )
288
+ )
289
+
290
+ # Dolly V2 default template
291
+ register_conv_template(
292
+ Conversation(
293
+ name="dolly_v2",
294
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
295
+ roles=("### Instruction", "### Response"),
296
+ messages=(),
297
+ offset=0,
298
+ sep_style=SeparatorStyle.DOLLY,
299
+ sep="\n\n",
300
+ sep2="### End",
301
+ )
302
+ )
303
+
304
+ # OpenAssistant Pythia default template
305
+ register_conv_template(
306
+ Conversation(
307
+ name="oasst_pythia",
308
+ system="",
309
+ roles=("<|prompter|>", "<|assistant|>"),
310
+ messages=(),
311
+ offset=0,
312
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
313
+ sep="<|endoftext|>",
314
+ )
315
+ )
316
+
317
+ # StableLM Alpha default template
318
+ register_conv_template(
319
+ Conversation(
320
+ name="stablelm",
321
+ system="""<|SYSTEM|># StableLM Tuned (Alpha version)
322
+ - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
323
+ - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
324
+ - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
325
+ - StableLM will refuse to participate in anything that could harm a human.
326
+ """,
327
+ roles=("<|USER|>", "<|ASSISTANT|>"),
328
+ messages=(),
329
+ offset=0,
330
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
331
+ sep="",
332
+ stop_token_ids=[50278, 50279, 50277, 1, 0],
333
+ )
334
+ )
335
+
336
+ # Baize default template
337
+ register_conv_template(
338
+ Conversation(
339
+ name="baize",
340
+ system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.",
341
+ roles=("[|Human|]", "[|AI|]"),
342
+ messages=(
343
+ ("[|Human|]", "Hello!"),
344
+ ("[|AI|]", "Hi!"),
345
+ ),
346
+ offset=2,
347
+ sep_style=SeparatorStyle.BAIZE,
348
+ sep="[|Human|]",
349
+ stop_str="[|Human|]",
350
+ )
351
+ )
352
+
353
+ # RWKV-4-Raven default template
354
+ register_conv_template(
355
+ Conversation(
356
+ name="rwkv",
357
+ system="The following is a coherent verbose detailed conversation between Bob and Alice.\n\n",
358
+ roles=("Bob", "Alice"),
359
+ messages=(
360
+ ("Bob", "Hi"),
361
+ (
362
+ "Alice",
363
+ "Hi. I am your assistant and I will answer all questions. Please feel free to ask any question and I will always answer it.",
364
+ ),
365
+ ),
366
+ offset=2,
367
+ sep_style=SeparatorStyle.RWKV,
368
+ sep="",
369
+ stop_str="\n\n",
370
+ )
371
+ )
372
+
373
+ # Buddy default template
374
+ register_conv_template(
375
+ Conversation(
376
+ name="openbuddy",
377
+ system="""Consider a conversation between User (a human) and Assistant (named Buddy).
378
+ Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
379
+ Buddy cannot access the Internet.
380
+ Buddy can fluently speak the user's language (e.g. English, Chinese).
381
+ Buddy can generate poems, stories, code, essays, songs, parodies, and more.
382
+ Buddy possesses vast knowledge about the world, history, and culture.
383
+ Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
384
+ Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
385
+
386
+ User: Hi.
387
+ Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
388
+ roles=("User", "Assistant"),
389
+ messages=(),
390
+ offset=0,
391
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
392
+ sep="\n",
393
+ )
394
+ )
395
+
396
+ # Phoenix default template
397
+ register_conv_template(
398
+ Conversation(
399
+ name="phoenix",
400
+ system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
401
+ roles=("Human", "Assistant"),
402
+ messages=(),
403
+ offset=0,
404
+ sep_style=SeparatorStyle.PHOENIX,
405
+ sep="</s>",
406
+ )
407
+ )
408
+
409
+ # ChatGPT default template
410
+ register_conv_template(
411
+ Conversation(
412
+ name="chatgpt",
413
+ system="You are a helpful assistant.",
414
+ roles=("user", "assistant"),
415
+ messages=(),
416
+ offset=0,
417
+ sep_style=None,
418
+ sep=None,
419
+ )
420
+ )
421
+
422
+ # Claude default template
423
+ register_conv_template(
424
+ Conversation(
425
+ name="claude",
426
+ system="",
427
+ roles=("Human", "Assistant"),
428
+ messages=(),
429
+ offset=0,
430
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
431
+ sep="\n\n",
432
+ )
433
+ )
434
+
435
+ # MPT default template
436
+ register_conv_template(
437
+ Conversation(
438
+ name="mpt",
439
+ system="""<|im_start|>system
440
+ - You are a helpful assistant chatbot trained by MosaicML.
441
+ - You answer questions.
442
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
443
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.
444
+ """,
445
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
446
+ messages=(),
447
+ offset=0,
448
+ sep_style=SeparatorStyle.NEW_LINE,
449
+ sep="<|im_end|>",
450
+ stop_token_ids=[50278, 0],
451
+ )
452
+ )
453
+
454
+ # Bard default template
455
+ # Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150
456
+ # https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40
457
+ register_conv_template(
458
+ Conversation(
459
+ name="bard",
460
+ system="",
461
+ roles=("0", "1"),
462
+ messages=(),
463
+ offset=0,
464
+ sep_style=None,
465
+ sep=None,
466
+ )
467
+ )
468
+
469
+ # BiLLa default template
470
+ register_conv_template(
471
+ Conversation(
472
+ name="billa",
473
+ system="",
474
+ roles=("Human", "Assistant"),
475
+ messages=(),
476
+ offset=0,
477
+ sep_style=SeparatorStyle.BILLA,
478
+ sep="\n",
479
+ stop_str="Human:",
480
+ )
481
+ )
482
+
483
+ # custom otter template
484
+ register_conv_template(
485
+ Conversation(
486
+ name='otter',
487
+ system='',
488
+ roles=('User:', 'GPT:<answer>'),
489
+ messages=(),
490
+ offset=0,
491
+ sep_style=SeparatorStyle.ADD_SPACE_TWO,
492
+ sep=' ',
493
+ sep2='<|endofchunk|>',
494
+ )
495
+ )
496
+
497
+ if __name__ == "__main__":
498
+ conv = get_conv_template("vicuna_v1.1")
499
+ conv.append_message(conv.roles[0], "Hello!")
500
+ conv.append_message(conv.roles[1], "Hi!")
501
+ conv.append_message(conv.roles[0], "How are you?")
502
+ conv.append_message(conv.roles[1], None)
503
+ print(conv.get_prompt())
mllm/dataset/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .root import *
2
+ from .utils import *
3
+ from .process_function import *
4
+ from .single_image_convsation import *
5
+ from .single_image_dataset import *
6
+
7
+ from .builder import prepare_data
mllm/dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (322 Bytes). View file
 
mllm/dataset/__pycache__/builder.cpython-310.pyc ADDED
Binary file (2.96 kB). View file
 
mllm/dataset/__pycache__/root.cpython-310.pyc ADDED
Binary file (2.42 kB). View file
 
mllm/dataset/__pycache__/single_image_convsation.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
mllm/dataset/__pycache__/single_image_interactive.cpython-310.pyc ADDED
Binary file (4.15 kB). View file
 
mllm/dataset/builder.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Callable, Dict, Tuple, Any, Optional
3
+
4
+ from torch.utils.data import Dataset
5
+ from transformers import EvalPrediction, TrainingArguments
6
+
7
+ from .root import DATASETS, METRICS, TRANSFORMS, FUNCTIONS
8
+ from .single_image_convsation import SingleImageConvDataset
9
+ from .single_image_interactive import SingleImageInteractive
10
+ from ..conversation import get_conv_template
11
+ from .utils import init_ceph_client_if_needed
12
+
13
+ DatasetDict = Dict[str, Dataset]
14
+ ComputeMetrics = Callable[[EvalPrediction], Dict]
15
+
16
+
17
+ def prepare_data(
18
+ data_args,
19
+ model_args,
20
+ training_args: TrainingArguments,
21
+ preprocessor: Dict[str, Any],
22
+ ) -> Tuple[DatasetDict, Optional[ComputeMetrics]]:
23
+ # raw dataset
24
+ datasets = {
25
+ 'train': partial(DATASETS.build, data_args.train) if training_args.do_train else None,
26
+ 'validation': partial(DATASETS.build, data_args.validation) if training_args.do_eval else None,
27
+ 'test': partial(DATASETS.build, data_args.test) if training_args.do_predict else None,
28
+ }
29
+ # compute metric
30
+ compute_metric_cfg = data_args.get('compute_metric', None)
31
+ compute_metrics = build_compute_metric(compute_metric_cfg, preprocessor)
32
+ # conv dataset wrap
33
+ conv_args = model_args.conv_args
34
+ tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
35
+ conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
36
+ conv_template = partial(get_conv_template, name=conv_template)
37
+ transforms = conv_args.get('transforms', None)
38
+ if transforms is not None:
39
+ transforms = TRANSFORMS.build(transforms)
40
+ # process func
41
+ process_func = {}
42
+ for k, v in model_args.process_func_args.items():
43
+ process_func[k] = FUNCTIONS.build(cfg=v)
44
+
45
+ conv_dataset_cls = partial(
46
+ SingleImageConvDataset,
47
+ preprocessor=preprocessor,
48
+ process_func=process_func,
49
+ tokenize_kwargs=tokenize_kwargs,
50
+ conv_template=conv_template,
51
+ training_args=training_args,
52
+ transforms=transforms,
53
+ )
54
+ ds = {
55
+ 'train': conv_dataset_cls(dataset_generator=datasets['train'], mode='train') if datasets['train'] is not None else None,
56
+ 'validation': conv_dataset_cls(dataset_generator=datasets['validation'], mode='validation') if datasets['validation'] is not None else None,
57
+ 'test': conv_dataset_cls(dataset_generator=datasets['test'], mode='test') if datasets['test'] is not None else None,
58
+ }
59
+
60
+ # multi test set
61
+ if hasattr(data_args, 'multitest') and bool(data_args.multitest) \
62
+ and hasattr(training_args, 'do_multi_predict') and training_args.do_multi_predict:
63
+ print(f"processing multitest set")
64
+ k2v = {}
65
+ for k, item in data_args.multitest.items():
66
+ _dataset_cls = partial(DATASETS.build, item['cfg'])
67
+ _compute_metric = build_compute_metric(item['compute_metric'], preprocessor)
68
+ k2v[k] = {
69
+ "dataset": conv_dataset_cls(dataset_generator=_dataset_cls, mode='test'),
70
+ "compute_metric": _compute_metric
71
+ }
72
+ ds['multitest'] = k2v
73
+ print(f"processing multitest set. done.")
74
+
75
+ # in default, ceph client do init at the beginning of program.
76
+ # importantly, before dataloader worker fork.
77
+ lazy_init = data_args.get('lazy_init', True)
78
+ if not lazy_init:
79
+ init_ceph_client_if_needed()
80
+ return ds, compute_metrics
81
+
82
+
83
+ def build_compute_metric(compute_metric_cfg, preprocessor):
84
+ if compute_metric_cfg is not None:
85
+ compute_metric_cfg = dict(compute_metric_cfg) # copy cfg because we modify it
86
+ compute_metric_cfg.update(dict(preprocessor=preprocessor))
87
+ compute_metrics = METRICS.build(cfg=compute_metric_cfg)
88
+ else:
89
+ compute_metrics = None
90
+ return compute_metrics
91
+
92
+
93
+ def prepare_interactive(
94
+ model_args,
95
+ preprocessor: Dict[str, Any],
96
+ ):
97
+ conv_args = model_args.conv_args
98
+ tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
99
+ conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
100
+ conv_template = partial(get_conv_template, name=conv_template)
101
+ transforms = conv_args.get('transforms', None)
102
+ if transforms is not None:
103
+ transforms = TRANSFORMS.build(transforms)
104
+ # process func
105
+ process_func = {}
106
+ for k, v in model_args.process_func_args.items():
107
+ process_func[k] = FUNCTIONS.build(cfg=v)
108
+
109
+ ds = SingleImageInteractive(
110
+ preprocessor=preprocessor,
111
+ process_func=process_func,
112
+ tokenize_kwargs=tokenize_kwargs,
113
+ conv_template=conv_template,
114
+ training_args=None,
115
+ transforms=transforms,
116
+ mode='test',
117
+ )
118
+ return ds
mllm/dataset/process_function/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .shikra_process_function import (
2
+ ShikraConvProcess,
3
+ ShikraImageProcessor,
4
+ ShikraTextProcess,
5
+ )
6
+
7
+ from .box_process_function import (
8
+ BoxFormatProcess,
9
+ BoxFormatter,
10
+ PlainBoxFormatter,
11
+ TokenFormatter,
12
+ prepare_target_processor,
13
+ )
mllm/dataset/process_function/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (458 Bytes). View file
 
mllm/dataset/process_function/__pycache__/box_process_function.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
mllm/dataset/process_function/__pycache__/shikra_process_function.cpython-310.pyc ADDED
Binary file (6.02 kB). View file
 
mllm/dataset/process_function/box_process_function.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import logging
4
+ import typing
5
+ from typing import List, Dict, Any, Tuple, Union
6
+
7
+ from ..utils.transform import norm_box_xyxy, norm_point_xyxy
8
+
9
+ from ..root import (
10
+ FUNCTIONS,
11
+ BaseTargetProcessFunc,
12
+ BOXES_PLACEHOLDER,
13
+ BOXES_PROCESSOR,
14
+ POINTS_PLACEHOLDER,
15
+ )
16
+
17
+ from ...utils import smart_tokenizer_and_embedding_resize
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.INFO)
21
+ logging.basicConfig(
22
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
23
+ datefmt="%m/%d/%Y %H:%M:%S",
24
+ handlers=[logging.StreamHandler(sys.stdout), ],
25
+ )
26
+
27
+ Box = List[Union[float, int]]
28
+ Boxes = List[Box]
29
+ BoxesSeq = List[Boxes]
30
+
31
+
32
+ @FUNCTIONS.register_module()
33
+ class BoxFormatProcess(BaseTargetProcessFunc):
34
+ def __call__(self, raw_conv: List[Dict[str, Any]], target: Dict[str, Any], preprocessor: Dict[str, Any],
35
+ multimage_mode=False) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
36
+ box_formatter = preprocessor['target']['boxes']
37
+
38
+ if multimage_mode:
39
+ target = typing.cast(list, target)
40
+ outer_normalized_boxes = []
41
+ for tgt in target:
42
+ normalized_boxes = []
43
+ if tgt is not None and 'boxes' in tgt:
44
+ for box in tgt['boxes']:
45
+ normalized_boxes.append(
46
+ norm_box_xyxy(box, w=tgt['width'], h=tgt['height'])
47
+ )
48
+ outer_normalized_boxes.append(normalized_boxes)
49
+ normalized_boxes = outer_normalized_boxes
50
+ outer_normalized_points = []
51
+ for tgt in target:
52
+ normalized_points = []
53
+ if tgt is not None and 'boxes' in tgt:
54
+ for box in tgt['boxes']:
55
+ normalized_points.append(
56
+ norm_box_xyxy(box, w=tgt['width'], h=tgt['height'])
57
+ )
58
+ outer_normalized_points.append(normalized_points)
59
+ normalized_points = outer_normalized_points
60
+ else:
61
+ # normalize target
62
+ normalized_boxes = []
63
+ if target is not None and 'boxes' in target:
64
+ for box in target['boxes']:
65
+ normalized_boxes.append(
66
+ norm_box_xyxy(box, w=target['width'], h=target['height'])
67
+ )
68
+ normalized_points = []
69
+ if target is not None and 'points' in target:
70
+ for point in target['points']:
71
+ normalized_points.append(
72
+ norm_point_xyxy(point, w=target['width'], h=target['height'])
73
+ )
74
+
75
+ # convert bboxes_seq
76
+ for sentence in raw_conv:
77
+ words: str = sentence['value']
78
+ boxes_seq: List[List[int]] = sentence.get('boxes_seq', None)
79
+ if boxes_seq is not None:
80
+ # map box seq
81
+ boxes_seq: List[Boxes] = map_obj(normalized_boxes, boxes_seq)
82
+ # reformat; replace <boxes> placeholder
83
+ converted = box_formatter(words, boxes_seq)
84
+ words = converted
85
+ points_seq: List[List[int]] = sentence.get('points_seq', None)
86
+ if points_seq is not None:
87
+ # map point seq
88
+ points_seq: List[Boxes] = map_obj(normalized_points, points_seq)
89
+ # reformat; replace <points> placeholder
90
+ converted = box_formatter.call_on_point(words, points_seq)
91
+ words = converted
92
+ if boxes_seq is not None or points_seq is not None:
93
+ sentence['raw_value'] = sentence['value']
94
+ sentence['value'] = words
95
+ return raw_conv, target
96
+
97
+
98
+ def map_obj(boxes_value: List[List[float]], boxes_seq: List[List[int]]) -> List[List[List[float]]]:
99
+ """
100
+ >>> normalized_boxes = [[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2], [0.3, 0.3, 0.3, 0.3]]
101
+ >>> boxes_seq_ = [[3, 1], [2]]
102
+ >>> var = map_obj(normalized_boxes, boxes_seq_)
103
+ >>> assert var == [[[0.3,0.3,0.3,0.3], [0.1,0.1,0.1,0.1]], [0.2,0.2,0.2,0.2]]
104
+ """
105
+ try:
106
+ ret = []
107
+ for boxes in boxes_seq:
108
+ boxes_ret = []
109
+ for box_index in boxes:
110
+ if isinstance(box_index, (list, tuple)):
111
+ boxes_ret.append(boxes_value[box_index[0]][box_index[1]])
112
+ else:
113
+ boxes_ret.append(boxes_value[box_index])
114
+ ret.append(boxes_ret)
115
+ return ret
116
+ except:
117
+ raise SystemExit(f"error: map obj {boxes_value} {boxes_seq}")
118
+
119
+
120
+ class BoxFormatter:
121
+ def __init__(self, bboxes_token=BOXES_PLACEHOLDER, points_token=POINTS_PLACEHOLDER):
122
+ self.bboxes_token = bboxes_token
123
+ self.points_token = points_token
124
+ # normally the bboxes_token_pat is the same as bboxes_token if u not use some weird token
125
+ self.bboxes_token_pat = re.compile(bboxes_token)
126
+ self.points_token_pat = re.compile(points_token)
127
+
128
+ def __call__(self, sentence: str, bboxes_seq: BoxesSeq) -> str:
129
+ all_box = self.bboxes_token_pat.findall(sentence)
130
+ assert len(all_box) == len(bboxes_seq), f"not match. sentence: {sentence}. boxes:{bboxes_seq}"
131
+ if len(all_box) == 0:
132
+ return sentence
133
+ bboxes_strs = [self.format_box(bboxes) for bboxes in bboxes_seq]
134
+ converted = sentence.replace(self.bboxes_token, '{}').format(*bboxes_strs)
135
+ return converted
136
+
137
+ def call_on_point(self, sentence: str, points_seq: BoxesSeq) -> str:
138
+ all_box = self.points_token_pat.findall(sentence)
139
+ assert len(all_box) == len(points_seq), f"not match. sentence: {sentence}. boxes:{points_seq}"
140
+ if len(all_box) == 0:
141
+ return sentence
142
+ bboxes_strs = [self.format_point(bboxes) for bboxes in points_seq]
143
+ converted = sentence.replace(self.points_token, '{}').format(*bboxes_strs)
144
+ return converted
145
+
146
+ def format_point(self, points) -> str:
147
+ raise NotImplementedError
148
+
149
+ def format_box(self, bboxes: Boxes) -> str:
150
+ raise NotImplementedError
151
+
152
+ def extract(self, string: str) -> List[Boxes]:
153
+ raise NotImplementedError
154
+
155
+ def extract_point(self, string: str) -> List[Boxes]:
156
+ raise NotImplementedError
157
+
158
+
159
+ @BOXES_PROCESSOR.register_module()
160
+ class PlainBoxFormatter(BoxFormatter):
161
+
162
+ def __init__(self, *args, precision=3, use_small_brackets=False, **kwargs):
163
+ super().__init__(*args, **kwargs)
164
+ self.precision = precision
165
+ self.use_small_brackets = use_small_brackets
166
+
167
+ small_brackets_pat = re.compile(r'\(\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\)')
168
+ small_brackets_point_pat = re.compile(r'\(\d(?:\.\d*)?(?:,\d(?:\.\d*)?)(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?))*\)')
169
+
170
+ middle_brackets_pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]')
171
+ middle_brackets_point_pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?)(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?))*\]')
172
+
173
+ self.pat = small_brackets_pat if use_small_brackets else middle_brackets_pat
174
+ self.point_pat = small_brackets_point_pat if use_small_brackets else middle_brackets_point_pat
175
+
176
+ def format_box(self, boxes: Boxes) -> str:
177
+ box_strs = []
178
+ for box in boxes:
179
+ box_strs.append(','.join([f"{elem:.{self.precision}f}" for elem in box]))
180
+ box_str = ';'.join(box_strs)
181
+ if self.use_small_brackets:
182
+ return "(" + box_str + ")"
183
+ return "[" + box_str + "]"
184
+
185
+ def format_point(self, points) -> str:
186
+ return self.format_box(points)
187
+
188
+ def extract(self, string: str) -> List[Boxes]:
189
+ """ balabala<boxes>balabala<boxes> -> [boxes, boxes] """
190
+ ret = []
191
+ for bboxes_str in self.pat.findall(string):
192
+ bboxes = []
193
+ bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";")
194
+ for bbox_str in bbox_strs:
195
+ bbox = list(map(float, bbox_str.split(',')))
196
+ bboxes.append(bbox)
197
+ ret.append(bboxes)
198
+ return ret
199
+
200
+ def extract_point(self, string: str) -> List[Boxes]:
201
+ """ balabala<boxes>balabala<boxes> -> [boxes, boxes] """
202
+ ret = []
203
+ for bboxes_str in self.point_pat.findall(string):
204
+ bboxes = []
205
+ bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";")
206
+ for bbox_str in bbox_strs:
207
+ bbox = list(map(float, bbox_str.split(',')))
208
+ bboxes.append(bbox)
209
+ ret.append(bboxes)
210
+ return ret
211
+
212
+
213
+ @BOXES_PROCESSOR.register_module()
214
+ class TokenFormatter(BoxFormatter):
215
+
216
+ def __init__(self, num_bins=1001):
217
+ super().__init__()
218
+ self.extract_box_pat = re.compile(r'<b_st><bin_\d*?>(?:<bin_\d*?>){3}(?:<b_sep><bin_\d*?>(?:<bin_\d*?>){3})*<b_ed>')
219
+ self.extract_point_pat = re.compile(r'<p_st><bin_\d*?>(?:<bin_\d*?>){1}(?:<p_sep><bin_\d*?>(?:<bin_\d*?>){1})*<p_ed>')
220
+ self.num_bins = num_bins
221
+ self.use_sep = True
222
+ self.use_begin_end = True
223
+
224
+ self.box_begin = '<b_st>'
225
+ self.box_sep = '<b_sep>'
226
+ self.box_end = '<b_ed>'
227
+
228
+ self.point_begin = '<p_st>'
229
+ self.point_sep = '<p_sep>'
230
+ self.point_end = '<p_ed>'
231
+
232
+ def format_point(self, points) -> str:
233
+ final_str = []
234
+ for bbox in points:
235
+ quant_x0 = "<bin_{}>".format(round((bbox[0] * (self.num_bins - 1))))
236
+ quant_y0 = "<bin_{}>".format(round((bbox[1] * (self.num_bins - 1))))
237
+ region_coord = "{} {}".format(quant_x0, quant_y0)
238
+ final_str.append(region_coord)
239
+ if self.use_sep:
240
+ final_str = self.point_sep.join(final_str)
241
+ else:
242
+ final_str = ''.join(final_str)
243
+ if self.use_begin_end:
244
+ final_str = self.point_begin + final_str + self.point_end
245
+ return final_str
246
+
247
+ def format_box(self, bboxes: Boxes) -> str:
248
+ final_str = []
249
+ for bbox in bboxes:
250
+ quant_x0 = "<bin_{}>".format(round((bbox[0] * (self.num_bins - 1))))
251
+ quant_y0 = "<bin_{}>".format(round((bbox[1] * (self.num_bins - 1))))
252
+ quant_x1 = "<bin_{}>".format(round((bbox[2] * (self.num_bins - 1))))
253
+ quant_y1 = "<bin_{}>".format(round((bbox[3] * (self.num_bins - 1))))
254
+ region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
255
+ final_str.append(region_coord)
256
+ if self.use_sep:
257
+ final_str = self.box_sep.join(final_str)
258
+ else:
259
+ final_str = ''.join(final_str)
260
+ if self.use_begin_end:
261
+ final_str = self.box_begin + final_str + self.box_end
262
+ return final_str
263
+
264
+ def extract(self, string: str) -> List[Boxes]:
265
+ ret = []
266
+ for bboxes_str in self.extract_box_pat.findall(string.replace(" ", "")):
267
+ bboxes = []
268
+ bbox_strs = bboxes_str.replace(self.box_begin, "").replace(self.box_end, "").split(self.box_sep)
269
+ for bbox_str in bbox_strs:
270
+ elems = list(map(int, re.findall(r'<bin_(\d*?)>', bbox_str)))
271
+ bbox = [elem / (self.num_bins - 1) for elem in elems]
272
+ bboxes.append(bbox)
273
+ ret.append(bboxes)
274
+ return ret
275
+
276
+ def extract_point(self, string: str) -> List[Boxes]:
277
+ ret = []
278
+ for bboxes_str in self.extract_point_pat.findall(string):
279
+ bboxes = []
280
+ bbox_strs = bboxes_str.replace(self.point_begin, "").replace(self.point_end, "").split(self.point_sep)
281
+ for bbox_str in bbox_strs:
282
+ elems = list(map(int, re.findall(r'<bin_(\d*?)>', bbox_str)))
283
+ bbox = [elem / (self.num_bins - 1) for elem in elems]
284
+ bboxes.append(bbox)
285
+ ret.append(bboxes)
286
+ return ret
287
+
288
+ def post_process_model_tokenizer(self, model, preprocessor, model_args, training_args):
289
+ tokenizer = preprocessor['text']
290
+
291
+ additional_special_tokens = [
292
+ self.box_begin, self.box_sep, self.box_end,
293
+ self.point_begin, self.point_sep, self.point_end,
294
+ ]
295
+ for i in range(self.num_bins):
296
+ additional_special_tokens.append(f'<bin_{i}>')
297
+
298
+ smart_tokenizer_and_embedding_resize(
299
+ {'additional_special_tokens': additional_special_tokens},
300
+ tokenizer,
301
+ model,
302
+ )
303
+ return model, preprocessor
304
+
305
+
306
+ # FIXME: merge into load_pretrained
307
+ def prepare_target_processor(
308
+ model, # multimodal llm
309
+ preprocessor: Dict[str, Any],
310
+ model_args,
311
+ training_args,
312
+ ):
313
+ if not hasattr(model_args, 'target_processor'):
314
+ return model, preprocessor
315
+
316
+ target_processor = {}
317
+ if 'boxes' in model_args['target_processor']:
318
+ boxes_cfg = model_args['target_processor']['boxes']
319
+ boxes_processor = BOXES_PROCESSOR.build(boxes_cfg)
320
+ target_processor['boxes'] = boxes_processor
321
+ if hasattr(boxes_processor, "post_process_model_tokenizer"):
322
+ model, preprocessor = boxes_processor.post_process_model_tokenizer(
323
+ model, preprocessor, model_args, training_args,
324
+ )
325
+ preprocessor['target'] = target_processor
326
+ return model, preprocessor
mllm/dataset/process_function/shikra_process_function.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import copy
3
+ import warnings
4
+ import logging
5
+ from typing import Dict, Any, List
6
+
7
+ import PIL.Image
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import LlamaTokenizer
11
+
12
+ from ..root import (
13
+ FUNCTIONS,
14
+ IMAGE_PLACEHOLDER,
15
+ BaseImageProcessFunc,
16
+ BaseConvProcessFunc,
17
+ BaseTextProcessFunc,
18
+ )
19
+ from ...conversation import SeparatorStyle, Conversation
20
+
21
+ IGNORE_INDEX = -100
22
+ DEFAULT_IMAGE_TOKEN = IMAGE_PLACEHOLDER
23
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
24
+ DEFAULT_IM_START_TOKEN = "<im_start>"
25
+ DEFAULT_IM_END_TOKEN = "<im_end>"
26
+
27
+ logger = logging.getLogger(__name__)
28
+ logger.setLevel(logging.INFO)
29
+ logging.basicConfig(
30
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
31
+ datefmt="%m/%d/%Y %H:%M:%S",
32
+ handlers=[logging.StreamHandler(sys.stdout), ],
33
+ )
34
+
35
+
36
+ @FUNCTIONS.register_module()
37
+ class ShikraConvProcess(BaseConvProcessFunc):
38
+ def __call__(self, raw_conv: List[Dict[str, Any]], preprocessor: Dict[str, Any], conv_template: Conversation) -> List[Dict[str, Any]]:
39
+ conv_processor_cfg = preprocessor['conv']
40
+
41
+ image_token_len = conv_processor_cfg['image_token_len']
42
+ sep_image_conv_front = conv_processor_cfg.get('sep_image_conv_front', False)
43
+ use_im_start_end = conv_processor_cfg.get('use_im_start_end', False)
44
+ # assert DEFAULT_IMAGE_PATCH_TOKEN in preprocessor['text'].get_vocab()
45
+ # if use_im_start_end:
46
+ # assert DEFAULT_IM_START_TOKEN in preprocessor['text'].get_vocab()
47
+ # assert DEFAULT_IM_END_TOKEN in preprocessor['text'].get_vocab()
48
+
49
+ if sep_image_conv_front:
50
+ raw_conv[0]['value'] = raw_conv[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
51
+ raw_conv[0]['value'] = DEFAULT_IMAGE_TOKEN + conv_template.sep + conv_template.roles[0] + ": " + raw_conv[0]['value']
52
+ for sentence in raw_conv:
53
+ replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
54
+ if use_im_start_end:
55
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
56
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
57
+
58
+ return raw_conv
59
+
60
+
61
+ @FUNCTIONS.register_module()
62
+ class ShikraTextProcess(BaseTextProcessFunc):
63
+
64
+ def __call__(self, conv: Conversation, preprocessor: Dict[str, Any], mode: str, **tokenize_kwargs) -> Dict[str, Any]:
65
+ tokenizer = preprocessor['text']
66
+ assert isinstance(tokenizer, LlamaTokenizer), "only work for LlamaTokenizer"
67
+
68
+ _truncation_size = tokenize_kwargs.pop('truncation_size', None)
69
+ _kwargs = {'return_tensors': 'pt'}
70
+ _kwargs.update(tokenize_kwargs)
71
+
72
+ if conv.sep_style == SeparatorStyle.ADD_COLON_TWO:
73
+ if mode in ['train']:
74
+ ret = self.tk_conv_colon_two_train(conv, tokenizer, **_kwargs)
75
+ else:
76
+ ret = self.tk_conv_colon_two_eval(conv, tokenizer, **_kwargs)
77
+ else:
78
+ raise ValueError(f"unrecognized conv_style: {conv.sep_style}.\n the conv is {conv}")
79
+
80
+ if _truncation_size is None:
81
+ return ret
82
+ if len(ret['input_ids']) <= _truncation_size:
83
+ return ret
84
+
85
+ origin_len = len(ret['input_ids'])
86
+ ids_to_remove_num = origin_len - _truncation_size
87
+ # truncation. should carefully not truncate <img_token>
88
+ ids_should_not_remove = list(map(
89
+ tokenizer.convert_tokens_to_ids,
90
+ (DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN)
91
+ ))
92
+ back_no_image = all(ids not in ids_should_not_remove for ids in ret['input_ids'][_truncation_size:])
93
+ if back_no_image:
94
+ tgt_ids = list(range(_truncation_size))
95
+ else:
96
+ ids_to_remove = set()
97
+ for idx in range(origin_len - 1, -1, -1):
98
+ if ret['input_ids'][idx] not in ids_should_not_remove:
99
+ ids_to_remove.add(idx)
100
+ if len(ids_to_remove) >= ids_to_remove_num:
101
+ break
102
+ tgt_ids = [_ for _ in range(origin_len) if _ not in ids_to_remove]
103
+ logger.warning(f"truncate sample size from {origin_len} to {len(tgt_ids)}.")
104
+ assert len(tgt_ids) == _truncation_size, f"{len(tgt_ids)}, {_truncation_size}, {ret['input_ids'].tolist()}"
105
+ truncated_ret = {k: v[tgt_ids] for k, v in ret.items()}
106
+ return truncated_ret
107
+
108
+ # noinspection PyMethodMayBeStatic
109
+ def tk_conv_colon_two_train(self, conv, tokenizer, **kwargs):
110
+ conversation = conv.get_prompt()
111
+ input_ids = tokenizer([conversation, ], **kwargs).input_ids[0]
112
+ target = copy.deepcopy(input_ids)
113
+ assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
114
+ # Mask targets
115
+ sep = conv.sep + conv.roles[1] + ": "
116
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
117
+ rounds = conversation.split(conv.sep2)
118
+ cur_len = 1
119
+ target[:cur_len] = IGNORE_INDEX
120
+ for i, rou in enumerate(rounds):
121
+ if rou == "":
122
+ break
123
+ parts = rou.split(sep)
124
+ if len(parts) != 2:
125
+ break
126
+ parts[0] += sep
127
+ round_len = len(tokenizer(rou).input_ids)
128
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # <s> <space>
129
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
130
+ cur_len += round_len
131
+ target[cur_len:] = IGNORE_INDEX
132
+ if cur_len < tokenizer.model_max_length:
133
+ if cur_len != total_len:
134
+ target[:] = IGNORE_INDEX
135
+ warnings.warn(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. (ignored):\n{conversation}")
136
+ return dict(
137
+ input_ids=input_ids,
138
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
139
+ labels=target,
140
+ )
141
+
142
+ # noinspection PyMethodMayBeStatic
143
+ def tk_conv_colon_two_eval(self, conv, tokenizer, **kwargs):
144
+ assert len(conv.messages) >= 2
145
+ # target = conv.messages[-1][-1]
146
+ target = conv.get_prompt()
147
+
148
+ conv.messages[-1][-1] = ""
149
+ conversation = conv.get_prompt()
150
+ input_ids = tokenizer([conversation, ], **kwargs).input_ids[0]
151
+
152
+ target = tokenizer([target, ], add_special_tokens=False, **kwargs).input_ids[0]
153
+ target[target == tokenizer.pad_token_id] = IGNORE_INDEX
154
+ return dict(
155
+ input_ids=input_ids,
156
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
157
+ labels=target,
158
+ )
159
+
160
+
161
+ @FUNCTIONS.register_module()
162
+ class ShikraImageProcessor(BaseImageProcessFunc):
163
+ def __call__(self, image: Image.Image, preprocessor: Dict[str, Any]) -> Dict[str, Any]:
164
+ image_processor = preprocessor['image']
165
+
166
+ if isinstance(image, (list, tuple)):
167
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
168
+ assert False, 'Shikra not support MultiImage'
169
+ elif isinstance(image, PIL.Image.Image):
170
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
171
+ else:
172
+ if hasattr(image_processor, 'crop_size'):
173
+ crop_size = image_processor.crop_size
174
+ height, width = crop_size['height'], crop_size['width']
175
+ else:
176
+ raise ValueError("got empty image. and don't know how to pad")
177
+ image = torch.zeros(3, height, width)
178
+ return {'image': image}
mllm/dataset/root.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List, Tuple
2
+
3
+ from PIL import Image
4
+ from mmengine import DATASETS, TRANSFORMS, METRICS, FUNCTIONS, Registry
5
+
6
+ from ..conversation import Conversation
7
+
8
+ IMAGE_PLACEHOLDER = '<image>'
9
+ BOXES_PLACEHOLDER = '<boxes>'
10
+ EXPR_PLACEHOLDER = '<expr>'
11
+ OBJS_PLACEHOLDER = '<objs>'
12
+ QUESTION_PLACEHOLDER = '<question>'
13
+ POINTS_PLACEHOLDER = '<points>'
14
+ # processor
15
+ BOXES_PROCESSOR = Registry('Processor for Boxes')
16
+
17
+
18
+ # only for static type checking
19
+ class BaseConvProcessFunc:
20
+ def __call__(
21
+ self,
22
+ raw_conv: List[Dict[str, Any]],
23
+ preprocessor: Dict[str, Any],
24
+ conv_template: Conversation,
25
+ ) -> List[Dict[str, Any]]:
26
+ raise NotImplementedError
27
+
28
+
29
+ class BaseTargetProcessFunc:
30
+ def __call__(
31
+ self,
32
+ raw_conv: List[Dict[str, Any]],
33
+ target: Dict[str, Any],
34
+ preprocessor: Dict[str, Any],
35
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
36
+ raise NotImplementedError
37
+
38
+
39
+ class BaseTextProcessFunc:
40
+ def __call__(
41
+ self,
42
+ conv: Conversation,
43
+ preprocessor: Dict[str, Any],
44
+ mode: str,
45
+ **tokenize_kwargs,
46
+ ) -> Dict[str, Any]:
47
+ raise NotImplementedError
48
+
49
+
50
+ class BaseImageProcessFunc:
51
+ def __call__(
52
+ self,
53
+ image: Image.Image,
54
+ preprocessor: Dict[str, Any],
55
+ ) -> Dict[str, Any]:
56
+ raise NotImplementedError
57
+
58
+
59
+ __all__ = [
60
+ 'IMAGE_PLACEHOLDER', 'BOXES_PLACEHOLDER', 'EXPR_PLACEHOLDER', 'OBJS_PLACEHOLDER', 'QUESTION_PLACEHOLDER', 'POINTS_PLACEHOLDER',
61
+ 'FUNCTIONS',
62
+ 'DATASETS',
63
+ 'TRANSFORMS',
64
+ 'METRICS',
65
+ 'BOXES_PROCESSOR',
66
+ 'BaseConvProcessFunc', 'BaseTargetProcessFunc', 'BaseTextProcessFunc', 'BaseImageProcessFunc',
67
+ ]
mllm/dataset/single_image_convsation.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from functools import partial
3
+ from typing import Dict, Any, Callable, List, Optional, Tuple, Type
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+ from transformers import TrainingArguments
9
+
10
+ from .root import IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER
11
+ from ..conversation import Conversation, get_conv_template
12
+ from ..utils import post_process_generate_ids
13
+
14
+
15
+ class SingleImageConvDatasetMixin:
16
+
17
+ def __init__(
18
+ self,
19
+ *args,
20
+ preprocessor: Dict[str, Any],
21
+ process_func: Dict[str, Any],
22
+ conv_template: Callable[[], Conversation] = partial(get_conv_template, name='vicuna_v1.1'),
23
+ mode='train',
24
+ tokenize_kwargs: dict = None,
25
+ training_args: TrainingArguments = None,
26
+ transforms: Optional[Callable] = None,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(*args, **kwargs)
30
+ assert mode in ['train', 'validation', 'test']
31
+
32
+ self.preprocessor = preprocessor
33
+ self.process_func = process_func
34
+ self.conv_template = conv_template
35
+ self.mode = mode
36
+ self.tokenize_kwargs = tokenize_kwargs if tokenize_kwargs is not None else {}
37
+ self.training_args = training_args
38
+ self.transforms = transforms
39
+
40
+ def __getitem__(self, index, debug_mode=False, return_conv=False) -> Dict[str, Any]:
41
+ # getitem
42
+ item = self.get_raw_item(index)
43
+ image: Image.Image = item.get('image', None)
44
+ target: Dict[str, Any] = item.get('target', None)
45
+ raw_conv: List[Dict[str, Any]] = item['conversations']
46
+
47
+ # transform
48
+ assert isinstance(image, list) == isinstance(target, list)
49
+ multimage_mode = isinstance(image, list)
50
+ if isinstance(image, list):
51
+ # TODO: validate raw item
52
+ transformed_image, transformed_target = [], []
53
+ for img, tgt in zip(image, target):
54
+ if self.transforms is not None and image is not None:
55
+ img, tgt = self.transforms(img, tgt)
56
+ if tgt is not None:
57
+ tgt['width'], tgt['height'] = img.width, img.height
58
+ transformed_image.append(img)
59
+ transformed_target.append(tgt)
60
+ image, target = transformed_image, transformed_target
61
+ else:
62
+ self.validate_raw_item(item) # only validate for single image.
63
+ if self.transforms is not None and image is not None:
64
+ image, target = self.transforms(image, target)
65
+ has_image = 'image' in item and bool(item['image'])
66
+ has_target = 'target' in item and bool(item['target']) and any(bool(elem) for elem in item['target'].values())
67
+ if has_target and has_image:
68
+ target['width'], target['height'] = image.width, image.height
69
+
70
+ # preprocess
71
+ raw_conv = self.process_conv(raw_conv)
72
+ raw_conv, image = self.process_conv_multimage(raw_conv, image)
73
+ raw_conv, _ = self.process_target(raw_conv, target, multimage_mode=multimage_mode)
74
+ conv = self.build_conv(raw_conv)
75
+ if return_conv:
76
+ # noinspection PyTypeChecker
77
+ return conv
78
+ text_dict = self.process_text(conv)
79
+ image_dict = self.process_image(image)
80
+
81
+ # return
82
+ ret_dict = {}
83
+ ret_dict.update(text_dict)
84
+ ret_dict.update(image_dict)
85
+ self._print_sample(ret_dict, raw_conv, conv)
86
+ if debug_mode:
87
+ return {'ret': ret_dict, 'raw_conv': raw_conv, 'conv': conv, 'image': image}
88
+ return ret_dict
89
+
90
+ def __len__(self):
91
+ raise NotImplementedError
92
+
93
+ # noinspection PyMethodMayBeStatic
94
+ def process_conv_multimage(self, raw_conv, image):
95
+ # re-sort multi image
96
+ if image is None:
97
+ return raw_conv, image
98
+ if not isinstance(image, (list, tuple)):
99
+ return raw_conv, image
100
+ image_seqs = []
101
+ for conv in raw_conv:
102
+ image_seqs.extend(conv['image_seq'] if 'image_seq' in conv else [])
103
+ images = []
104
+ for idx in image_seqs:
105
+ images.append(image[idx])
106
+ return raw_conv, images
107
+
108
+ def get_raw_item(self, index) -> Dict[str, Any]:
109
+ """
110
+ return item format like this.
111
+ item = {
112
+ 'image': # PIL.Image.Image,
113
+ 'target': {
114
+ # xmin, ymin, xmax, ymax
115
+ 'boxes': [
116
+ [10, 10, 256, 265], # dog1
117
+ [24, 18, 378, 768], # dog2
118
+ [100, 310, 670, 653], # man
119
+ [278, 320, 809, 673], # rope
120
+ ],
121
+ }
122
+
123
+ "conversations": [
124
+ {
125
+ 'from': 'human',
126
+ 'value': 'What is the relation between the two dogs <boxes> and the man <boxes> in the image <image> ?',
127
+ 'boxes_seq': [[0, 1], [2], ],
128
+ },
129
+ {
130
+ 'from': 'gpt',
131
+ 'value': 'a rope <boxes> is connecting the left dog <boxes> with the man <boxes>. '
132
+ 'So the man <boxes> is walking the dog <boxes>.'
133
+ 'And the man <boxes> has no relationship with the right dog <boxes>',
134
+ 'boxes_seq': [[3], [0], [2], [2], [0], [2], [1]],
135
+ }
136
+ ]
137
+ }
138
+ # placeholder: <image> <boxes>
139
+ """
140
+ raise NotImplementedError
141
+
142
+ # noinspection PyMethodMayBeStatic
143
+ def validate_raw_item(self, item):
144
+ has_image = 'image' in item and bool(item['image'])
145
+ has_target = 'target' in item and bool(item['target']) and any(bool(elem) for elem in item['target'].values())
146
+ has_target_boxes = 'boxes' in item['target'] if has_target else False
147
+ raw_conv: List[Dict[str, Any]] = item['conversations']
148
+
149
+ # check image
150
+ human_input_has_image_placeholder = any(
151
+ sentence['from'] == 'human' and IMAGE_PLACEHOLDER in sentence['value'] for sentence in raw_conv
152
+ )
153
+ if human_input_has_image_placeholder:
154
+ assert has_image
155
+ if has_image and (not human_input_has_image_placeholder):
156
+ warnings.warn(f'item has image but the question has no image placeholder.\n{item}')
157
+ gpt_input_has_image_placeholder = any(
158
+ sentence['from'] == 'gpt' and IMAGE_PLACEHOLDER in sentence['value'] for sentence in raw_conv
159
+ )
160
+ assert not gpt_input_has_image_placeholder
161
+
162
+ # check target
163
+ has_boxes_placeholder = any(
164
+ BOXES_PLACEHOLDER in sentence['value'] for sentence in raw_conv
165
+ )
166
+ if has_boxes_placeholder:
167
+ assert has_target_boxes
168
+ # not check box placeholder num this will be checked in format process
169
+
170
+ def build_conv(self, source: List[Dict[str, Any]]) -> Conversation:
171
+ conv = self.conv_template()
172
+ role_map = {"human": conv.roles[0], "gpt": conv.roles[1]}
173
+ assert len(source) > 0
174
+ assert source[0]['from'] == 'human'
175
+ for sentence in source:
176
+ role = role_map[sentence['from']]
177
+ conv.append_message(role, sentence['value'])
178
+ return conv
179
+
180
+ def process_conv(self, raw_conv: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
181
+ """
182
+ some utils preprocess for raw_conv.
183
+ e.g. replace <image> placeholder to sequence <im_start> <im_patch>*256 <im_end>
184
+ """
185
+ return self.process_func['conv'](raw_conv, self.preprocessor, self.conv_template)
186
+
187
+ def process_target(self, raw_conv: List[Dict[str, Any]], target: Dict[str, Any], multimage_mode=False) -> Tuple[
188
+ List[Dict[str, Any]], Dict[str, Any]]:
189
+ """
190
+ convert target placeholder to actual information in raw_conv.
191
+ e.g. normalize bounding boxes; convert bounding boxes format; replace <boxes> placeholder
192
+ """
193
+ return self.process_func['target'](raw_conv, target, self.preprocessor, multimage_mode=multimage_mode)
194
+
195
+ def process_text(self, conv: Conversation) -> Dict[str, Any]:
196
+ """
197
+ convert Conversation object to torch.Tensor, e.g. input_ids, labels, attention_mask, etc.
198
+ self.tokenize_kwargs control something like padding/truncation behavior.
199
+ """
200
+ return self.process_func['text'](conv, self.preprocessor, self.mode, **self.tokenize_kwargs)
201
+
202
+ def process_image(self, image: Image.Image) -> Dict[str, Any]:
203
+ """
204
+ convert Image.Image object to torch.Tensor
205
+ """
206
+ return self.process_func['image'](image, self.preprocessor)
207
+
208
+ def _print_sample(self, ret_dict, raw_conv, conv):
209
+ if not hasattr(self, '_printed_sample'):
210
+ self._printed_sample = True
211
+ post_processed_labels = post_process_generate_ids(self.preprocessor['text'], ret_dict['labels'])
212
+ print(f"=================== {self.mode} sample ===================", flush=True)
213
+ print(f" input_ids: {self.preprocessor['text'].convert_ids_to_tokens(ret_dict['input_ids'])}")
214
+ print(f" labels: {self.preprocessor['text'].convert_ids_to_tokens(post_processed_labels)}")
215
+ print(f"decoded input_ids: {self.preprocessor['text'].decode(ret_dict['input_ids'])}")
216
+ print(f"decoded labels: {self.preprocessor['text'].decode(post_processed_labels)}")
217
+ if 'image' in ret_dict and ret_dict['image'] is not None:
218
+ image = ret_dict['image']
219
+ if isinstance(image, torch.Tensor):
220
+ print(f" image: {image.shape}")
221
+ elif isinstance(image, dict):
222
+ print(f" image: {image.keys()}")
223
+ elif isinstance(image, list) and len(image) > 0:
224
+ print(f" image: {len(image)}, {type(image[0])}")
225
+ else:
226
+ print(f" image: {type(image)}")
227
+ print("====================================================", flush=True)
228
+ try:
229
+ if self.training_args is not None:
230
+ _save_obj = {
231
+ 'ret_dict': ret_dict,
232
+ 'raw_conv': raw_conv,
233
+ 'conv': conv.get_prompt(),
234
+ }
235
+ from pathlib import Path
236
+ output_dir = Path(self.training_args.output_dir)
237
+ output_dir.mkdir(exist_ok=True, parents=True)
238
+ _local_rank = self.training_args.local_rank
239
+ _word_size = self.training_args.world_size
240
+ _file_path = str(output_dir / f'sample_check_{self.mode}_{_local_rank}_{_word_size}.pt')
241
+ print(f'saving some sample to {_file_path} for check.')
242
+ torch.save(_save_obj, _file_path)
243
+ except Exception as e:
244
+ warnings.warn(f'try to save samples but get exception: {e.args}. ignored.')
245
+
246
+
247
+ class SingleImageConvDataset(SingleImageConvDatasetMixin, Dataset):
248
+ _repr_indent = 4
249
+
250
+ def __init__(self, *args, dataset_generator: Type[Dataset], **kwargs):
251
+ super().__init__(*args, **kwargs)
252
+ self.dataset_generator = dataset_generator
253
+ self.dataset = None
254
+
255
+ def initialize_if_needed(self):
256
+ """
257
+ lazy initialize for big in-memory python object due to python 'copy-on-read' behavior
258
+ when num_worker > 0. refer: https://github.com/pytorch/pytorch/issues/13246
259
+ """
260
+ if self.dataset is None:
261
+ # warnings.warn("it's highly recommended that set persistent_workers=True, "
262
+ # "otherwise this initialize code will run in every epoch beginning."
263
+ # "(ignore me if set)")
264
+ self.dataset = self.dataset_generator()
265
+
266
+ def __len__(self):
267
+ self.initialize_if_needed()
268
+ return len(self.dataset)
269
+
270
+ def get_raw_item(self, index) -> Dict[str, Any]:
271
+ self.initialize_if_needed()
272
+ return self.dataset[index]
273
+
274
+ def __repr__(self) -> str:
275
+ head = "Dataset " + self.__class__.__name__
276
+ body = [
277
+ f"Number of datapoints: {self.__len__()}",
278
+ ]
279
+ body += self.dataset.__repr__().splitlines()
280
+ lines = [head] + [" " * self._repr_indent + line for line in body]
281
+ return "\n".join(lines)
282
+
283
+
284
+ __all__ = ['SingleImageConvDatasetMixin', 'SingleImageConvDataset']
mllm/dataset/single_image_dataset/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flickr import FlickrParser, FlickrDataset
2
+ from .rec import RECDataset, RECComputeMetrics
3
+ from .reg import REGDataset, GCDataset
4
+ from .caption import CaptionDataset
5
+ from .instr import InstructDataset
6
+ from .gqa import GQADataset, GQAComputeMetrics
7
+ from .clevr import ClevrDataset
8
+ from .point_qa import Point_QA_local, Point_QA_twice, V7W_POINT, PointQAComputeMetrics
9
+ from .gpt_gen import GPT4Gen
10
+ from .vcr import VCRDataset, VCRPredDataset
11
+ from .vqav2 import VQAv2Dataset
12
+ from .vqaex import VQAEXDataset
13
+ from .pope import POPEVQADataset
mllm/dataset/single_image_dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (909 Bytes). View file
 
mllm/dataset/single_image_dataset/__pycache__/caption.cpython-310.pyc ADDED
Binary file (1.07 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/clevr.cpython-310.pyc ADDED
Binary file (3.89 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/flickr.cpython-310.pyc ADDED
Binary file (2.65 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/gpt_gen.cpython-310.pyc ADDED
Binary file (1.64 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/gqa.cpython-310.pyc ADDED
Binary file (6.73 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/instr.cpython-310.pyc ADDED
Binary file (1.08 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/point_qa.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/pope.cpython-310.pyc ADDED
Binary file (1.2 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/rec.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/reg.cpython-310.pyc ADDED
Binary file (1.39 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/vcr.cpython-310.pyc ADDED
Binary file (5.39 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/vqaex.cpython-310.pyc ADDED
Binary file (1.48 kB). View file
 
mllm/dataset/single_image_dataset/__pycache__/vqav2.cpython-310.pyc ADDED
Binary file (1.29 kB). View file
 
mllm/dataset/single_image_dataset/caption.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..root import DATASETS, IMAGE_PLACEHOLDER
2
+ from ..utils import MInstrDataset
3
+
4
+
5
+ @DATASETS.register_module()
6
+ class CaptionDataset(MInstrDataset):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,))
9
+
10
+ def __getitem__(self, index):
11
+ item = self.get_raw_item(index)
12
+ img_path = item['img_path']
13
+ caption = item['caption']
14
+
15
+ image = self.get_image(img_path)
16
+ question = self.get_template()
17
+
18
+ ret = {
19
+ 'image': image,
20
+ 'conversations': [
21
+ {
22
+ 'from': 'human',
23
+ 'value': question,
24
+ },
25
+ {
26
+ 'from': 'gpt',
27
+ 'value': caption,
28
+ }
29
+ ]
30
+ }
31
+ return ret
mllm/dataset/single_image_dataset/clevr.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from ..root import DATASETS, IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER, POINTS_PLACEHOLDER
4
+ from ..utils import MInstrDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class ClevrDataset(MInstrDataset):
9
+ def __init__(self, *args, scene_graph_file, version, **kwargs):
10
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
11
+ self.scene_graph_file = scene_graph_file
12
+ self.version = version
13
+ qtype, atype = version.split('-')
14
+ assert qtype in ['q']
15
+ assert atype in ['a', 's', 'bs']
16
+ self.qtype = qtype
17
+ self.atype = atype
18
+
19
+ if scene_graph_file is None:
20
+ self.scene_graph = None
21
+ else:
22
+ self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')]
23
+
24
+ def get_raw_item(self, index):
25
+ question = json.loads(self.data[index])
26
+ if self.scene_graph is None:
27
+ scene = None
28
+ else:
29
+ scene = json.loads(self.scene_graph[question['image_index']])
30
+ return question, scene
31
+
32
+ def __getitem__(self, index):
33
+ question, scene = self.get_raw_item(index)
34
+ img_path = question['image_filename']
35
+ image = self.get_image(img_path)
36
+
37
+ if self.atype == 'a':
38
+ boxes = []
39
+ answer = f"The answer is {question['answer']}."
40
+ answer_boxes_seq = []
41
+ elif self.atype == 's':
42
+ answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=False)
43
+ answer += f" The answer is {question['answer']}."
44
+ elif self.atype == 'bs':
45
+ answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=True)
46
+ answer += f" The answer is {question['answer']}."
47
+ else:
48
+ assert False
49
+
50
+ if self.qtype == 'q':
51
+ query_boxes_seq = []
52
+ final_query = self.get_template().replace(QUESTION_PLACEHOLDER, question['question'])
53
+ else:
54
+ assert False
55
+
56
+ ret = {
57
+ 'image': image,
58
+ 'target': {'points': boxes},
59
+ 'conversations': [
60
+ {
61
+ 'from': 'human',
62
+ 'value': final_query,
63
+ 'points_seq': query_boxes_seq,
64
+ },
65
+ {
66
+ 'from': 'gpt',
67
+ 'value': answer,
68
+ 'points_seq': answer_boxes_seq,
69
+ }
70
+ ]
71
+ }
72
+ return ret
73
+
74
+
75
+ def get_boxes_idx(boxes_list, refs):
76
+ def get_idx(boxes_list, box):
77
+ if box in boxes_list:
78
+ return boxes_list.index(box)
79
+ else:
80
+ boxes_list.append(box)
81
+ return len(boxes_list) - 1
82
+
83
+ idx = [get_idx(boxes_list, box) for box in refs]
84
+ return idx
85
+
86
+
87
+ def clevr_ss_cot(obj, scene, add_ref=False):
88
+ cot = []
89
+ boxes = []
90
+ seq = []
91
+
92
+ def can_add_ref():
93
+ if p['function'] in ['unique', 'union', 'intersect', 'relate', 'same_size', 'same_shape', 'same_material', 'same_color']:
94
+ return True
95
+ if p['function'] in ['scene', 'filter_color', 'filter_material', 'filter_shape', 'filter_size']:
96
+ if idx + 1 < len(obj['program']) and obj['program'][idx + 1]['function'] in ['exist', 'count']:
97
+ return True
98
+ return False
99
+
100
+ for idx, p in enumerate(obj['program']):
101
+ func = f"{p['function']}:{p['value_inputs'][0]}" if 'value_inputs' in p and p['value_inputs'] else p['function']
102
+ inputs = f"[{','.join(map(str, p['inputs']))}]" if p['inputs'] else ""
103
+
104
+ if add_ref and can_add_ref():
105
+ if p['ans']:
106
+ objs = POINTS_PLACEHOLDER
107
+ idx = get_boxes_idx(boxes_list=boxes, refs=[scene['objects'][_]['pixel_coords'][:2] for _ in p['ans']])
108
+ seq.append(idx)
109
+ else:
110
+ objs = f" Found no object."
111
+ else:
112
+ objs = ""
113
+ cot.append(f"{func}{inputs}{objs}")
114
+
115
+ ret = " -> ".join(cot)
116
+ return ret, boxes, seq
mllm/dataset/single_image_dataset/flickr.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+
3
+ from ..root import DATASETS, BOXES_PLACEHOLDER, IMAGE_PLACEHOLDER
4
+ from ..utils import MInstrDataset
5
+ from ..utils.flickr30k_entities_utils import (
6
+ flatten_annotation,
7
+ PHRASE_ED_PLACEHOLDER,
8
+ PHRASE_ST_PLACEHOLDER,
9
+ )
10
+
11
+
12
+ class FlickrParser(Dataset):
13
+ def __init__(self, filename, annotation_dir):
14
+ self.filename = filename
15
+ self.annotation_dir = annotation_dir
16
+
17
+ self.indexes = [line.strip() for line in open(filename, 'r', encoding='utf8')]
18
+ self.data = flatten_annotation(self.annotation_dir, self.indexes)
19
+
20
+ def __len__(self):
21
+ return len(self.data)
22
+
23
+ def __getitem__(self, index):
24
+ return self.data[index]
25
+
26
+ def dump(self, filename):
27
+ import json
28
+ with open(filename, 'w', encoding='utf8') as f:
29
+ for obj in self.data:
30
+ obj_str = json.dumps(obj)
31
+ f.write(obj_str)
32
+ f.write('\n')
33
+
34
+
35
+ @DATASETS.register_module()
36
+ class FlickrDataset(MInstrDataset):
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,))
40
+
41
+ def __len__(self):
42
+ return len(self.data)
43
+
44
+ def __getitem__(self, index):
45
+ item = self.get_raw_item(index)
46
+ img_path = f"{item['image_id']}.jpg"
47
+ caption = item['sentence']
48
+
49
+ image = self.get_image(img_path)
50
+ caption = caption.replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
51
+ question = self.get_template()
52
+
53
+ ret = {
54
+ 'image': image,
55
+ 'target': {'boxes': item['boxes']},
56
+ 'conversations': [
57
+ {
58
+ 'from': 'human',
59
+ 'value': question,
60
+ },
61
+ {
62
+ 'from': 'gpt',
63
+ 'value': caption,
64
+ 'boxes_seq': item['boxes_seq'],
65
+ }
66
+ ]
67
+ }
68
+ return ret
mllm/dataset/single_image_dataset/gpt_gen.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..root import (
2
+ DATASETS,
3
+ QUESTION_PLACEHOLDER,
4
+ IMAGE_PLACEHOLDER,
5
+ BOXES_PLACEHOLDER,
6
+ )
7
+ from ..utils import MInstrDataset
8
+ from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class GPT4Gen(MInstrDataset):
13
+ def __init__(self, *args, version, **kwargs):
14
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
15
+ self.version = version
16
+ assert version in ['a', 'c', 'bc']
17
+
18
+ def __getitem__(self, item):
19
+ raw = self.get_raw_item(item)
20
+ #
21
+ image = self.get_image(raw['img_path'])
22
+ #
23
+ boxes = raw['boxes']
24
+ #
25
+ question = raw['question']
26
+ question = question.replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
27
+ final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
28
+ query_boxes_seq = raw['question_boxes_seq']
29
+
30
+ if self.version == 'a':
31
+ final_answer = raw['answer']
32
+ answer_boxes_seq = None
33
+ elif self.version == 'c':
34
+ final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, '')
35
+ answer_boxes_seq = None
36
+ elif self.version == 'bc':
37
+ final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
38
+ answer_boxes_seq = raw['answer_boxes_seq']
39
+ else:
40
+ assert False
41
+
42
+ ret = {
43
+ 'image': image,
44
+ 'target': {'boxes': boxes},
45
+ 'conversations': [
46
+ {
47
+ 'from': 'human',
48
+ 'value': final_question,
49
+ 'boxes_seq': query_boxes_seq,
50
+ },
51
+ {
52
+ 'from': 'gpt',
53
+ 'value': final_answer,
54
+ 'boxes_seq': answer_boxes_seq,
55
+ }
56
+ ]
57
+ }
58
+ return ret
mllm/dataset/single_image_dataset/gqa.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+
4
+ from ..root import DATASETS, IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER, QUESTION_PLACEHOLDER, METRICS
5
+ from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER
6
+ from ..utils import MInstrDataset, BaseComputeMetrics
7
+
8
+ REFID_PAT = re.compile(r'(\s\((?:(?:\d+(?:,\d+)*)|-)\)\s?)')
9
+ ANS_EXTRACT_PAT = re.compile(r'(?:(?:(?:(?:(?:So t)|(?:T)|(?:t))he answer is)|(?:Answer:)) (.+))')
10
+
11
+
12
+ @DATASETS.register_module()
13
+ class GQADataset(MInstrDataset):
14
+ def __init__(
15
+ self,
16
+ *args,
17
+ scene_graph_file,
18
+ scene_graph_index,
19
+ version,
20
+ question_box_prob=0.5,
21
+ **kwargs
22
+ ):
23
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
24
+ self.scene_graph_file = scene_graph_file
25
+ self.scene_graph_index = scene_graph_index
26
+ self.version = version
27
+ self.question_box_prob = question_box_prob
28
+ qtype, atype = version.split('-')
29
+ assert qtype in ['q', 'qb', 'qbp']
30
+ assert atype in ['a', 'c', 'bc', 's', 'bs', 'l', 'bl']
31
+ self.qtype = qtype
32
+ self.atype = atype
33
+
34
+ assert bool(scene_graph_file) == bool(scene_graph_index)
35
+ if scene_graph_file is not None and scene_graph_index is not None:
36
+ self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')]
37
+ self.scene_index = json.load(open(scene_graph_index, 'r', encoding='utf8'))
38
+ else:
39
+ self.scene_graph = None
40
+ self.scene_index = None
41
+
42
+ def get_raw_item(self, index):
43
+ question = json.loads(self.data[index])
44
+ if self.scene_graph is None:
45
+ return question, None
46
+ scene = json.loads(self.scene_graph[self.scene_index[question['imageId']]])
47
+ return question, scene
48
+
49
+ def __getitem__(self, index):
50
+ question, scene = self.get_raw_item(index)
51
+ img_path = f"{question['imageId']}.jpg"
52
+ image = self.get_image(img_path)
53
+
54
+ # answer
55
+ if self.atype == 'bc':
56
+ boxes = question['cot']['boxes']
57
+ answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
58
+ answer_boxes_seq = question['cot']['seq']
59
+ elif self.atype == 'c':
60
+ boxes = []
61
+ answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, "")
62
+ answer_boxes_seq = []
63
+ elif self.atype == 'bs':
64
+ boxes, bss, answer_boxes_seq = get_bss_example(question, scene)
65
+ answer = f"{bss}. The answer is {question['answer']}."
66
+ elif self.atype == 's':
67
+ boxes = []
68
+ ss = REFID_PAT.sub('', question['semanticStr'])
69
+ answer = f"{ss}. The answer is {question['answer']}."
70
+ answer_boxes_seq = []
71
+ elif self.atype == 'bl':
72
+ boxes, answer, answer_boxes_seq = get_bl_example(question, scene)
73
+ elif self.atype == 'l':
74
+ boxes = []
75
+ _, answer, _ = get_bl_example(question, scene)
76
+ answer = answer.replace(BOXES_PLACEHOLDER, "")
77
+ answer_boxes_seq = []
78
+ elif self.atype == 'a':
79
+ boxes = []
80
+ answer = f"The answer is {question['answer']}."
81
+ answer_boxes_seq = []
82
+ else:
83
+ assert False
84
+
85
+ # question
86
+ if self.qtype == 'q':
87
+ boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene)
88
+ elif self.qtype == 'qb':
89
+ boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene)
90
+ elif self.qtype == 'qbp':
91
+ if self.rng.uniform() > self.question_box_prob:
92
+ boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene)
93
+ else:
94
+ boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene)
95
+ else:
96
+ assert False
97
+
98
+ final_query = self.get_template().replace(QUESTION_PLACEHOLDER, query)
99
+
100
+ ret = {
101
+ 'image': image,
102
+ 'target': {'boxes': boxes},
103
+ 'conversations': [
104
+ {
105
+ 'from': 'human',
106
+ 'value': final_query,
107
+ 'boxes_seq': query_boxes_seq,
108
+ },
109
+ {
110
+ 'from': 'gpt',
111
+ 'value': answer,
112
+ 'boxes_seq': answer_boxes_seq,
113
+ }
114
+ ]
115
+ }
116
+ return ret
117
+
118
+
119
+ def prepare_query_dummy(boxes_list, q, scene):
120
+ return boxes_list, q['question'], []
121
+
122
+
123
+ def prepare_query_box(boxes_list, q, scene):
124
+ def get_boxes_idx(box):
125
+ if box in boxes_list:
126
+ return boxes_list.index(box)
127
+ else:
128
+ boxes_list.append(box)
129
+ return len(boxes_list) - 1
130
+
131
+ def add_boxes_by_rids(rids):
132
+ def get_box_xyxy(obj):
133
+ x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h']
134
+ return x, y, x + w, y + h
135
+
136
+ boxes_idx = []
137
+ for rid in rids:
138
+ ref = scene['objects'][rid]
139
+ ref_box = list(get_box_xyxy(ref))
140
+ boxes_idx.append(get_boxes_idx(ref_box))
141
+ return boxes_idx
142
+
143
+ sent = list(q['question'].split())
144
+ query_boxes_seq = []
145
+ for span, rids_str in q['annotations']['question'].items():
146
+ span = tuple(map(int, span.split(':')))
147
+ if len(span) == 1:
148
+ span = [span[0], span[0] + 1]
149
+ sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}"
150
+ boxes_idx = add_boxes_by_rids(rids_str.split(','))
151
+ query_boxes_seq.append(boxes_idx)
152
+ sent_converted = " ".join(sent).strip()
153
+ return boxes_list, sent_converted, query_boxes_seq
154
+
155
+
156
+ def add_boxes_by_rids(boxes_list, rids, scene):
157
+ def get_boxes_idx(boxes_list, box):
158
+ if box in boxes_list:
159
+ return boxes_list.index(box)
160
+ else:
161
+ boxes_list.append(box)
162
+ return len(boxes_list) - 1
163
+
164
+ def get_box_xyxy(obj):
165
+ x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h']
166
+ return x, y, x + w, y + h
167
+
168
+ boxes_idx = []
169
+ for rid in rids:
170
+ ref = scene['objects'][rid]
171
+ ref_box = list(get_box_xyxy(ref))
172
+ boxes_idx.append(get_boxes_idx(boxes_list, ref_box))
173
+ return boxes_idx
174
+
175
+
176
+ def get_bss_example(question, scene):
177
+ def format_refids(item):
178
+ item = item.strip()[1:-1]
179
+ return item.split(',')
180
+
181
+ s = question['semanticStr']
182
+ print(REFID_PAT.findall(s))
183
+ formats = []
184
+ boxes = []
185
+ seqs = []
186
+
187
+ for item in REFID_PAT.findall(s):
188
+ if '-' in item:
189
+ formats.append('')
190
+ else:
191
+ formats.append('<boxes>')
192
+ refids = format_refids(item)
193
+ idx = add_boxes_by_rids(boxes, refids, scene)
194
+ seqs.append(idx)
195
+ answer = REFID_PAT.sub('{}', s).format(*formats)
196
+
197
+ print(answer)
198
+ print(boxes)
199
+ print(seqs)
200
+ return boxes, answer, seqs
201
+
202
+
203
+ def get_bl_example(ann, scene):
204
+ boxes = []
205
+ boxes_seq = []
206
+
207
+ origin_sent = ann['fullAnswer']
208
+ origin_sent = re.sub('(?:^Yes,)|(?:^No,)', '', origin_sent).strip()
209
+ sent = list(origin_sent.split())
210
+ for span, rids_str in ann['annotations']['fullAnswer'].items():
211
+ span = tuple(map(int, span.split(':')))
212
+ if len(span) == 1:
213
+ span = [span[0], span[0] + 1]
214
+ sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}"
215
+ rids = rids_str.split(',')
216
+ boxes_idx = add_boxes_by_rids(boxes, rids, scene)
217
+ boxes_seq.append(boxes_idx)
218
+
219
+ answer = "".join(sent)
220
+ answer += f"The answer is {ann['answer']}."
221
+ return boxes, answer, boxes_seq
222
+
223
+
224
+ @METRICS.register_module()
225
+ class GQAComputeMetrics(BaseComputeMetrics):
226
+ def extract_ans(self, string: str):
227
+ try:
228
+ found = ANS_EXTRACT_PAT.findall(string.strip())
229
+ if len(found) != 1:
230
+ return None
231
+ return found[0].strip().rstrip('.').strip()
232
+ except (IndexError, AttributeError):
233
+ return None
mllm/dataset/single_image_dataset/instr.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..root import DATASETS
2
+ from ..utils import MInstrDataset
3
+
4
+
5
+ @DATASETS.register_module()
6
+ class InstructDataset(MInstrDataset):
7
+ def __init__(self, *args, add_coco_prefix=False, **kwargs):
8
+ super().__init__(*args, **kwargs, placeholders=(), template_string='', template_file=None)
9
+ self.add_coco_prefix = add_coco_prefix
10
+
11
+ def __getitem__(self, index):
12
+ item = self.get_raw_item(index)
13
+ if self.add_coco_prefix:
14
+ img_path = f"COCO_train2014_{item['image']}"
15
+ else:
16
+ img_path = item['image']
17
+ conversations = item['conversations']
18
+
19
+ image = self.get_image(img_path)
20
+ ret = {
21
+ 'image': image,
22
+ 'conversations': conversations,
23
+ }
24
+ return ret
mllm/dataset/single_image_dataset/point_qa.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from .. import BaseComputeMetrics
4
+ from ..root import (
5
+ DATASETS,
6
+ METRICS,
7
+ QUESTION_PLACEHOLDER,
8
+ IMAGE_PLACEHOLDER,
9
+ BOXES_PLACEHOLDER,
10
+ POINTS_PLACEHOLDER,
11
+ )
12
+ from ..utils import MInstrDataset
13
+
14
+
15
+ # noinspection PyPep8Naming
16
+ @DATASETS.register_module()
17
+ class Point_QA_local(MInstrDataset):
18
+ def __init__(self, *args, version='p', qbp_p_prob=0.5, **kwargs):
19
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
20
+ assert version in ['b', 'p', 'bp']
21
+ self.version = version
22
+ self.qbp_p_prob = qbp_p_prob
23
+
24
+ def __getitem__(self, index):
25
+ item = self.get_raw_item(index)
26
+ # image
27
+ img_path = item['file_path']
28
+ image = self.get_image(img_path)
29
+ # answer
30
+ answer = item['answer']
31
+ # question
32
+ question = item['question']
33
+ bbox = item['bbox']
34
+ point = item['point']
35
+
36
+ version = self.version
37
+ if version == 'bp':
38
+ version = 'p' if self.rng.random() < self.qbp_p_prob else 'b'
39
+ if version == 'b':
40
+ question = question + BOXES_PLACEHOLDER
41
+ query_boxes_seq = [[0]]
42
+ query_points_seq = None
43
+ elif version == 'p':
44
+ question = question + POINTS_PLACEHOLDER
45
+ query_boxes_seq = None
46
+ query_points_seq = [[0]]
47
+ else:
48
+ assert False
49
+ final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
50
+
51
+ ret = {
52
+ 'image': image,
53
+ 'target': {
54
+ 'boxes': [bbox],
55
+ 'points': [point],
56
+ },
57
+ 'conversations': [
58
+ {
59
+ 'from': 'human',
60
+ 'value': final_question,
61
+ 'boxes_seq': query_boxes_seq,
62
+ 'points_seq': query_points_seq,
63
+ },
64
+ {
65
+ 'from': 'gpt',
66
+ 'value': f'The answer is {answer} .',
67
+ }
68
+ ]
69
+ }
70
+ return ret
71
+
72
+
73
+ # noinspection PyPep8Naming
74
+ @DATASETS.register_module()
75
+ class Point_QA_twice(MInstrDataset):
76
+ def __init__(self, *args, version='gq-p', bp_p_prob=0.5, **kwargs):
77
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
78
+ self.version = version
79
+ self.bp_p_prob = bp_p_prob
80
+ qtype, rtype = version.split('-')
81
+ assert qtype in ['oq', 'sq', 'gq']
82
+ assert rtype in ['b', 'p', 'bp']
83
+ self.qtype = qtype
84
+ self.rtype = rtype
85
+
86
+ def __getitem__(self, index):
87
+ item = self.get_raw_item(index)
88
+ # image
89
+ img_path = item['file_path']
90
+ image = self.get_image(img_path)
91
+ # answer
92
+ answer = item['answer']
93
+ # question
94
+ bbox = item['bbox']
95
+ point = item['point']
96
+ if self.qtype == 'oq':
97
+ question = item['obj_question']
98
+ elif self.qtype == 'sq':
99
+ question = item['super_question']
100
+ elif self.qtype == 'gq':
101
+ question = item['general_question']
102
+ else:
103
+ assert False
104
+ rtype = self.rtype
105
+ if rtype == 'bp':
106
+ rtype = 'p' if self.rng.random() < self.bp_p_prob else 'b'
107
+ if rtype == 'p':
108
+ question = question + POINTS_PLACEHOLDER
109
+ query_boxes_seq = None
110
+ query_points_seq = [[0]]
111
+ elif rtype == 'b':
112
+ question = question + BOXES_PLACEHOLDER
113
+ query_boxes_seq = [[0]]
114
+ query_points_seq = None
115
+ else:
116
+ assert False
117
+ final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
118
+
119
+ ret = {
120
+ 'image': image,
121
+ 'target': {
122
+ 'boxes': [bbox],
123
+ 'points': [point],
124
+ },
125
+ 'conversations': [
126
+ {
127
+ 'from': 'human',
128
+ 'value': final_question,
129
+ 'boxes_seq': query_boxes_seq,
130
+ 'points_seq': query_points_seq,
131
+ },
132
+ {
133
+ 'from': 'gpt',
134
+ 'value': f'The answer is {answer} .',
135
+ }
136
+ ]
137
+ }
138
+ return ret
139
+
140
+
141
+ # noinspection PyPep8Naming
142
+ @DATASETS.register_module()
143
+ class V7W_POINT(MInstrDataset):
144
+ def __init__(self, *args, version, do_shuffle_choice=True, **kwargs):
145
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
146
+ self.version = version
147
+ self.do_shuffle_choice = do_shuffle_choice
148
+ assert version in ['p', 'b']
149
+
150
+ def __len__(self):
151
+ return len(self.data)
152
+
153
+ def __getitem__(self, index):
154
+ item = self.get_raw_item(index)
155
+ # image
156
+ img_path = item['file_path']
157
+ image = self.get_image(img_path)
158
+ # question
159
+ bboxes = item['candidates']
160
+ points = []
161
+ final_question = item['question'] + ' Candidates: ' + " ".join([BOXES_PLACEHOLDER for _ in range(len(bboxes))])
162
+ query_boxes_seq = []
163
+ for _ in range(len(bboxes)):
164
+ query_boxes_seq.append([_])
165
+ # answer
166
+ if self.version == 'p':
167
+ final_question += f" answer in point format."
168
+ points.append(item['point'])
169
+ final_answer = f"The answer is {POINTS_PLACEHOLDER} ."
170
+ answer_boxes_seq = None
171
+ answer_points_seq = [[0]]
172
+ elif self.version == 'b':
173
+ final_question += f" answer in box format."
174
+ idx = bboxes.index(item['answer'])
175
+ final_answer = f"The answer is {BOXES_PLACEHOLDER} ."
176
+ answer_boxes_seq = [[idx]]
177
+ answer_points_seq = None
178
+ else:
179
+ assert False
180
+ final_question = self.get_template().replace(QUESTION_PLACEHOLDER, final_question)
181
+ if self.do_shuffle_choice:
182
+ self.rng.shuffle(query_boxes_seq)
183
+ # bboxes, query_boxes_seq, answer_boxes_seq = self.shuffle_boxes(bboxes, query_boxes_seq, answer_boxes_seq)
184
+
185
+ ret = {
186
+ 'image': image,
187
+ 'target': {
188
+ 'boxes': bboxes,
189
+ 'points': points,
190
+ },
191
+ 'conversations': [
192
+ {
193
+ 'from': 'human',
194
+ 'value': final_question,
195
+ 'boxes_seq': query_boxes_seq,
196
+ },
197
+ {
198
+ 'from': 'gpt',
199
+ 'value': final_answer,
200
+ 'boxes_seq': answer_boxes_seq,
201
+ 'points_seq': answer_points_seq,
202
+
203
+ }
204
+ ]
205
+ }
206
+ return ret
207
+
208
+ # def shuffle_boxes(self, bboxes, query_boxes_seq, answer_boxes_seq):
209
+ # idx_mapping = list(range(len(bboxes)))
210
+ # self.rng.shuffle(idx_mapping)
211
+ #
212
+ # new_bboxes = [None for _ in range(len(bboxes))]
213
+ # for idx_old, idx_new in enumerate(idx_mapping):
214
+ # new_bboxes[idx_new] = bboxes[idx_old]
215
+ #
216
+ # if query_boxes_seq is None:
217
+ # new_query_boxes_seq = None
218
+ # else:
219
+ # new_query_boxes_seq = []
220
+ # for boxes in query_boxes_seq:
221
+ # new_boxes = [idx_mapping[box_idx] for box_idx in boxes]
222
+ # new_query_boxes_seq.append(new_boxes)
223
+ #
224
+ # if answer_boxes_seq is None:
225
+ # new_answer_boxes_seq = None
226
+ # else:
227
+ # new_answer_boxes_seq = []
228
+ # for boxes in answer_boxes_seq:
229
+ # new_boxes = [idx_mapping[box_idx] for box_idx in boxes]
230
+ # new_answer_boxes_seq.append(new_boxes)
231
+ #
232
+ # return new_bboxes, new_query_boxes_seq, new_answer_boxes_seq
233
+
234
+
235
+ ANS_EXTRACT_PAT = re.compile(r'(?:The answer is (.+?)\.)')
236
+
237
+
238
+ @METRICS.register_module()
239
+ class PointQAComputeMetrics(BaseComputeMetrics):
240
+ def extract_ans(self, string: str):
241
+ try:
242
+ found = ANS_EXTRACT_PAT.findall(string.strip())
243
+ if len(found) != 1:
244
+ return None
245
+ return found[0].strip()
246
+ except (IndexError, AttributeError):
247
+ return None
mllm/dataset/single_image_dataset/pope.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..root import (
2
+ DATASETS,
3
+ QUESTION_PLACEHOLDER,
4
+ IMAGE_PLACEHOLDER,
5
+ )
6
+ from ..utils import MInstrDataset
7
+
8
+
9
+ @DATASETS.register_module()
10
+ class POPEVQADataset(MInstrDataset):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
13
+
14
+ def __getitem__(self, index):
15
+ item = self.get_raw_item(index)
16
+ image = self.get_image(image_path=item['image'])
17
+
18
+ question = item['text']
19
+ final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
20
+
21
+ label = str(item['label']).lower()
22
+
23
+ ret = {
24
+ 'image': image,
25
+ 'conversations': [
26
+ {
27
+ 'from': 'human',
28
+ 'value': final_question,
29
+ },
30
+ {
31
+ 'from': 'gpt',
32
+ 'value': f"The answer is {label} .",
33
+ },
34
+ ]
35
+ }
36
+ return ret
mllm/dataset/single_image_dataset/rec.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import warnings
4
+ from typing import Dict, Any, Sequence
5
+
6
+ import torch
7
+ from torchvision.ops import box_iou
8
+
9
+ from ..utils import (
10
+ MInstrDataset,
11
+ BaseComputeMetrics,
12
+ )
13
+
14
+ from ..process_function import (
15
+ BoxFormatter,
16
+ )
17
+
18
+ from ..root import (
19
+ DATASETS,
20
+ METRICS,
21
+ IMAGE_PLACEHOLDER,
22
+ BOXES_PLACEHOLDER,
23
+ EXPR_PLACEHOLDER,
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+ logger.setLevel(logging.INFO)
28
+ logging.basicConfig(
29
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
30
+ datefmt="%m/%d/%Y %H:%M:%S",
31
+ handlers=[logging.StreamHandler(sys.stdout), ],
32
+ )
33
+
34
+
35
+ @DATASETS.register_module()
36
+ class RECDataset(MInstrDataset):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, EXPR_PLACEHOLDER))
39
+
40
+ def __getitem__(self, index):
41
+ item = self.get_raw_item(index)
42
+ img_path = item['img_path']
43
+ expr = item['expression']
44
+ bbox = item['bbox']
45
+
46
+ image = self.get_image(img_path)
47
+ question = self.get_template().replace(EXPR_PLACEHOLDER, expr)
48
+
49
+ ret = {
50
+ 'image': image,
51
+ 'target': {
52
+ 'boxes': [bbox],
53
+ },
54
+ 'conversations': [
55
+ {
56
+ 'from': 'human',
57
+ 'value': question,
58
+ },
59
+ {
60
+ 'from': 'gpt',
61
+ 'value': f'Answer: {BOXES_PLACEHOLDER} .',
62
+ 'boxes_seq': [[0]],
63
+ }
64
+ ]
65
+ }
66
+ return ret
67
+
68
+
69
+ @METRICS.register_module()
70
+ class RECComputeMetrics(BaseComputeMetrics):
71
+ def __init__(self, *args, **kwargs):
72
+ super().__init__(*args, **kwargs)
73
+ self.box_formatter: BoxFormatter = self.preprocessor['target']['boxes']
74
+
75
+ def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]:
76
+ failed = 0
77
+ target_failed = 0
78
+
79
+ pred_boxes, target_boxes = [], []
80
+ for pred, target in zip(preds, targets):
81
+ extract_pred = self.extract_ans(pred)
82
+ extract_target = self.extract_ans(target)
83
+ if extract_target is None:
84
+ target_failed += 1
85
+ logger.warning(f"failed to extract ans for target: {target}")
86
+ continue
87
+ if extract_pred is None:
88
+ failed += 1
89
+ logger.warning(f"failed to extract ans for pred: {pred}")
90
+ extract_pred = [0, 0, 0, 0]
91
+ target_boxes.append(extract_target)
92
+ pred_boxes.append(extract_pred)
93
+
94
+ with torch.no_grad():
95
+ target_boxes = torch.tensor(target_boxes)
96
+ pred_boxes = torch.tensor(pred_boxes)
97
+ # normalized box value is too small, so that the area is 0.
98
+ ious = box_iou(pred_boxes * 1000, target_boxes * 1000)
99
+ ious = torch.einsum('i i -> i', ious) # take diag elem
100
+ # NOTE: please note iou only calculate for success target
101
+ iou = ious.mean().item()
102
+ correct = (ious > 0.5).sum().item()
103
+
104
+ # HACK: currently we expand image to square. so this iou is the real iou.
105
+ warn_message = "this iou is calculate on normalized box. just for non-rigorous training progress checking." \
106
+ "the value is consistent with real iou only if image.width == image.height."
107
+ warnings.warn(warn_message)
108
+
109
+ return {
110
+ 'accuracy': 1.0 * correct / len(targets),
111
+ 'target_failed': target_failed,
112
+ 'failed': failed,
113
+ 'iou': iou,
114
+ 'warning': warn_message,
115
+ }
116
+
117
+ def extract_ans(self, string: str):
118
+ try:
119
+ list_of_boxes = self.box_formatter.extract(string)
120
+ if len(list_of_boxes) != 1 or len(list_of_boxes[0]) != 1:
121
+ return None
122
+ box = list_of_boxes[0][0]
123
+ if len(box) != 4:
124
+ return None
125
+ return box
126
+ except Exception as e:
127
+ logger.warning(f"extract_ans for {string} but get exception: {e}")
128
+ return None