fffiloni commited on
Commit
d59f323
·
verified ·
1 Parent(s): 1281541

Migrated from GitHub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. LICENSE +201 -0
  3. ORIGINAL_README.md +166 -0
  4. assets/images/teaser.jpg +0 -0
  5. assets/videos/apt_exp_1_all.gif +3 -0
  6. assets/videos/apt_exp_2_all.gif +3 -0
  7. assets/videos/baodao_exp_1_all.gif +3 -0
  8. assets/videos/exp_1.gif +3 -0
  9. assets/videos/exp_2.gif +3 -0
  10. assets/videos/gf_exp1.gif +3 -0
  11. assets/videos/gf_exp1.mp4 +3 -0
  12. demo.ipynb +0 -0
  13. demo.py +98 -0
  14. demo/demo.py +98 -0
  15. demo/requirements.txt +10 -0
  16. projects/glamm/datasets/__init__.py +7 -0
  17. projects/glamm/datasets/collate_fns/glamm_collate_fn.py +136 -0
  18. projects/glamm/datasets/gcg_dataset.py +349 -0
  19. projects/glamm/datasets/refcoco_segm_dataset.py +195 -0
  20. projects/glamm/datasets/region_level_dataset.py +297 -0
  21. projects/glamm/datasets/semantic_seg_dataset.py +424 -0
  22. projects/glamm/datasets/utils/ade20k_classes.json +30 -0
  23. projects/glamm/datasets/utils/cocostuff_classes.txt +183 -0
  24. projects/glamm/datasets/utils/utils.py +131 -0
  25. projects/glamm/models/glamm.py +183 -0
  26. projects/glamm/models/region_encoder.py +359 -0
  27. projects/glamm/utils.py +280 -0
  28. projects/llava_sam2/configs/sa2va_4b.py +548 -0
  29. projects/llava_sam2/datasets/ChatUniVi_Dataset.py +389 -0
  30. projects/llava_sam2/datasets/GCG_Dataset.py +375 -0
  31. projects/llava_sam2/datasets/Grand_Dataset.py +241 -0
  32. projects/llava_sam2/datasets/MeVIS_Dataset.py +5 -0
  33. projects/llava_sam2/datasets/Osprey_Dataset.py +463 -0
  34. projects/llava_sam2/datasets/ReSAM2_Dataset.py +489 -0
  35. projects/llava_sam2/datasets/ReVOS_Dataset.py +602 -0
  36. projects/llava_sam2/datasets/RefCOCO_Dataset.py +338 -0
  37. projects/llava_sam2/datasets/RefYoutubeVOS_Dataset.py +47 -0
  38. projects/llava_sam2/datasets/__init__.py +15 -0
  39. projects/llava_sam2/datasets/collect_fns.py +206 -0
  40. projects/llava_sam2/datasets/encode_fn.py +144 -0
  41. projects/llava_sam2/datasets/gcg_process.py +297 -0
  42. projects/llava_sam2/datasets/grand_process.py +110 -0
  43. projects/llava_sam2/datasets/utils.py +58 -0
  44. projects/llava_sam2/datasets/vqa_dataset.py +509 -0
  45. projects/llava_sam2/deepspeed_zero2_sam2.json +24 -0
  46. projects/llava_sam2/gradio/app.py +151 -0
  47. projects/llava_sam2/gradio/app_utils.py +293 -0
  48. projects/llava_sam2/models/__init__.py +3 -0
  49. projects/llava_sam2/models/extension/__init__.py +1 -0
  50. projects/llava_sam2/models/extension/sam2_base.py +281 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/videos/apt_exp_1_all.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/videos/apt_exp_2_all.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/videos/baodao_exp_1_all.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/videos/exp_1.gif filter=lfs diff=lfs merge=lfs -text
40
+ assets/videos/exp_2.gif filter=lfs diff=lfs merge=lfs -text
41
+ assets/videos/gf_exp1.gif filter=lfs diff=lfs merge=lfs -text
42
+ assets/videos/gf_exp1.mp4 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ORIGINAL_README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos
2
+
3
+ [\[🏠 Sa2VA\]](https://lxtgh.github.io/project/sa2va) [\[📜 arXiv\]](https://arxiv.org/abs/2501.04001) [\[🤗 HuggingFace\]](https://huggingface.co/collections/ByteDance/sa2va-model-zoo-677e3084d71b5f108d00e093) [\[🎥 Introduction\]]() [\[🧑‍💻 GitHub\]](https://github.com/magic-research/Sa2VA) [\[Online Demo (Sa2VA-4B)\]](https://5512470799b6b35fbc.gradio.live/)
4
+
5
+
6
+ [**Haobo Yuan**](https://yuanhaobo.me/)<sup>1*</sup> · [**Xiangtai Li**](https://scholar.google.com/citations?user=NmHgX-wAAAAJ)<sup>2*&dagger;</sup> · [**Tao Zhang**](https://zhang-tao-whu.github.io/)<sup>2,3*</sup> · [**Zilong Huang**](http://speedinghzl.github.io/)<sup>2</sup> · [**Shilin Xu**](https://xushilin1.github.io/)<sup>4</sup> ·[**Shunping Ji**](https://scholar.google.com/citations?user=FjoRmF4AAAAJ&hl=en)<sup>3</sup> ·[**Yunhai Tong**](https://scholar.google.com/citations?user=T4gqdPkAAAAJ&hl=zh-CN)<sup>4</sup> ·
7
+
8
+ [**Lu Qi**](https://luqi.info/)<sup>2</sup> · [**Jiashi Feng**](https://sites.google.com/site/jshfeng/)<sup>2</sup> · [**Ming-Hsuan Yang**](https://faculty.ucmerced.edu/mhyang/)<sup>1</sup>
9
+
10
+ <sup>1</sup>UC Merced&emsp;&emsp;&emsp;&emsp;<sup>2</sup>ByteDance Seed&emsp;&emsp;&emsp;&emsp;<sup>3</sup>WHU&emsp;&emsp;&emsp;&emsp;<sup>4</sup>PKU
11
+
12
+ &dagger; project lead&emsp;* the first three authors equally contribute to the work.
13
+
14
+ ![Teaser](assets/images/teaser.jpg)
15
+
16
+ ## Overiew
17
+ This repository contains the code for the paper "Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos".
18
+
19
+ Sa2VA is the the first unified model for dense grounded understanding of both images and videos. Unlike existing multi-modal large language models, which are often limited to specific modalities and tasks, Sa2VA supports a wide range of image and video tasks, including referring segmentation and conversation, with minimal one-shot instruction tuning. Sa2VA combines SAM-2, a foundation video segmentation model, with LLaVA, an advanced vision-language model, and unifies text, image, and video into a shared LLM token space.
20
+
21
+ ## Model Zoo
22
+ We provide the following models:
23
+ | Model Name | Base MLLM | Language Part | HF Link |
24
+ |:----------:|:-----------------------------------------------------------------:|:-----------------------------------------------------------------------------:|:----------------------------------------------------:|
25
+ | Sa2VA-1B | [InternVL2.0-1B](https://huggingface.co/OpenGVLab/InternVL2-1B) | [Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-1B) |
26
+ | Sa2VA-4B | [InternVL2.5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B) | [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-4B) |
27
+ | Sa2VA-8B | [InternVL2.5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) | [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-8B) |
28
+
29
+ ## Gradio Demos
30
+
31
+ We provide a script that implements interactive chat using gradio, which requires installing `gradio==4.42.0`. You can try it to quickly build a chat interface locally.
32
+ ```shell
33
+ PYTHONPATH=. python projects/llava_sam2/gradio/app.py ByteDance/Sa2VA-4B
34
+ ```
35
+
36
+ ## Quick Start
37
+
38
+ Our Sa2VA model is available on 🤗HuggingFace. With very few steps, you can try it with your own data. You can install the `demo/requirements.txt` to avoid training-only packages.
39
+
40
+
41
+ **Option1 - scripts:**
42
+
43
+ Supposing you have a folder (`PATH_TO_FOLDER`) that contains images of a video, you can use the following script to chat with the Sa2VA model or segment the objects in the videos.
44
+
45
+ ```bash
46
+ > cd scripts
47
+ > python demo.py PATH_TO_FOLDER --model_path ByteDance/Sa2VA-8B --work-dir OUTPUT_DIR --text "<image>Please describe the video content."
48
+ ```
49
+
50
+ If the output contains the segmentation results, the results will be saved to `OUTPUT_DIR`.
51
+
52
+ **Option2 - Jupter Notebook:**
53
+
54
+ Please refer to `demo.ipynb`.
55
+
56
+ ## Demo
57
+
58
+ <details open>
59
+ <summary>Demo 1</summary>
60
+ Input Video (Source: La La Land 2016):
61
+
62
+ ![Error](assets/videos/exp_1.gif)
63
+
64
+ Instruction: "Please segment the girl wearing the yellow dress."
65
+ </details>
66
+
67
+ <details open>
68
+ <summary>Demo 2</summary>
69
+ Input Video (Source: La La Land 2016):
70
+
71
+ ![Error](assets/videos/exp_2.gif)
72
+
73
+ Instruction: "Please segment the main character."
74
+ </details>
75
+
76
+
77
+ <details open>
78
+ <summary>Demo 3</summary>
79
+ Input Video (Source: Internet):
80
+
81
+ ![Error](assets/videos/apt_exp_1_all.gif)
82
+
83
+ Instruction: "Please segment the person wearing sun glasses."
84
+ </details>
85
+
86
+
87
+ <details open>
88
+ <summary>Demo 4</summary>
89
+ Input Video (Source: Internet):
90
+
91
+ ![Error](assets/videos/apt_exp_2_all.gif)
92
+
93
+ Instruction: "Instruction: "Please segment the singing girl."
94
+ </details>
95
+
96
+ <details open>
97
+ <summary>Demo 5</summary>
98
+ Input Video:
99
+
100
+ ![Error](assets/videos/gf_exp1.gif)
101
+
102
+ Instruction: "What is the atmosphere of the scene?"
103
+
104
+ Answer: "The scene has a dark and mysterious atmosphere, with the men dressed in suits and ties, and the dimly lit room."
105
+ </details>
106
+
107
+
108
+ ## Training
109
+ <details open>
110
+ <summary>Installation</summary>
111
+
112
+ 1. Please install the python and pytorch first:
113
+ ```bash
114
+ > conda create -n vlm python=3.10
115
+ > conda activate vlm
116
+ > conda install pytorch==2.3.1 torchvision==0.18.1 pytorch-cuda=12.1 cuda -c pytorch -c "nvidia/label/cuda-12.1.0" -c "nvidia/label/cuda-12.1.1"
117
+ ```
118
+
119
+ 2. Install mmcv:
120
+ ```bash
121
+ > pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3/index.html
122
+ ```
123
+
124
+ 3. Install other dependencies:
125
+ ```bash
126
+ > pip install -r requirements.txt
127
+ ```
128
+ </details>
129
+
130
+ <details open>
131
+ <summary>Pretrained Model Preparation</summary>
132
+
133
+ You are expected to download the following pretrained models and place them in the `./pretrained` directory:
134
+ - [sam2_hiera_large.pt](https://huggingface.co/facebook/sam2-hiera-large)
135
+ - [InternVL2_5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B)
136
+
137
+ </details>
138
+
139
+ <details open>
140
+ <summary>Data Preparation</summary>
141
+
142
+ (TODO) Please download the training datasets and place them in the `data` directory. The download link is [here](https://huggingface.co/datasets/Dense-World/Sa2VA-Training).
143
+
144
+ </details>
145
+
146
+
147
+ <details open>
148
+ <summary>Training Script</summary>
149
+
150
+ Please run the following script to train:
151
+ ```bash
152
+ > bash tools/dist.sh train projects/llava_sam2/configs/sa2va_4b.py 8
153
+ ```
154
+ </details>
155
+
156
+
157
+ ## References
158
+ If you find this repository useful, please consider referring the following paper:
159
+ ```
160
+ @article{sa2va,
161
+ title={Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos},
162
+ author={Yuan, Haobo and Li, Xiangtai and Zhang, Tao and Huang, Zilong and Xu, Shilin and Ji, Shunping and Tong, Yunhai and Qi, Lu and Feng, Jiashi and Yang, Ming-Hsuan},
163
+ journal={arXiv},
164
+ year={2025}
165
+ }
166
+ ```
assets/images/teaser.jpg ADDED
assets/videos/apt_exp_1_all.gif ADDED

Git LFS Details

  • SHA256: ddf6e915c5f5f00e11136b4342c63b601fd446f714967333db4995c6ee4b797c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
assets/videos/apt_exp_2_all.gif ADDED

Git LFS Details

  • SHA256: eb9a946270dd9d3a1f1f0b30ff55d70abea9cf54bc52499cb07813e80a8f1e33
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
assets/videos/baodao_exp_1_all.gif ADDED

Git LFS Details

  • SHA256: e762e253dafb71ecf90d48144422bcd6fdcdf9c6a3c67571ee1a9d0232e32f03
  • Pointer size: 132 Bytes
  • Size of remote file: 2.95 MB
assets/videos/exp_1.gif ADDED

Git LFS Details

  • SHA256: 7b63b1465808dbe658761936b61a10f3e72bfc04f0b144a9e9103fcfaa810147
  • Pointer size: 132 Bytes
  • Size of remote file: 4.26 MB
assets/videos/exp_2.gif ADDED

Git LFS Details

  • SHA256: fad52f51a9f4238106923217e1d60c3ebc563c77117c49988496a67699ead397
  • Pointer size: 132 Bytes
  • Size of remote file: 3.84 MB
assets/videos/gf_exp1.gif ADDED

Git LFS Details

  • SHA256: 2cb7962fa6d20f4535b07e526c8a65edfcee55d5c2ec79308f98dde24c209842
  • Pointer size: 132 Bytes
  • Size of remote file: 4.82 MB
assets/videos/gf_exp1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:272f4246fbb62aa690811e01d5f8aecaac3d157cc01a9859de79675ee5d4f7cf
3
+ size 15332128
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
demo.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ from PIL import Image
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ import cv2
8
+ try:
9
+ from mmengine.visualization import Visualizer
10
+ except ImportError:
11
+ Visualizer = None
12
+ print("Warning: mmengine is not installed, visualization is disabled.")
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description='Video Reasoning Segmentation')
17
+ parser.add_argument('image_folder', help='Path to image file')
18
+ parser.add_argument('--model_path', default="ByteDance/Sa2VA-8B")
19
+ parser.add_argument('--work-dir', default=None, help='The dir to save results.')
20
+ parser.add_argument('--text', type=str, default="<image>Please describe the video content.")
21
+ parser.add_argument('--select', type=int, default=-1)
22
+ args = parser.parse_args()
23
+ return args
24
+
25
+
26
+ def visualize(pred_mask, image_path, work_dir):
27
+ visualizer = Visualizer()
28
+ img = cv2.imread(image_path)
29
+ visualizer.set_image(img)
30
+ visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
31
+ visual_result = visualizer.get_image()
32
+
33
+ output_path = os.path.join(work_dir, os.path.basename(image_path))
34
+ cv2.imwrite(output_path, visual_result)
35
+
36
+ if __name__ == "__main__":
37
+ cfg = parse_args()
38
+ model_path = cfg.model_path
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_path,
41
+ torch_dtype="auto",
42
+ device_map="auto",
43
+ trust_remote_code=True
44
+ )
45
+
46
+ tokenizer = AutoTokenizer.from_pretrained(
47
+ model_path,
48
+ trust_remote_code=True
49
+ )
50
+
51
+ image_files = []
52
+ image_paths = []
53
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
54
+ for filename in sorted(list(os.listdir(cfg.image_folder))):
55
+ if os.path.splitext(filename)[1].lower() in image_extensions:
56
+ image_files.append(filename)
57
+ image_paths.append(os.path.join(cfg.image_folder, filename))
58
+
59
+ vid_frames = []
60
+ for img_path in image_paths:
61
+ img = Image.open(img_path).convert('RGB')
62
+ vid_frames.append(img)
63
+
64
+
65
+ if cfg.select > 0:
66
+ img_frame = vid_frames[cfg.select - 1]
67
+
68
+ print(f"Selected frame {cfg.select}")
69
+ print(f"The input is:\n{cfg.text}")
70
+ result = model.predict_forward(
71
+ image=img_frame,
72
+ text=cfg.text,
73
+ tokenizer=tokenizer,
74
+ )
75
+ else:
76
+ print(f"The input is:\n{cfg.text}")
77
+ result = model.predict_forward(
78
+ video=vid_frames,
79
+ text=cfg.text,
80
+ tokenizer=tokenizer,
81
+ )
82
+
83
+ prediction = result['prediction']
84
+ print(f"The output is:\n{prediction}")
85
+
86
+ if '[SEG]' in prediction and Visualizer is not None:
87
+ _seg_idx = 0
88
+ pred_masks = result['prediction_masks'][_seg_idx]
89
+ for frame_idx in range(len(vid_frames)):
90
+ pred_mask = pred_masks[frame_idx]
91
+ if cfg.work_dir:
92
+ os.makedirs(cfg.work_dir, exist_ok=True)
93
+ visualize(pred_mask, image_paths[frame_idx], cfg.work_dir)
94
+ else:
95
+ os.makedirs('./temp_visualize_results', exist_ok=True)
96
+ visualize(pred_mask, image_paths[frame_idx], './temp_visualize_results')
97
+ else:
98
+ pass
demo/demo.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ from PIL import Image
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ import cv2
8
+ try:
9
+ from mmengine.visualization import Visualizer
10
+ except ImportError:
11
+ Visualizer = None
12
+ print("Warning: mmengine is not installed, visualization is disabled.")
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description='Video Reasoning Segmentation')
17
+ parser.add_argument('image_folder', help='Path to image file')
18
+ parser.add_argument('--model_path', default="ByteDance/Sa2VA-8B")
19
+ parser.add_argument('--work-dir', default=None, help='The dir to save results.')
20
+ parser.add_argument('--text', type=str, default="<image>Please describe the video content.")
21
+ parser.add_argument('--select', type=int, default=-1)
22
+ args = parser.parse_args()
23
+ return args
24
+
25
+
26
+ def visualize(pred_mask, image_path, work_dir):
27
+ visualizer = Visualizer()
28
+ img = cv2.imread(image_path)
29
+ visualizer.set_image(img)
30
+ visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
31
+ visual_result = visualizer.get_image()
32
+
33
+ output_path = os.path.join(work_dir, os.path.basename(image_path))
34
+ cv2.imwrite(output_path, visual_result)
35
+
36
+ if __name__ == "__main__":
37
+ cfg = parse_args()
38
+ model_path = cfg.model_path
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_path,
41
+ torch_dtype="auto",
42
+ device_map="auto",
43
+ trust_remote_code=True
44
+ )
45
+
46
+ tokenizer = AutoTokenizer.from_pretrained(
47
+ model_path,
48
+ trust_remote_code=True
49
+ )
50
+
51
+ image_files = []
52
+ image_paths = []
53
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
54
+ for filename in sorted(list(os.listdir(cfg.image_folder))):
55
+ if os.path.splitext(filename)[1].lower() in image_extensions:
56
+ image_files.append(filename)
57
+ image_paths.append(os.path.join(cfg.image_folder, filename))
58
+
59
+ vid_frames = []
60
+ for img_path in image_paths:
61
+ img = Image.open(img_path).convert('RGB')
62
+ vid_frames.append(img)
63
+
64
+
65
+ if cfg.select > 0:
66
+ img_frame = vid_frames[cfg.select - 1]
67
+
68
+ print(f"Selected frame {cfg.select}")
69
+ print(f"The input is:\n{cfg.text}")
70
+ result = model.predict_forward(
71
+ image=img_frame,
72
+ text=cfg.text,
73
+ tokenizer=tokenizer,
74
+ )
75
+ else:
76
+ print(f"The input is:\n{cfg.text}")
77
+ result = model.predict_forward(
78
+ video=vid_frames,
79
+ text=cfg.text,
80
+ tokenizer=tokenizer,
81
+ )
82
+
83
+ prediction = result['prediction']
84
+ print(f"The output is:\n{prediction}")
85
+
86
+ if '[SEG]' in prediction and Visualizer is not None:
87
+ _seg_idx = 0
88
+ pred_masks = result['prediction_masks'][_seg_idx]
89
+ for frame_idx in range(len(vid_frames)):
90
+ pred_mask = pred_masks[frame_idx]
91
+ if cfg.work_dir:
92
+ os.makedirs(cfg.work_dir, exist_ok=True)
93
+ visualize(pred_mask, image_paths[frame_idx], cfg.work_dir)
94
+ else:
95
+ os.makedirs('./temp_visualize_results', exist_ok=True)
96
+ visualize(pred_mask, image_paths[frame_idx], './temp_visualize_results')
97
+ else:
98
+ pass
demo/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ transformers==4.42.3
4
+ opencv-python-headless<4.10
5
+ peft<0.14.0
6
+ timm==1.0.9
7
+ einops==0.8.0
8
+ flash_attn
9
+ sentencepiece==0.2.0
10
+ mmengine<1
projects/glamm/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .semantic_seg_dataset import SemanticSegDataset, ADE20kSemanticSegDataset, \
2
+ COCOStuffSemanticSegDataset, PascalPartSemanticSegDataset, PacoSemanticSegDataset
3
+ from .gcg_dataset import GCGDataset, GranDfGCGDataset, RefCOCOgGCGDataset, OpenPsgGCGDataset, Flickr30kGCGDataset
4
+ from .region_level_dataset import RefCocoGRegionDataset, VisualGenomeRegionDataset
5
+ from .refcoco_segm_dataset import ReferSegmDataset
6
+ from .utils.utils import *
7
+ from .collate_fns.glamm_collate_fn import glamm_collate_fn
projects/glamm/datasets/collate_fns/glamm_collate_fn.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Sequence
2
+
3
+ import torch
4
+ from torch.nn.utils.rnn import pad_sequence
5
+
6
+ from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
7
+ pad_for_sequence_parallel)
8
+ from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
9
+
10
+
11
+ def glamm_collate_fn(instances: Sequence[Dict],
12
+ pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
13
+ return_hf_format: bool = False,
14
+ use_varlen_attn: bool = False):
15
+ seq_parallel_world_size = get_sequence_parallel_world_size()
16
+
17
+ input_ids, labels = [], []
18
+ has_image = any(inst.get('pixel_values') is not None for inst in instances)
19
+ has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances)
20
+ has_mask = any(inst.get('masks') is not None for inst in instances)
21
+ has_bboxes = any(inst.get('bboxes') is not None for inst in instances)
22
+ has_points = any(inst.get('points') is not None for inst in instances)
23
+
24
+ if use_varlen_attn:
25
+ position_ids, cumulative_len = [], []
26
+ assert len(instances) == 1, (
27
+ f'If utilizing varlen attention, the batch size should be'
28
+ f' set to 1, but got {len(instances)}')
29
+ assert not has_image, 'Currently, it is not configured to '
30
+ 'accommodate the use of varlen Attention in multimodal training'
31
+
32
+ if has_image:
33
+ pixel_values = []
34
+ if has_grounding_image:
35
+ grounding_pixel_values = []
36
+ if has_mask:
37
+ object_masks = []
38
+ if has_bboxes:
39
+ object_bboxes = []
40
+ if has_points:
41
+ prompt_points = []
42
+
43
+ for example in instances:
44
+ input_ids.append(torch.LongTensor(example['input_ids']))
45
+ labels.append(torch.LongTensor(example['labels']))
46
+ if use_varlen_attn:
47
+ cumulative_len.append(torch.IntTensor(example['cumulative_len']))
48
+ position_ids.append(torch.LongTensor(example['position_ids']))
49
+
50
+ if has_image:
51
+ pixel_values.append(example['pixel_values'])
52
+ if has_grounding_image:
53
+ grounding_pixel_values.append(example['g_pixel_values'])
54
+ if has_mask:
55
+ if 'masks' in example.keys() and example['masks'] is not None:
56
+ object_masks.append(example['masks'])
57
+ if has_bboxes:
58
+ if 'bboxes' in example.keys() and example['bboxes'] is not None:
59
+ object_bboxes.append(example['bboxes'])
60
+ if has_points:
61
+ if 'points' in example.keys() and example['points'] is not None:
62
+ prompt_points.append(example['points'])
63
+
64
+ ori_length = [len(ids) for ids in input_ids]
65
+ if len(instances) > 1:
66
+ input_ids = pad_sequence(
67
+ input_ids, batch_first=True, padding_value=pad_index)
68
+ labels = pad_sequence(
69
+ labels, batch_first=True, padding_value=IGNORE_INDEX)
70
+ else:
71
+ input_ids = torch.stack(input_ids)
72
+ labels = torch.stack(labels)
73
+
74
+ if use_varlen_attn:
75
+ assert input_ids.size(1) % seq_parallel_world_size == 0
76
+ attention_mask = None
77
+ position_ids = torch.stack(position_ids, dim=0)
78
+ else:
79
+ # Some tokenizers have the same eos token and pad token, so input_ids
80
+ # cannot be masked directly based on the pad token id.
81
+ attention_mask = torch.zeros_like(input_ids).bool()
82
+ for i, length in enumerate(ori_length):
83
+ attention_mask[i, :length] = True
84
+
85
+ bs, seq_len = input_ids.shape
86
+ position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
87
+
88
+ if seq_parallel_world_size > 1:
89
+ input_ids = pad_for_sequence_parallel(input_ids, pad_index)
90
+ labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
91
+ position_ids = pad_for_sequence_parallel(position_ids, 0)
92
+ if attention_mask is not None:
93
+ attention_mask = pad_for_sequence_parallel(attention_mask, 0)
94
+
95
+ if use_varlen_attn:
96
+ max_seqlen = (
97
+ cumulative_len[0][1:] - # noqa: W504
98
+ cumulative_len[0][:-1]).max().item()
99
+ data_dict = {
100
+ 'input_ids': input_ids,
101
+ 'cumulative_len': cumulative_len,
102
+ 'position_ids': position_ids,
103
+ 'labels': labels,
104
+ 'max_seqlen': max_seqlen
105
+ }
106
+ else:
107
+ data_dict = {
108
+ 'input_ids': input_ids,
109
+ 'attention_mask': attention_mask,
110
+ 'position_ids': position_ids,
111
+ 'labels': labels
112
+ }
113
+
114
+ if has_image:
115
+ if all(x.shape == pixel_values[0].shape for x in pixel_values):
116
+ pixel_values = torch.stack(pixel_values, dim=0)
117
+ data_dict['pixel_values'] = pixel_values
118
+
119
+ if has_grounding_image:
120
+ # if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values):
121
+ # grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0)
122
+ data_dict['g_pixel_values'] = grounding_pixel_values
123
+
124
+ if has_mask:
125
+ data_dict['masks'] = object_masks
126
+
127
+ if has_bboxes:
128
+ data_dict['bboxes'] = object_bboxes
129
+
130
+ if has_points:
131
+ data_dict['points'] = prompt_points
132
+
133
+ if return_hf_format:
134
+ return data_dict
135
+ else:
136
+ return {'data': data_dict, 'data_samples': None}
projects/glamm/datasets/gcg_dataset.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import glob
4
+ import json
5
+ import logging
6
+ import os
7
+ import torch
8
+
9
+ from mmengine import print_log
10
+ from mmengine.config import Config, ConfigDict
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+ import numpy as np
14
+ import torch.nn.functional as F
15
+ from pycocotools.coco import COCO
16
+ from pycocotools import mask as mask_utils
17
+
18
+ from xtuner.registry import BUILDER
19
+
20
+ from xtuner.dataset.utils import encode_fn
21
+ from xtuner.dataset.map_fns import llava_map_fn
22
+
23
+ from projects.glamm.datasets.utils.utils import expand2square
24
+
25
+ from projects.glamm.datasets.utils.utils import GCG_QUESTIONS, ANSWER_LIST
26
+ from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
27
+ class GCGDataset(Dataset):
28
+ def __init__(self,
29
+ image_folder,
30
+ image_processor,
31
+ data_path=None,
32
+ tokenizer=None,
33
+ template_map_fn=None,
34
+ max_length=2048,
35
+ pad_image_to_square=False,
36
+ repeats=1,
37
+ num_classes_per_sample=3,
38
+ extra_image_processor=None):
39
+ super().__init__()
40
+ self.question_templates = GCG_QUESTIONS
41
+ if extra_image_processor is not None:
42
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
43
+ self.num_classes_per_sample = num_classes_per_sample
44
+ self.tokenizer = BUILDER.build(tokenizer)
45
+
46
+ self.tokenizer.add_tokens(
47
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
48
+ )
49
+ reg_tokens = ['<bbox>', '<point>']
50
+ segmentation_tokens = ['[SEG]']
51
+ phrase_tokens = ['<p>', '</p>']
52
+ special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
53
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
54
+
55
+ self.max_length = max_length
56
+ self.template_map_fn = BUILDER.build(template_map_fn)
57
+
58
+ self.text_data = self.json_file_preprocess(data_path, image_folder)
59
+ self.image_folder = image_folder
60
+
61
+ self.image_processor = BUILDER.build(image_processor)
62
+ size = self.image_processor.crop_size
63
+
64
+ if isinstance(size, dict):
65
+ self.image_w, self.image_h = size['width'], size['height']
66
+ elif isinstance(size, int):
67
+ self.image_h, self.image_w = size, size
68
+ else:
69
+ self.image_w, self.image_h = size
70
+
71
+ self.pad_image_to_square = pad_image_to_square
72
+ self.repeats = repeats
73
+
74
+ def json_file_preprocess(self, data_path, image_folder=None):
75
+ with open(data_path, 'r') as f:
76
+ json_data = json.load(f)
77
+ return json_data
78
+
79
+ @property
80
+ def modality_length(self):
81
+ length_list = []
82
+ for data_dict in self.text_data:
83
+ cur_len = 100
84
+ length_list.append(cur_len)
85
+ return length_list * self.repeats
86
+
87
+ def __len__(self):
88
+ return len(self.text_data) * self.repeats
89
+
90
+ def real_len(self):
91
+ return len(self.text_data)
92
+
93
+ def _parse_annotations(self, ann_info):
94
+ image_path = os.path.join(self.image_folder, ann_info['file_name'])
95
+ image = Image.open(image_path).convert('RGB')
96
+ if hasattr(self, 'extra_image_processor'):
97
+ g_image = np.array(image) # for grounding
98
+ g_image = self.extra_image_processor.apply_image(g_image)
99
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
100
+ ann_info['g_pixel_values'] = g_pixel_values
101
+
102
+ width, height = image.size
103
+ if self.pad_image_to_square:
104
+ image = expand2square(
105
+ image, tuple(int(x * 255) for x in self.image_processor.image_mean))
106
+ image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
107
+ ann_info['pixel_values'] = image
108
+
109
+ caption = ann_info['caption'].strip('"').strip()
110
+ masks, phrases, tokens_positive = [], [], []
111
+ for word, grounding in ann_info["groundings"].items():
112
+ phrases.append(word)
113
+ tokens_positive.append(grounding["token_positives"])
114
+
115
+ # Convert segmentation to binary mask
116
+ binary_mask = np.zeros((height, width), dtype=np.uint8)
117
+ for rle in grounding["rle_masks"]:
118
+ m = mask_utils.decode(rle).astype(np.uint8)
119
+ binary_mask += m.squeeze()
120
+ masks.append(binary_mask)
121
+
122
+ def sort_by_start_index(items, order):
123
+ return [items[i] for i in order]
124
+
125
+ phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
126
+ masks = sort_by_start_index(masks, phrase_order)
127
+ phrases = sort_by_start_index(phrases, phrase_order)
128
+ tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
129
+
130
+ ann_info.update({
131
+ 'image_path': image_path,
132
+ 'caption': caption,
133
+ 'masks': masks,
134
+ 'phrases': phrases,
135
+ 'tokens_positive': tokens_positive,
136
+ })
137
+ return ann_info
138
+
139
+ def create_conversation(self, caption, tokens_positive):
140
+ question = random.choice(self.question_templates).strip()
141
+
142
+ # Prepare caption with tags
143
+ def tag_caption(caption, tokens):
144
+ for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
145
+ caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
146
+ return caption
147
+
148
+ detailed_answer = tag_caption(caption, tokens_positive)
149
+
150
+ question = 'The <image> provides an overview of the picture.\n' + question
151
+ conversation = [{'input': question, 'output': detailed_answer}]
152
+ return conversation
153
+
154
+ def __getitem__(self, index):
155
+ index = index % self.real_len()
156
+ data_dict = {}
157
+ ann_info = copy.deepcopy(self.text_data[index])
158
+ ann_info = self._parse_annotations(ann_info)
159
+
160
+ data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values')
161
+ data_dict['pixel_values'] = ann_info.pop('pixel_values')
162
+ if len(ann_info['masks']) == 0:
163
+ return self.__getitem__(0)
164
+ data_dict['masks'] = torch.from_numpy(np.stack(ann_info['masks'], axis=0))
165
+
166
+ conversation = self.create_conversation(ann_info['caption'], ann_info['tokens_positive'])
167
+ data_dict['conversation'] = conversation
168
+
169
+ result = self.template_map_fn(data_dict)
170
+ data_dict.update(result)
171
+
172
+ result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
173
+ data_dict.update(result)
174
+
175
+ return data_dict
176
+
177
+ class GranDfGCGDataset(GCGDataset):
178
+ pass
179
+ class RefCOCOgGCGDataset(GCGDataset):
180
+ def json_file_preprocess(self, data_path, image_folder=None):
181
+ with open(data_path, 'r') as f:
182
+ json_data = json.load(f)
183
+ return [list(line.values())[0] for line in json_data]
184
+
185
+ def _parse_annotations(self, ann_info):
186
+ image_path = os.path.join(self.image_folder, ann_info['img_file_name'])
187
+ image = Image.open(image_path).convert('RGB')
188
+ if hasattr(self, 'extra_image_processor'):
189
+ g_image = np.array(image) # for grounding
190
+ g_image = self.extra_image_processor.apply_image(g_image)
191
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
192
+ ann_info['g_pixel_values'] = g_pixel_values
193
+
194
+ width, height = image.size
195
+ if self.pad_image_to_square:
196
+ image = expand2square(
197
+ image, tuple(int(x * 255) for x in self.image_processor.image_mean))
198
+ image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
199
+ ann_info['pixel_values'] = image
200
+
201
+ caption = ann_info['caption'].strip('"').strip().lower()
202
+ masks, phrases, tokens_positive = [], [], []
203
+ for detail in ann_info['refs']:
204
+ phrase = detail['sentence']
205
+ if phrase.lower() in caption:
206
+ phrases.append(phrase)
207
+ index = caption.find(phrase)
208
+ end_index = index + len(phrase) if index != -1 else -1
209
+ tokens_positive.append([index, end_index])
210
+
211
+ binary_mask = np.zeros((height, width), dtype=np.uint8)
212
+ for seg in detail["segmentation"]:
213
+ rles = mask_utils.frPyObjects([seg], height, width)
214
+ m = mask_utils.decode(rles)
215
+ m = m.astype(np.uint8)
216
+ binary_mask += m.squeeze()
217
+ masks.append(binary_mask)
218
+
219
+ def sort_by_start_index(items, order):
220
+ return [items[i] for i in order]
221
+
222
+ phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
223
+ masks = sort_by_start_index(masks, phrase_order)
224
+ phrases = sort_by_start_index(phrases, phrase_order)
225
+ tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
226
+
227
+ ann_info.update({
228
+ 'image_path': image_path,
229
+ 'caption': caption,
230
+ 'masks': masks,
231
+ 'phrases': phrases,
232
+ 'tokens_positive': tokens_positive,
233
+ })
234
+ return ann_info
235
+
236
+ class OpenPsgGCGDataset(GCGDataset):
237
+ pass
238
+
239
+ class Flickr30kGCGDataset(GCGDataset):
240
+
241
+ def json_file_preprocess(self, data_path, image_folder=None):
242
+ def filter_images(data_infos, min_size):
243
+ return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size]
244
+
245
+ self.coco = COCO(data_path)
246
+ self.image_ids = self.coco.getImgIds()
247
+ data_infos = []
248
+ total_ann_ids = []
249
+ removed_img_count = 0
250
+ for img_id in self.image_ids:
251
+ info = self.coco.loadImgs([img_id])[0]
252
+ if len(info['caption'].split(' ')) < 3:
253
+ removed_img_count += 1
254
+ continue
255
+ info['filename'] = info['file_name'].split('_')[-1]
256
+ info['height'] = int(info['height'])
257
+ info['width'] = int(info['width'])
258
+ data_infos.append(info)
259
+ ann_ids = self.coco.getAnnIds(imgIds=[img_id])
260
+ total_ann_ids.extend(ann_ids)
261
+ assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!"
262
+ print(f'Removed {removed_img_count} images.')
263
+ data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)]
264
+
265
+ return data_infos
266
+
267
+ def _parse_annotations(self, img_info):
268
+ ann_ids = self.coco.getAnnIds(imgIds=img_info['id'])
269
+ ann_info = self.coco.loadAnns(ann_ids)
270
+
271
+ annotations = {'phrases': [], 'caption': img_info['caption'], 'masks': [], 'tokens_positive': []}
272
+ image_path = os.path.join(self.image_folder, img_info['file_name'])
273
+ image = Image.open(image_path).convert('RGB')
274
+ if hasattr(self, 'extra_image_processor'):
275
+ g_image = np.array(image) # for grounding
276
+ g_image = self.extra_image_processor.apply_image(g_image)
277
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
278
+ annotations['g_pixel_values'] = g_pixel_values
279
+
280
+ width, height = image.size
281
+ if self.pad_image_to_square:
282
+ image = expand2square(
283
+ image, tuple(int(x * 255) for x in self.image_processor.image_mean))
284
+ image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
285
+ annotations['pixel_values'] = image
286
+
287
+ for ann in ann_info:
288
+ if ann.get('ignore', False):
289
+ continue
290
+ x1, y1, w, h = ann['bbox']
291
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
292
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
293
+ if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1:
294
+ continue
295
+ bbox = [x1, y1, x1 + w, y1 + h]
296
+ tokens_positive = ann['tokens_positive']
297
+ phrase = [img_info['caption'][span[0]:span[1]] for span in tokens_positive]
298
+ annotations['phrases'].append(phrase[0])
299
+ annotations['tokens_positive'].append(tokens_positive[0])
300
+
301
+ rle = ann['sam_mask']
302
+ mask_decoded = mask_utils.decode(rle).astype(np.uint8)
303
+ annotations['masks'].append(mask_decoded)
304
+
305
+ def sort_by_start_index(items, order):
306
+ return [items[i] for i in order]
307
+
308
+ phrase_order = sorted(range(len(annotations['tokens_positive'])), key=lambda x: annotations['tokens_positive'][x][0])
309
+ annotations['masks'] = sort_by_start_index(annotations['masks'], phrase_order)
310
+ annotations['phrases'] = sort_by_start_index(annotations['phrases'], phrase_order)
311
+ annotations['tokens_positive'] = sort_by_start_index(annotations['tokens_positive'], phrase_order)
312
+
313
+ return annotations
314
+
315
+ if __name__ == '__main__':
316
+ from transformers import CLIPImageProcessor, AutoTokenizer
317
+ from third_parts.segment_anything.utils.transforms import ResizeLongestSide
318
+ pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
319
+ llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
320
+
321
+ tokenizer = dict(
322
+ type=AutoTokenizer.from_pretrained,
323
+ pretrained_model_name_or_path=llm_name_or_path)
324
+ image_processor = dict(
325
+ type=CLIPImageProcessor.from_pretrained,
326
+ pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
327
+ extra_image_processor = dict(
328
+ type=ResizeLongestSide,
329
+ target_length=1024,
330
+ )
331
+ from xtuner.utils.templates import PROMPT_TEMPLATE
332
+ prompt_template = PROMPT_TEMPLATE.vicuna
333
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
334
+ from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
335
+ dataset = Flickr30kGCGDataset(
336
+ image_folder='data/flickr30k/flickr30k-images/',
337
+ image_processor=image_processor,
338
+ data_path='./data/GranDf/annotations/train/flickr_mergedGT_GCG_train.json',
339
+ tokenizer=tokenizer,
340
+ template_map_fn=dict(
341
+ type=template_map_fn_factory, template=prompt_template),
342
+ max_length=2048,
343
+ pad_image_to_square=True,
344
+ repeats=1,
345
+ num_classes_per_sample=3,
346
+ extra_image_processor=extra_image_processor)
347
+
348
+ for i in range(1000):
349
+ print(dataset[i])
projects/glamm/datasets/refcoco_segm_dataset.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import glob
4
+ import json
5
+ import logging
6
+ import os
7
+ import torch
8
+
9
+ from mmengine import print_log
10
+ from mmengine.config import Config, ConfigDict
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+ import numpy as np
14
+ import torch.nn.functional as F
15
+ from pycocotools.coco import COCO
16
+ from pycocotools import mask as mask_utils
17
+
18
+ from xtuner.registry import BUILDER
19
+
20
+ from xtuner.dataset.utils import encode_fn
21
+ from xtuner.dataset.map_fns import llava_map_fn
22
+
23
+ from projects.glamm.datasets.utils.utils import expand2square
24
+
25
+ from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
26
+ from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
27
+
28
+ from third_parts.mmdet.datasets.refcoco import RefCocoDataset
29
+
30
+
31
+ class ReferSegmDataset(RefCocoDataset):
32
+ def __init__(self,
33
+ data_root,
34
+ ann_file=None,
35
+ split_file=None,
36
+ image_processor=None,
37
+ extra_image_processor=None,
38
+ data_prefix=dict(img_path='train2014/'),
39
+ tokenizer=None,
40
+ template_map_fn=None,
41
+ max_length=2048,
42
+ pad_image_to_square=False,
43
+ num_classes_per_sample=3):
44
+ super().__init__(
45
+ data_root=data_root,
46
+ data_prefix=data_prefix,
47
+ pipeline=None,
48
+ ann_file=ann_file,
49
+ split_file=split_file,
50
+ )
51
+ self.begin_str = f"""{DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
52
+
53
+ self.question_templates = SEG_QUESTIONS
54
+ if extra_image_processor is not None:
55
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
56
+ self.num_classes_per_sample = num_classes_per_sample
57
+ self.tokenizer = BUILDER.build(tokenizer)
58
+
59
+ self.tokenizer.add_tokens(
60
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
61
+ )
62
+ reg_tokens = ['<bbox>', '<point>']
63
+ segmentation_tokens = ['[SEG]']
64
+ phrase_tokens = ['<p>', '</p>']
65
+ special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
66
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
67
+
68
+ self.max_length = max_length
69
+ self.template_map_fn = BUILDER.build(template_map_fn)
70
+
71
+ self.image_processor = BUILDER.build(image_processor)
72
+ size = self.image_processor.crop_size
73
+ if isinstance(size, dict):
74
+ self.image_w, self.image_h = size['width'], size['height']
75
+ self.pad_image_to_square = pad_image_to_square
76
+
77
+ @property
78
+ def modality_length(self):
79
+ import pickle
80
+ length_list = []
81
+ for idx in range(len(self)):
82
+ length_list.append(100)
83
+ # for idx in range(len(self)):
84
+ # if self.serialize_data:
85
+ # start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
86
+ # end_addr = self.data_address[idx].item()
87
+ # bytes = memoryview(
88
+ # self.data_bytes[start_addr:end_addr]) # type: ignore
89
+ # data_dict = pickle.loads(bytes)
90
+ # else:
91
+ # data_dict = copy.deepcopy(self.data_list[idx])
92
+ return length_list
93
+
94
+ def _parse_annotations(self, ann_info):
95
+ image_path = ann_info['img_path']
96
+ image = Image.open(image_path).convert('RGB')
97
+ if hasattr(self, 'extra_image_processor'):
98
+ g_image = np.array(image) # for grounding
99
+ g_image = self.extra_image_processor.apply_image(g_image)
100
+ g_pixel_values = torch.from_numpy(
101
+ g_image).permute(2, 0, 1).contiguous()
102
+ ann_info['g_pixel_values'] = g_pixel_values
103
+
104
+ width, height = image.size
105
+ if self.pad_image_to_square:
106
+ image = expand2square(
107
+ image, tuple(int(x * 255) for x in self.image_processor.image_mean))
108
+ image = self.image_processor.preprocess(
109
+ image, return_tensors='pt')['pixel_values'][0]
110
+ ann_info['pixel_values'] = image
111
+
112
+ masks, phrases = [], []
113
+ instances, text = ann_info['instances'], ann_info['text']
114
+ index = np.random.choice(range(len(instances)), min(
115
+ len(instances), self.num_classes_per_sample))
116
+ for idx in index:
117
+ inst = instances[idx]
118
+ phrase = text[idx].lower()
119
+ phrases.append(phrase)
120
+ binary_mask = np.zeros((height, width), dtype=np.uint8)
121
+ for seg in inst["mask"]:
122
+ rles = mask_utils.frPyObjects([seg], height, width)
123
+ m = mask_utils.decode(rles)
124
+ m = m.astype(np.uint8)
125
+ binary_mask += m.squeeze()
126
+ masks.append(binary_mask)
127
+
128
+ ann_info.update({
129
+ 'masks': masks,
130
+ 'phrases': phrases,
131
+ })
132
+ return ann_info
133
+
134
+ def __getitem__(self, idx):
135
+ data_dict = {}
136
+ ann_info = super().__getitem__(idx)
137
+ ann_info = self._parse_annotations(ann_info)
138
+
139
+ data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values')
140
+ data_dict['pixel_values'] = ann_info.pop('pixel_values')
141
+ if len(ann_info['masks']) == 0:
142
+ return self.__getitem__(0)
143
+ data_dict['masks'] = torch.from_numpy(
144
+ np.stack(ann_info['masks'], axis=0))
145
+
146
+ conversation = []
147
+ for i, phrase in enumerate(ann_info['phrases']):
148
+ question = random.choice(SEG_QUESTIONS).format(class_name=phrase)
149
+ conversation.append(
150
+ {'input': question, 'output': random.choice(ANSWER_LIST)})
151
+
152
+ data_dict['conversation'] = conversation
153
+ result = self.template_map_fn(data_dict)
154
+ data_dict.update(result)
155
+
156
+ result = encode_fn(data_dict, tokenizer=self.tokenizer,
157
+ max_length=self.max_length, with_image_token=True)
158
+ data_dict.update(result)
159
+
160
+ return data_dict
161
+
162
+ if __name__ == '__main__':
163
+ from transformers import CLIPImageProcessor, AutoTokenizer
164
+ from third_parts.segment_anything.utils.transforms import ResizeLongestSide
165
+ pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
166
+ llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
167
+
168
+ tokenizer = dict(
169
+ type=AutoTokenizer.from_pretrained,
170
+ pretrained_model_name_or_path=llm_name_or_path)
171
+ image_processor = dict(
172
+ type=CLIPImageProcessor.from_pretrained,
173
+ pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
174
+ extra_image_processor = dict(
175
+ type=ResizeLongestSide,
176
+ target_length=1024,
177
+ )
178
+ from xtuner.utils.templates import PROMPT_TEMPLATE
179
+ prompt_template = PROMPT_TEMPLATE.vicuna
180
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
181
+ from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
182
+
183
+ dataset = ReferSegmDataset(
184
+ tokenizer=tokenizer,
185
+ image_processor=image_processor,
186
+ template_map_fn=dict(
187
+ type=template_map_fn_factory, template=prompt_template),
188
+ extra_image_processor=extra_image_processor,
189
+ data_root='data/coco/',
190
+ data_prefix=dict(img_path='train2014/'),
191
+ ann_file='refcoco+/instances.json',
192
+ split_file='refcoco+/refs(unc).p',
193
+ )
194
+ for i in range(1000):
195
+ dataset[i]
projects/glamm/datasets/region_level_dataset.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import glob
4
+ import json
5
+ import logging
6
+ import os
7
+ import torch
8
+
9
+ from mmengine import print_log
10
+ from mmengine.config import Config, ConfigDict
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+ import numpy as np
14
+ import torch.nn.functional as F
15
+ from pycocotools.coco import COCO
16
+ from pycocotools import mask as mask_utils
17
+
18
+ from xtuner.registry import BUILDER
19
+
20
+ from xtuner.dataset.utils import encode_fn
21
+ from xtuner.dataset.map_fns import llava_map_fn
22
+
23
+ from projects.glamm.datasets.utils.utils import expand2square
24
+
25
+ from projects.glamm.datasets.utils.utils import ANSWER_LIST, REGION_QUESTIONS
26
+ from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
27
+
28
+
29
+ class RegionDataset(Dataset):
30
+ def __init__(self,
31
+ image_folder,
32
+ image_processor,
33
+ data_path=None,
34
+ tokenizer=None,
35
+ template_map_fn=None,
36
+ max_length=2048,
37
+ pad_image_to_square=False,
38
+ repeats=1,
39
+ num_classes_per_sample=3,
40
+ extra_image_processor=None):
41
+ super().__init__()
42
+
43
+ self.begin_str = f"""{DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
44
+ self.question_templates = REGION_QUESTIONS
45
+
46
+ if extra_image_processor is not None:
47
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
48
+ self.num_classes_per_sample = num_classes_per_sample
49
+ self.tokenizer = BUILDER.build(tokenizer)
50
+
51
+ self.tokenizer.add_tokens(
52
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
53
+ )
54
+ reg_tokens = ['<bbox>', '<point>']
55
+ segmentation_tokens = ['[SEG]']
56
+ phrase_tokens = ['<p>', '</p>']
57
+ special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
58
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
59
+
60
+ self.max_length = max_length
61
+ self.template_map_fn = BUILDER.build(template_map_fn)
62
+
63
+ self.text_data = self._load_annotations(data_path, image_folder)
64
+ self.image_folder = image_folder
65
+
66
+ self.image_processor = BUILDER.build(image_processor)
67
+ size = self.image_processor.crop_size
68
+
69
+ if isinstance(size, dict):
70
+ self.image_w, self.image_h = size['width'], size['height']
71
+ elif isinstance(size, int):
72
+ self.image_h, self.image_w = size, size
73
+ else:
74
+ self.image_w, self.image_h = size
75
+
76
+ self.pad_image_to_square = pad_image_to_square
77
+ self.repeats = repeats
78
+
79
+ def _load_annotations(self, data_path, image_folder=None):
80
+ self.coco = COCO(data_path)
81
+ img_ids = self.coco.getImgIds()
82
+ data_infos = []
83
+ for img_id in img_ids:
84
+ info = self.coco.loadImgs([img_id])[0]
85
+ info['filename'] = info['file_name'].split('_')[-1]
86
+ info['height'] = int(info['height'])
87
+ info['width'] = int(info['width'])
88
+ if min(info['height'], info['width']) < 32:
89
+ continue
90
+ data_infos.append(info)
91
+ return data_infos
92
+
93
+ @property
94
+ def modality_length(self):
95
+ length_list = []
96
+ for data_dict in self.text_data:
97
+ cur_len = 100
98
+ length_list.append(cur_len)
99
+ return length_list * self.repeats
100
+
101
+ def __len__(self):
102
+ return len(self.text_data) * self.repeats
103
+
104
+ def real_len(self):
105
+ return len(self.text_data)
106
+
107
+ def region_processor(self, orig_size, post_size, bboxes, labels):
108
+ orig_h, orig_w = orig_size
109
+ post_h, post_w = post_size
110
+ y_scale = post_h / orig_h
111
+ x_scale = post_w / orig_w
112
+ shuffle_ids = torch.randperm(len(labels))[:self.num_classes_per_sample]
113
+ selected_bboxes = bboxes[shuffle_ids]
114
+
115
+ # Ensure selected_bboxes is two-dimensional
116
+ if len(selected_bboxes.shape) == 1:
117
+ selected_bboxes = np.expand_dims(selected_bboxes, axis=0)
118
+
119
+ selected_labels = [labels[i] for i in shuffle_ids]
120
+ selected_bboxes[:, [0, 2]] *= x_scale
121
+ selected_bboxes[:, [1, 3]] *= y_scale
122
+ selected_bboxes = torch.tensor(
123
+ selected_bboxes, dtype=torch.float32) / post_h
124
+ return selected_bboxes, selected_labels
125
+
126
+ def _parse_annotations(self, img_info):
127
+ data_dict = {}
128
+ bboxes, captions = [], []
129
+ ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
130
+ image_path = os.path.join(self.image_folder, img_info['file_name'])
131
+ image = Image.open(image_path).convert('RGB')
132
+ if hasattr(self, 'extra_image_processor'):
133
+ g_image = np.array(image) # for grounding
134
+ g_image = self.extra_image_processor.apply_image(g_image)
135
+ g_pixel_values = torch.from_numpy(
136
+ g_image).permute(2, 0, 1).contiguous()
137
+ data_dict['g_pixel_values'] = g_pixel_values
138
+
139
+ orig_w, orig_h = image.size
140
+ if self.pad_image_to_square:
141
+ image = expand2square(
142
+ image, tuple(int(x * 255) for x in self.image_processor.image_mean))
143
+ image = self.image_processor.preprocess(
144
+ image, return_tensors='pt')['pixel_values'][0]
145
+ post_h, post_w = image.shape[1:3]
146
+ data_dict['pixel_values'] = image
147
+
148
+ for ann in ann_info:
149
+ if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
150
+ continue
151
+ x1, y1, w, h = ann['bbox']
152
+ inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0))
153
+ inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0))
154
+ if inter_w * inter_h == 0:
155
+ continue
156
+ bbox = [x1, y1, x1 + w, y1 + h]
157
+
158
+ if bbox:
159
+ bboxes.append(bbox)
160
+ captions.append(img_info['caption'])
161
+
162
+ if len(bboxes) == 0:
163
+ return self.__getitem__(0)
164
+
165
+ bboxes = np.array(bboxes, dtype=np.float32)
166
+ seg_map = img_info['file_name'].replace('jpg', 'png')
167
+ bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions)
168
+
169
+ data_dict['bboxes'] = bboxes
170
+ data_dict['captions'] = captions
171
+ data_dict['seg_map'] = seg_map
172
+ return data_dict
173
+
174
+ def create_conversation(self, captions):
175
+ questions = []
176
+ answers = []
177
+ for i, label in enumerate(captions):
178
+ question = random.choice(self.question_templates).strip().replace('<region>', f'region{i + 1} <bbox>')
179
+ questions.append(question)
180
+ answers.append(label)
181
+
182
+ conversation = []
183
+ for i, (question, answer) in enumerate(zip(questions, answers)):
184
+ if i == 0:
185
+ question = self.begin_str + question
186
+ conversation.append({'input': question, 'output': answer})
187
+ return conversation
188
+
189
+ def __getitem__(self, index):
190
+ index = index % self.real_len()
191
+ data_dict = {}
192
+ ann_info = copy.deepcopy(self.text_data[index])
193
+ ann_info = self._parse_annotations(ann_info)
194
+
195
+ data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values', None)
196
+ data_dict['pixel_values'] = ann_info.pop('pixel_values')
197
+ data_dict['bboxes'] = ann_info.pop('bboxes', None)
198
+
199
+ conversation = self.create_conversation(ann_info['captions'])
200
+ data_dict['conversation'] = conversation
201
+
202
+ result = self.template_map_fn(data_dict)
203
+ data_dict.update(result)
204
+
205
+ result = encode_fn(data_dict, tokenizer=self.tokenizer,
206
+ max_length=self.max_length, with_image_token=True)
207
+ data_dict.update(result)
208
+
209
+ return data_dict
210
+
211
+ class RefCocoGRegionDataset(RegionDataset):
212
+ pass
213
+
214
+ class VisualGenomeRegionDataset(RegionDataset):
215
+ def _parse_annotations(self, img_info):
216
+ data_dict = {}
217
+ bboxes, captions = [], []
218
+ ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
219
+ image_path = os.path.join(self.image_folder, img_info['file_name'])
220
+ image = Image.open(image_path).convert('RGB')
221
+ if hasattr(self, 'extra_image_processor'):
222
+ g_image = np.array(image) # for grounding
223
+ g_image = self.extra_image_processor.apply_image(g_image)
224
+ g_pixel_values = torch.from_numpy(
225
+ g_image).permute(2, 0, 1).contiguous()
226
+ data_dict['g_pixel_values'] = g_pixel_values
227
+
228
+ orig_w, orig_h = image.size
229
+ if self.pad_image_to_square:
230
+ image = expand2square(
231
+ image, tuple(int(x * 255) for x in self.image_processor.image_mean))
232
+ image = self.image_processor.preprocess(
233
+ image, return_tensors='pt')['pixel_values'][0]
234
+ post_h, post_w = image.shape[1:3]
235
+ data_dict['pixel_values'] = image
236
+
237
+ for ann in ann_info:
238
+ if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
239
+ continue
240
+ x1, y1, w, h = ann['bbox']
241
+ inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0))
242
+ inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0))
243
+ if inter_w * inter_h == 0:
244
+ continue
245
+ bbox = [x1, y1, x1 + w, y1 + h]
246
+
247
+ if bbox:
248
+ bboxes.append(bbox)
249
+ captions.append(ann['caption'].strip())
250
+
251
+ if len(bboxes) == 0:
252
+ return self.__getitem__(0)
253
+
254
+ bboxes = np.array(bboxes, dtype=np.float32)
255
+ seg_map = img_info['file_name'].replace('jpg', 'png')
256
+ bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions)
257
+
258
+ data_dict['bboxes'] = bboxes
259
+ data_dict['captions'] = captions
260
+ data_dict['seg_map'] = seg_map
261
+ return data_dict
262
+
263
+ if __name__ == '__main__':
264
+ from transformers import CLIPImageProcessor, AutoTokenizer
265
+ from third_parts.segment_anything.utils.transforms import ResizeLongestSide
266
+ pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
267
+ llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
268
+
269
+ tokenizer = dict(
270
+ type=AutoTokenizer.from_pretrained,
271
+ pretrained_model_name_or_path=llm_name_or_path)
272
+ image_processor = dict(
273
+ type=CLIPImageProcessor.from_pretrained,
274
+ pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
275
+ extra_image_processor = dict(
276
+ type=ResizeLongestSide,
277
+ target_length=1024,
278
+ )
279
+ from xtuner.utils.templates import PROMPT_TEMPLATE
280
+ prompt_template = PROMPT_TEMPLATE.vicuna
281
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
282
+ from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
283
+ dataset = VisualGenomeRegionDataset(
284
+ image_folder='./data/visual_genome/images',
285
+ image_processor=image_processor,
286
+ data_path='data/visual_genome/train.json',
287
+ tokenizer=tokenizer,
288
+ template_map_fn=dict(
289
+ type=template_map_fn_factory, template=prompt_template),
290
+ max_length=2048,
291
+ pad_image_to_square=False,
292
+ repeats=1,
293
+ num_classes_per_sample=3,
294
+ extra_image_processor=None)
295
+
296
+ for i in range(1000):
297
+ print(dataset[i])
projects/glamm/datasets/semantic_seg_dataset.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import glob
4
+ import json
5
+ import logging
6
+ import os
7
+ import torch
8
+
9
+ from mmengine import print_log
10
+ from mmengine.config import Config, ConfigDict
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+ import numpy as np
14
+ import torch.nn.functional as F
15
+ from pycocotools.coco import COCO
16
+
17
+ from xtuner.registry import BUILDER
18
+
19
+ from xtuner.dataset.utils import encode_fn
20
+ from xtuner.dataset.map_fns import llava_map_fn
21
+
22
+ from projects.glamm.datasets.utils.utils import expand2square
23
+
24
+ from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
25
+ from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
26
+
27
+
28
+ class SemanticSegDataset(Dataset):
29
+ def __init__(self,
30
+ image_folder,
31
+ image_processor,
32
+ data_path=None,
33
+ tokenizer=None,
34
+ offline_processed_text_folder=None,
35
+ max_dataset_length=None,
36
+ dataset_map_fn=None,
37
+ template_map_fn=None,
38
+ max_length=2048,
39
+ pad_image_to_square=False,
40
+ num_proc=8,
41
+ lazy=False,
42
+ repeats=1,
43
+ gcg_format=False,
44
+ num_classes_per_sample=3,
45
+ extra_image_processor=None):
46
+ super().__init__()
47
+ self.gcg_format = gcg_format
48
+ if extra_image_processor is not None:
49
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
50
+ self.num_classes_per_sample = num_classes_per_sample
51
+ self.tokenizer = BUILDER.build(tokenizer)
52
+
53
+ self.tokenizer.add_tokens(
54
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
55
+ )
56
+ reg_tokens = ['<bbox>', '<point>']
57
+ segmentation_tokens = ['[SEG]']
58
+ phrase_tokens = ['<p>', '</p>']
59
+ special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
60
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
61
+
62
+ assert offline_processed_text_folder or (data_path and tokenizer)
63
+ self.lazy = lazy
64
+
65
+ self.max_length = max_length
66
+ self.dataset_map_fn = dataset_map_fn
67
+ self.template_map_fn = template_map_fn
68
+ if isinstance(self.template_map_fn, dict) and self.lazy:
69
+ _type = self.template_map_fn['type']
70
+ del self.template_map_fn['type']
71
+ self.template_map_fn = _type(**self.template_map_fn)
72
+
73
+ if offline_processed_text_folder and data_path:
74
+ print_log(
75
+ 'Both `offline_processed_text_folder` and '
76
+ '`data_path` are set, and we load dataset from'
77
+ '`offline_processed_text_folder` '
78
+ f'({offline_processed_text_folder})',
79
+ logger='current',
80
+ level=logging.WARNING)
81
+
82
+ if offline_processed_text_folder is not None:
83
+ raise NotImplementedError
84
+ else:
85
+ self.image_label_datas = self.json_file_preprocess(data_path, image_folder)
86
+
87
+ self.image_folder = image_folder
88
+
89
+ if isinstance(image_processor, dict) or isinstance(image_processor, Config) or isinstance(image_processor, ConfigDict):
90
+ self.image_processor = BUILDER.build(image_processor)
91
+ else:
92
+ self.image_processor = image_processor
93
+
94
+ size = self.image_processor.crop_size
95
+
96
+ if isinstance(size, dict):
97
+ self.image_w, self.image_h = size['width'], size['height']
98
+ elif isinstance(size, int):
99
+ self.image_h, self.image_w = size, size
100
+ else:
101
+ self.image_w, self.image_h = size
102
+
103
+ self.pad_image_to_square = pad_image_to_square
104
+ self.down_ratio = 1
105
+ self.repeats = repeats
106
+
107
+ def json_file_preprocess(self, data_path, image_folder):
108
+ # ade20k
109
+ with open(data_path, 'r') as file:
110
+ ade20k_classes = json.load(file)
111
+ ade20k_image_dir = image_folder
112
+ ade20k_images = [os.path.join(ade20k_image_dir, img) for img in os.listdir(ade20k_image_dir) if
113
+ img.endswith('.jpg')]
114
+ ade20k_labels = [img.replace(".jpg", ".png").replace(
115
+ "images", "annotations") for img in ade20k_images]
116
+ self.classes = np.array(ade20k_classes)
117
+
118
+ ret = []
119
+ for image, label in zip(ade20k_images, ade20k_labels):
120
+ ret.append({"image": image, "label": label})
121
+ return ret
122
+
123
+ def __len__(self):
124
+ return len(self.image_label_datas) * self.repeats
125
+
126
+ @property
127
+ def modality_length(self):
128
+ length_list = []
129
+ for data_dict in self.image_label_datas:
130
+ length_list.append(100)
131
+ length_list = length_list * self.repeats
132
+ return length_list
133
+
134
+ def real_len(self):
135
+ return len(self.image_label_datas)
136
+
137
+ def decode_mask(self, label_path):
138
+ label = np.array(Image.open(label_path))
139
+
140
+ # ade20k
141
+ label = np.where(label == 0, 255, label - 1)
142
+ unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
143
+ if not unique_labels:
144
+ return None, None
145
+
146
+ selected_labels = np.random.choice(unique_labels, min(
147
+ len(unique_labels), self.num_classes_per_sample), replace=False)
148
+ label = torch.from_numpy(label).long()
149
+ masks = torch.stack([label == class_id for class_id in selected_labels], dim=0)
150
+ return masks, selected_labels
151
+
152
+ def __getitem__(self, index):
153
+ index = index % self.real_len()
154
+ data_dict = copy.deepcopy(self.image_label_datas[index])
155
+
156
+ assert 'image' in data_dict.keys()
157
+ if data_dict.get('image', None) is not None:
158
+ image_file = data_dict['image']
159
+ image = Image.open(image_file).convert('RGB')
160
+ if hasattr(self, 'extra_image_processor'):
161
+ g_image = np.array(image) # for grounding
162
+ g_image = self.extra_image_processor.apply_image(g_image)
163
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
164
+ data_dict['g_pixel_values'] = g_pixel_values
165
+
166
+ ori_width, ori_height = image.size
167
+ if self.pad_image_to_square:
168
+ image = expand2square(image, tuple(int(x * 255)
169
+ for x in self.image_processor.image_mean))
170
+ image = self.image_processor.preprocess(
171
+ image, return_tensors='pt')['pixel_values'][0]
172
+ data_dict['pixel_values'] = image
173
+
174
+ # process and get masks
175
+ data_dict['masks'], class_id = self.decode_mask(data_dict['label'])
176
+ if class_id is None:
177
+ return self.__getitem__(0)
178
+
179
+ if self.gcg_format:
180
+ pass
181
+ else:
182
+ conversation = []
183
+ for i, c_id in enumerate(class_id):
184
+ question = random.choice(SEG_QUESTIONS).format(
185
+ class_name=self.classes[c_id].lower())
186
+ if i == 0:
187
+ question = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" + question
188
+ conversation.append(
189
+ {'input': question, 'output': random.choice(ANSWER_LIST)})
190
+
191
+ data_dict.update({'conversation': conversation})
192
+ else:
193
+ if hasattr(self.image_processor, 'crop_size'):
194
+ crop_size = self.image_processor.crop_size
195
+ else:
196
+ crop_size = self.image_processor.size
197
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
198
+ crop_size['width'])
199
+ data_dict['masks'] = None
200
+
201
+ if self.lazy:
202
+ result = self.template_map_fn(data_dict)
203
+ data_dict.update(result)
204
+
205
+ result = encode_fn(data_dict, tokenizer=self.tokenizer,
206
+ max_length=self.max_length, with_image_token=True)
207
+ data_dict.update(result)
208
+
209
+ return data_dict
210
+
211
+ class ADE20kSemanticSegDataset(SemanticSegDataset):
212
+ def __init__(self,
213
+ image_folder,
214
+ image_processor,
215
+ data_path=None,
216
+ tokenizer=None,
217
+ offline_processed_text_folder=None,
218
+ max_dataset_length=None,
219
+ dataset_map_fn=None,
220
+ template_map_fn=None,
221
+ max_length=2048,
222
+ pad_image_to_square=False,
223
+ num_proc=8,
224
+ lazy=False,
225
+ repeats=1,
226
+ gcg_format=False,
227
+ num_classes_per_sample=3,
228
+ extra_image_processor=None):
229
+ super().__init__(
230
+ image_folder=image_folder,
231
+ image_processor=image_processor,
232
+ data_path=data_path,
233
+ tokenizer=tokenizer,
234
+ offline_processed_text_folder=offline_processed_text_folder,
235
+ max_dataset_length=max_dataset_length,
236
+ dataset_map_fn=dataset_map_fn,
237
+ template_map_fn=template_map_fn,
238
+ max_length=max_length,
239
+ pad_image_to_square=pad_image_to_square,
240
+ num_proc=num_proc,
241
+ lazy=lazy,
242
+ repeats=repeats,
243
+ gcg_format=gcg_format,
244
+ num_classes_per_sample=num_classes_per_sample,
245
+ extra_image_processor=extra_image_processor,
246
+ )
247
+
248
+ class COCOStuffSemanticSegDataset(SemanticSegDataset):
249
+ def __init__(self,
250
+ image_folder,
251
+ image_processor,
252
+ data_path=None,
253
+ tokenizer=None,
254
+ offline_processed_text_folder=None,
255
+ max_dataset_length=None,
256
+ dataset_map_fn=None,
257
+ template_map_fn=None,
258
+ max_length=2048,
259
+ pad_image_to_square=False,
260
+ num_proc=8,
261
+ lazy=False,
262
+ repeats=1,
263
+ label_path=None,
264
+ gcg_format=False,
265
+ num_classes_per_sample=3,
266
+ extra_image_processor=None):
267
+ self.label_path = label_path
268
+ super().__init__(
269
+ image_folder=image_folder,
270
+ image_processor=image_processor,
271
+ data_path=data_path,
272
+ tokenizer=tokenizer,
273
+ offline_processed_text_folder=offline_processed_text_folder,
274
+ max_dataset_length=max_dataset_length,
275
+ dataset_map_fn=dataset_map_fn,
276
+ template_map_fn=template_map_fn,
277
+ max_length=max_length,
278
+ pad_image_to_square=pad_image_to_square,
279
+ num_proc=num_proc,
280
+ lazy=lazy,
281
+ repeats=repeats,
282
+ gcg_format=gcg_format,
283
+ num_classes_per_sample=num_classes_per_sample,
284
+ extra_image_processor=extra_image_processor,
285
+ )
286
+ self.cocostuff_class2index = {c: i for i, c in enumerate(self.classes)}
287
+
288
+ def json_file_preprocess(self, data_path, image_folder):
289
+ # coco stuff
290
+ assert self.label_path is not None
291
+ with open(data_path, 'r') as file:
292
+ cocostuff_classes = [line.strip().split(": ")[-1]
293
+ for line in file.readlines()[1:]]
294
+ coco_stuff_image_dir = image_folder
295
+ coco_stuff_label_dir = self.label_path
296
+ coco_stuff_labels = glob.glob(
297
+ os.path.join(coco_stuff_label_dir, "*.png"))
298
+
299
+ coco_stuff_images = [label.replace(".png", ".jpg").replace(coco_stuff_label_dir, coco_stuff_image_dir)
300
+ for label in coco_stuff_labels]
301
+
302
+ self.classes = np.array(cocostuff_classes)
303
+
304
+ ret = []
305
+ for image, label in zip(coco_stuff_images, coco_stuff_labels):
306
+ ret.append({"image": image, "label": label})
307
+ return ret
308
+
309
+ def decode_mask(self, label_path):
310
+ label = np.array(Image.open(label_path))
311
+
312
+ # coco stuff
313
+ ignored_classes = [index for class_name,
314
+ index in self.cocostuff_class2index.items() if "-" in class_name]
315
+ label = np.where(np.isin(label, ignored_classes), 255, label)
316
+
317
+ unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
318
+ if not unique_labels:
319
+ print("No valid label !!!")
320
+ return None, None
321
+
322
+ # only choose 1
323
+ selected_labels = np.random.choice(unique_labels, min(
324
+ len(unique_labels), self.num_classes_per_sample), replace=False)
325
+
326
+ label = torch.from_numpy(label).long()
327
+ masks = torch.stack(
328
+ [label == class_id for class_id in selected_labels], dim=0)
329
+ return masks, selected_labels
330
+
331
+ class PascalPartSemanticSegDataset(SemanticSegDataset):
332
+
333
+ def json_file_preprocess(self, data_path, image_folder):
334
+ self.coco_api = COCO(data_path)
335
+ img_ids = self.coco_api.getImgIds()
336
+ all_classes = self.coco_api.loadCats(self.coco_api.getCatIds())
337
+ class_map_pascal_part = {}
338
+ for cat in all_classes:
339
+ cat_main, cat_part = cat["name"].strip().split(":")
340
+ name = (cat_main, cat_part)
341
+ class_map_pascal_part[cat["id"]] = name
342
+ self.classes = class_map_pascal_part
343
+ return img_ids
344
+
345
+ def __getitem__(self, index):
346
+ index = index % self.real_len()
347
+ img_id = self.image_label_datas[index]
348
+ img_info = self.coco_api.loadImgs([img_id])[0]
349
+ file_name = img_info["file_name"]
350
+ data_dict = {}
351
+
352
+ image_file = os.path.join(self.image_folder, file_name)
353
+ image = Image.open(image_file).convert('RGB')
354
+
355
+ if hasattr(self, 'extra_image_processor'):
356
+ g_image = np.array(image) # for grounding
357
+ g_image = self.extra_image_processor.apply_image(g_image)
358
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
359
+ data_dict['g_pixel_values'] = g_pixel_values
360
+
361
+ if self.pad_image_to_square:
362
+ image = expand2square(
363
+ image, tuple(int(x * 255) for x in self.image_processor.image_mean))
364
+ image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
365
+ data_dict['pixel_values'] = image
366
+
367
+ annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"])
368
+ annotations = self.coco_api.loadAnns(annotation_ids)
369
+
370
+ if not annotations:
371
+ return self.__getitem__(0)
372
+
373
+ sampled_anns = np.random.choice(annotations, min(
374
+ len(annotations), self.num_classes_per_sample), replace=False)
375
+
376
+ conversation = []
377
+ for i, ann in enumerate(sampled_anns):
378
+ cat_id = ann['category_id']
379
+ sampled_cls = self.classes[cat_id]
380
+ if isinstance(sampled_cls, tuple):
381
+ obj, part = sampled_cls
382
+ name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}"
383
+ else:
384
+ name = sampled_cls
385
+ question = random.choice(SEG_QUESTIONS).format(class_name=name)
386
+ if i == 0:
387
+ question = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" + question
388
+ conversation.append(
389
+ {'input': question, 'output': random.choice(ANSWER_LIST)})
390
+
391
+ masks = [self.coco_api.annToMask(ann) for ann in sampled_anns]
392
+ masks = np.stack(masks, axis=0)
393
+ masks = torch.from_numpy(masks)
394
+
395
+ data_dict['masks'] = masks
396
+ data_dict['conversation'] = conversation
397
+
398
+ if self.lazy:
399
+ result = self.template_map_fn(data_dict)
400
+ data_dict.update(result)
401
+
402
+ result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
403
+ data_dict.update(result)
404
+
405
+ return data_dict
406
+
407
+ class PacoSemanticSegDataset(PascalPartSemanticSegDataset):
408
+ def json_file_preprocess(self, data_path, image_folder):
409
+ self.coco_api = COCO(data_path)
410
+ all_classes = self.coco_api.loadCats(self.coco_api.getCatIds())
411
+ class_map_paco = {}
412
+ for cat in all_classes:
413
+ cat_split = cat["name"].strip().split(":")
414
+ if len(cat_split) == 1:
415
+ name = cat_split[0].split("_(")[0]
416
+ else:
417
+ assert len(cat_split) == 2
418
+ obj, part = cat_split
419
+ obj = obj.split("_(")[0]
420
+ part = part.split("_(")[0]
421
+ name = (obj, part)
422
+ class_map_paco[cat["id"]] = name
423
+ self.classes = class_map_paco
424
+ return self.coco_api.getImgIds()
projects/glamm/datasets/utils/ade20k_classes.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "wall", "building", "sky", "floor", "tree", "ceiling", "road",
3
+ "bed", "windowpane", "grass", "cabinet", "sidewalk",
4
+ "person", "earth", "door", "table", "mountain", "plant",
5
+ "curtain", "chair", "car", "water", "painting", "sofa",
6
+ "shelf", "house", "sea", "mirror", "rug", "field", "armchair",
7
+ "seat", "fence", "desk", "rock", "wardrobe", "lamp",
8
+ "bathtub", "railing", "cushion", "base", "box", "column",
9
+ "signboard", "chest of drawers", "counter", "sand", "sink",
10
+ "skyscraper", "fireplace", "refrigerator", "grandstand",
11
+ "path", "stairs", "runway", "case", "pool table", "pillow",
12
+ "screen door", "stairway", "river", "bridge", "bookcase",
13
+ "blind", "coffee table", "toilet", "flower", "book", "hill",
14
+ "bench", "countertop", "stove", "palm", "kitchen island",
15
+ "computer", "swivel chair", "boat", "bar", "arcade machine",
16
+ "hovel", "bus", "towel", "light", "truck", "tower",
17
+ "chandelier", "awning", "streetlight", "booth",
18
+ "television receiver", "airplane", "dirt track", "apparel",
19
+ "pole", "land", "bannister", "escalator", "ottoman", "bottle",
20
+ "buffet", "poster", "stage", "van", "ship", "fountain",
21
+ "conveyer belt", "canopy", "washer", "plaything",
22
+ "swimming pool", "stool", "barrel", "basket", "waterfall",
23
+ "tent", "bag", "minibike", "cradle", "oven", "ball", "food",
24
+ "step", "tank", "trade name", "microwave", "pot", "animal",
25
+ "bicycle", "lake", "dishwasher", "screen", "blanket",
26
+ "sculpture", "hood", "sconce", "vase", "traffic light",
27
+ "tray", "ashcan", "fan", "pier", "crt screen", "plate",
28
+ "monitor", "bulletin board", "shower", "radiator", "glass",
29
+ "clock", "flag"
30
+ ]
projects/glamm/datasets/utils/cocostuff_classes.txt ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0: unlabeled
2
+ 1: person
3
+ 2: bicycle
4
+ 3: car
5
+ 4: motorcycle
6
+ 5: airplane
7
+ 6: bus
8
+ 7: train
9
+ 8: truck
10
+ 9: boat
11
+ 10: traffic light
12
+ 11: fire hydrant
13
+ 12: street sign
14
+ 13: stop sign
15
+ 14: parking meter
16
+ 15: bench
17
+ 16: bird
18
+ 17: cat
19
+ 18: dog
20
+ 19: horse
21
+ 20: sheep
22
+ 21: cow
23
+ 22: elephant
24
+ 23: bear
25
+ 24: zebra
26
+ 25: giraffe
27
+ 26: hat
28
+ 27: backpack
29
+ 28: umbrella
30
+ 29: shoe
31
+ 30: eye glasses
32
+ 31: handbag
33
+ 32: tie
34
+ 33: suitcase
35
+ 34: frisbee
36
+ 35: skis
37
+ 36: snowboard
38
+ 37: sports ball
39
+ 38: kite
40
+ 39: baseball bat
41
+ 40: baseball glove
42
+ 41: skateboard
43
+ 42: surfboard
44
+ 43: tennis racket
45
+ 44: bottle
46
+ 45: plate
47
+ 46: wine glass
48
+ 47: cup
49
+ 48: fork
50
+ 49: knife
51
+ 50: spoon
52
+ 51: bowl
53
+ 52: banana
54
+ 53: apple
55
+ 54: sandwich
56
+ 55: orange
57
+ 56: broccoli
58
+ 57: carrot
59
+ 58: hot dog
60
+ 59: pizza
61
+ 60: donut
62
+ 61: cake
63
+ 62: chair
64
+ 63: couch
65
+ 64: potted plant
66
+ 65: bed
67
+ 66: mirror
68
+ 67: dining table
69
+ 68: window
70
+ 69: desk
71
+ 70: toilet
72
+ 71: door
73
+ 72: tv
74
+ 73: laptop
75
+ 74: mouse
76
+ 75: remote
77
+ 76: keyboard
78
+ 77: cell phone
79
+ 78: microwave
80
+ 79: oven
81
+ 80: toaster
82
+ 81: sink
83
+ 82: refrigerator
84
+ 83: blender
85
+ 84: book
86
+ 85: clock
87
+ 86: vase
88
+ 87: scissors
89
+ 88: teddy bear
90
+ 89: hair drier
91
+ 90: toothbrush
92
+ 91: hair brush
93
+ 92: banner
94
+ 93: blanket
95
+ 94: branch
96
+ 95: bridge
97
+ 96: building-other
98
+ 97: bush
99
+ 98: cabinet
100
+ 99: cage
101
+ 100: cardboard
102
+ 101: carpet
103
+ 102: ceiling-other
104
+ 103: ceiling-tile
105
+ 104: cloth
106
+ 105: clothes
107
+ 106: clouds
108
+ 107: counter
109
+ 108: cupboard
110
+ 109: curtain
111
+ 110: desk-stuff
112
+ 111: dirt
113
+ 112: door-stuff
114
+ 113: fence
115
+ 114: floor-marble
116
+ 115: floor-other
117
+ 116: floor-stone
118
+ 117: floor-tile
119
+ 118: floor-wood
120
+ 119: flower
121
+ 120: fog
122
+ 121: food-other
123
+ 122: fruit
124
+ 123: furniture-other
125
+ 124: grass
126
+ 125: gravel
127
+ 126: ground-other
128
+ 127: hill
129
+ 128: house
130
+ 129: leaves
131
+ 130: light
132
+ 131: mat
133
+ 132: metal
134
+ 133: mirror-stuff
135
+ 134: moss
136
+ 135: mountain
137
+ 136: mud
138
+ 137: napkin
139
+ 138: net
140
+ 139: paper
141
+ 140: pavement
142
+ 141: pillow
143
+ 142: plant-other
144
+ 143: plastic
145
+ 144: platform
146
+ 145: playingfield
147
+ 146: railing
148
+ 147: railroad
149
+ 148: river
150
+ 149: road
151
+ 150: rock
152
+ 151: roof
153
+ 152: rug
154
+ 153: salad
155
+ 154: sand
156
+ 155: sea
157
+ 156: shelf
158
+ 157: sky
159
+ 158: skyscraper
160
+ 159: snow
161
+ 160: solid-other
162
+ 161: stairs
163
+ 162: stone
164
+ 163: straw
165
+ 164: structural-other
166
+ 165: table
167
+ 166: tent
168
+ 167: textile-other
169
+ 168: towel
170
+ 169: tree
171
+ 170: vegetable
172
+ 171: wall-brick
173
+ 172: wall-concrete
174
+ 173: wall-other
175
+ 174: wall-panel
176
+ 175: wall-stone
177
+ 176: wall-tile
178
+ 177: wall-wood
179
+ 178: water-other
180
+ 179: waterdrops
181
+ 180: window-blind
182
+ 181: window-other
183
+ 182: wood
projects/glamm/datasets/utils/utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+
4
+
5
+ def expand2square(pil_img, background_color):
6
+ width, height = pil_img.size
7
+ if width == height:
8
+ return pil_img
9
+ elif width > height:
10
+ result = Image.new(pil_img.mode, (width, width), background_color)
11
+ result.paste(pil_img, (0, (width - height) // 2))
12
+ return result
13
+ else:
14
+ result = Image.new(pil_img.mode, (height, height), background_color)
15
+ result.paste(pil_img, ((height - width) // 2, 0))
16
+ return result
17
+
18
+ CAPTION_QUESTIONS = [
19
+ 'Could you please give me a detailed description of the image?',
20
+ 'Can you provide a thorough description of the this image?',
21
+ 'Please provide a thorough description of the this image',
22
+ 'Please provide a thorough description of the this image.',
23
+ 'Please describe in detail the contents of the image.',
24
+ 'Please describe in detail the contents of the image',
25
+ 'Could you give a comprehensive explanation of what can be found within this picture?',
26
+ 'Could you give me an elaborate explanation of this picture?',
27
+ 'Could you provide me with a detailed analysis of this photo?',
28
+ 'Could you please give me a detailed description of the image?',
29
+ 'Can you provide a thorough description of the this image?',
30
+ 'Please describe in detail the contents of the image',
31
+ 'Please describe in detail the contents of the image.',
32
+ 'Can you give a comprehensive explanation of this photo',
33
+ 'Please provide an elaborate explanation of this picture.',
34
+ 'Please provide an elaborate explanation of this picture',
35
+ 'Could you provide me with a detailed analysis of this photo',
36
+ ]
37
+
38
+ REGION_QUESTIONS = [
39
+ 'Can you provide me with a detailed description of the region in the picture marked by <region>?',
40
+ "I'm curious about the region represented by <region> in the picture. Could you describe it in detail?",
41
+ 'What can you tell me about the region indicated by <region> in the image?',
42
+ "I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?",
43
+ 'Could you describe the region shown as <region> in the picture in great detail?',
44
+ 'What details can you give me about the region outlined by <region> in the photo?',
45
+ 'Please provide me with a comprehensive description of the region marked with <region> in the image.',
46
+ 'Can you give me a detailed account of the region labeled as <region> in the picture?',
47
+ "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?",
48
+ 'What is the region outlined by <region> in the picture like? Could you give me a detailed description?',
49
+ 'Can you provide me with a detailed description of the region in the picture marked by <region>, please?',
50
+ "I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?",
51
+ 'What can you tell me about the region indicated by <region> in the image, exactly?',
52
+ "I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?",
53
+ 'Could you describe the region shown as <region> in the picture in great detail, please?',
54
+ 'What details can you give me about the region outlined by <region> in the photo, please?',
55
+ 'Please provide me with a comprehensive description of the region marked with <region> in the image, please.',
56
+ 'Can you give me a detailed account of the region labeled as <region> in the picture, please?',
57
+ "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?",
58
+ 'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?',
59
+ ]
60
+
61
+ REGION_GROUP_QUESTIONS = [
62
+ 'Could you please give me a detailed description of these areas <region>?',
63
+ 'Can you provide a thorough description of the regions <region> in this image?',
64
+ 'Please describe in detail the contents of the boxed areas <region>.',
65
+ 'Could you give a comprehensive explanation of what can be found within <region> in the picture?',
66
+ 'Could you give me an elaborate explanation of the <region> regions in this picture?',
67
+ 'Can you provide a comprehensive description of the areas identified by <region> in this photo?',
68
+ 'Help me understand the specific locations labeled <region> in this picture in detail, please.',
69
+ 'What is the detailed information about the areas marked by <region> in this image?',
70
+ 'Could you provide me with a detailed analysis of the regions designated <region> in this photo?',
71
+ 'What are the specific features of the areas marked <region> in this picture that you can describe in detail?',
72
+ 'Could you elaborate on the regions identified by <region> in this image?',
73
+ 'What can you tell me about the areas labeled <region> in this picture?',
74
+ 'Can you provide a thorough analysis of the specific locations designated <region> in this photo?',
75
+ 'I am interested in learning more about the regions marked <region> in this image. Can you provide me with more information?',
76
+ 'Could you please provide a detailed description of the areas identified by <region> in this photo?',
77
+ 'What is the significance of the regions labeled <region> in this picture?',
78
+ 'I would like to know more about the specific locations designated <region> in this image. Can you provide me with more information?',
79
+ 'Can you provide a detailed breakdown of the regions marked <region> in this photo?',
80
+ 'What specific features can you tell me about the areas identified by <region> in this picture?',
81
+ 'Could you please provide a comprehensive explanation of the locations labeled <region> in this image?',
82
+ 'Can you provide a detailed account of the regions designated <region> in this photo?',
83
+ 'I am curious about the areas marked <region> in this picture. Can you provide me with a detailed analysis?',
84
+ 'What important details can you tell me about the specific locations identified by <region> in this image?',
85
+ 'Could you please provide a detailed description of the regions labeled <region> in this photo?',
86
+ 'What can you tell me about the features of the areas designated <region> in this picture?',
87
+ 'Can you provide a comprehensive overview of the regions marked <region> in this image?',
88
+ 'I would like to know more about the specific locations identified by <region> in this photo. Can you provide me with more information?',
89
+ 'What is the detailed information you have on the areas labeled <region> in this picture?',
90
+ 'Could you provide me with a thorough analysis of the regions designated <region> in this image?',
91
+ 'Can you provide a detailed explanation of the specific locations marked by <region> in this photo?'
92
+ ]
93
+
94
+ GCG_QUESTIONS = [
95
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
96
+ 'Can you provide a thorough description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
97
+ 'Please describe in detail the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
98
+ 'Could you give a comprehensive explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
99
+ 'Could you give me an elaborate explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
100
+ 'Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
101
+ ]
102
+
103
+ SEG_QUESTIONS = [
104
+ "Can you segment the {class_name} in this image?",
105
+ "Please segment {class_name} in this image.",
106
+ "What is {class_name} in this image? Please respond with segmentation mask.",
107
+ "What is {class_name} in this image? Please output segmentation mask.",
108
+
109
+ "Can you segment the {class_name} in this image",
110
+ "Please segment {class_name} in this image",
111
+ "What is {class_name} in this image? Please respond with segmentation mask",
112
+ "What is {class_name} in this image? Please output segmentation mask",
113
+
114
+ "Could you provide a segmentation mask for the {class_name} in this image?",
115
+ "Please identify and segment the {class_name} in this image.",
116
+ "Where is the {class_name} in this picture? Please respond with a segmentation mask.",
117
+ "Can you highlight the {class_name} in this image with a segmentation mask?",
118
+
119
+ "Could you provide a segmentation mask for the {class_name} in this image",
120
+ "Please identify and segment the {class_name} in this image",
121
+ "Where is the {class_name} in this picture? Please respond with a segmentation mask",
122
+ "Can you highlight the {class_name} in this image with a segmentation mask",
123
+ ]
124
+
125
+ ANSWER_LIST = [
126
+ "It is [SEG].",
127
+ "Sure, [SEG].",
128
+ "Sure, it is [SEG].",
129
+ "Sure, the segmentation result is [SEG].",
130
+ "[SEG].",
131
+ ]
projects/glamm/models/glamm.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from xtuner.registry import BUILDER
5
+ from xtuner.model.utils import LoadWoInit, guess_load_checkpoint
6
+ from xtuner.model.llava import LLaVAModel
7
+
8
+ from mmengine.model import BaseModel
9
+ from mmengine import print_log
10
+
11
+ from projects.glamm.utils import prepare_inputs_labels_for_multimodal
12
+ from projects.glamm.utils import DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
13
+
14
+
15
+ class GLaMM(LLaVAModel):
16
+ def __init__(self,
17
+ use_activation_checkpointing=True,
18
+ tokenizer=None,
19
+ grounding_encoder=None,
20
+ region_encoder=None,
21
+ loss_mask=None,
22
+ loss_dice=None,
23
+ *args, **kwargs):
24
+ super(GLaMM, self).__init__(
25
+ *args, use_activation_checkpointing=use_activation_checkpointing, **kwargs)
26
+
27
+ self.use_activation_checkpointing = use_activation_checkpointing
28
+ self.tokenizer = BUILDER.build(tokenizer)
29
+ self._add_special_tokens()
30
+
31
+ self.grounding_encoder = BUILDER.build(grounding_encoder)
32
+ self.grounding_encoder.requires_grad_(False)
33
+ self.grounding_encoder.mask_decoder.requires_grad_(True)
34
+
35
+ if region_encoder is not None:
36
+ self.region_encoder = BUILDER.build(region_encoder)
37
+
38
+ in_dim = self.config.hidden_size
39
+ out_dim = self.grounding_encoder.mask_decoder.transformer_dim
40
+ self.text_hidden_fcs = nn.Sequential(
41
+ nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True),
42
+ nn.Linear(in_dim, out_dim), nn.Dropout(0.0)
43
+ )
44
+
45
+ self.loss_mask = BUILDER.build(loss_mask)
46
+ self.loss_dice = BUILDER.build(loss_dice)
47
+
48
+ def _add_special_tokens(self):
49
+ reg_tokens = ['<im_start>', '<im_end>', '<bbox>', '<point>']
50
+ segmentation_tokens = ['[SEG]']
51
+ phrase_tokens = ['<p>', '</p>']
52
+ special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
53
+ num_new_tokens = self.tokenizer.add_tokens(
54
+ special_tokens, special_tokens=True)
55
+ if num_new_tokens > 0:
56
+ self.llm.resize_token_embeddings(len(self.tokenizer))
57
+ input_embeddings = self.llm.get_input_embeddings().weight.data
58
+ output_embeddings = self.llm.get_output_embeddings().weight.data
59
+
60
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
61
+ dim=0, keepdim=True)
62
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
63
+ dim=0, keepdim=True)
64
+
65
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
66
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
67
+
68
+ self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
69
+ self.bop_token_idx = self.tokenizer("<p>", add_special_tokens=False).input_ids[0]
70
+ self.eop_token_idx = self.tokenizer("</p>", add_special_tokens=False).input_ids[0]
71
+ self.bbox_token_idx = self.tokenizer("<bbox>", add_special_tokens=False).input_ids[0]
72
+
73
+ if self.use_activation_checkpointing or self.use_llm_lora or not self.freeze_llm:
74
+ self.llm.enable_input_require_grads()
75
+
76
+ def forward(self, data, data_samples=None, mode='loss'):
77
+ if 'pixel_values' in data:
78
+ visual_outputs = self.visual_encoder(
79
+ data['pixel_values'].to(self.visual_encoder.dtype),
80
+ output_hidden_states=True)
81
+ pixel_values = self.projector(
82
+ visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
83
+ data['pixel_values'] = pixel_values
84
+ bboxes = data.pop('bboxes', None)
85
+ if bboxes is not None:
86
+ select_hidden_state_layer = -2
87
+ num_level_reg_features = 4
88
+ mlvl_reg_features = visual_outputs.hidden_states[select_hidden_state_layer::-3]
89
+ mlvl_reg_features = mlvl_reg_features[::-1]
90
+ mlvl_reg_features = mlvl_reg_features[-num_level_reg_features:]
91
+ mlvl_reg_features = [item[:, 1:] for item in mlvl_reg_features]
92
+ mlvl_reg_features = self.region_encoder(mlvl_reg_features, bboxes)
93
+ data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
94
+
95
+ if bboxes is not None:
96
+ inputs_embeds = data['inputs_embeds']
97
+ for i, reg_feat in enumerate(mlvl_reg_features):
98
+ reg_mask = data['new_input_ids'][i] == self.bbox_token_idx
99
+ inputs_embeds[i][reg_mask] = reg_feat
100
+ data['inputs_embeds'] = inputs_embeds
101
+
102
+ if mode == 'loss':
103
+ return self.compute_loss(data, data_samples)
104
+ elif mode == 'predict':
105
+ return self.predict(data, data_samples)
106
+ elif mode == 'tensor':
107
+ return self._forward(data, data_samples)
108
+ else:
109
+ raise NotImplementedError
110
+
111
+ def compute_loss(self, data, data_samples=None):
112
+ g_pixel_values = data.pop('g_pixel_values', None)
113
+ gt_masks = data.pop('masks', None)
114
+ new_input_ids = data.pop('new_input_ids', None)
115
+
116
+ output = self.llm(output_hidden_states=True, **data)
117
+ if gt_masks is None:
118
+ return {'llm_loss': output.loss}
119
+
120
+ resize_list = [pixel.shape[-2:] for pixel in g_pixel_values]
121
+ ori_size_list = [mask.shape[-2:] for mask in gt_masks]
122
+ g_pixel_values = torch.stack([
123
+ self.grounding_encoder.preprocess(pixel) for pixel in g_pixel_values
124
+ ])
125
+ image_embeddings = self.grounding_encoder.image_encoder(g_pixel_values)
126
+
127
+ seg_token_mask = new_input_ids == self.seg_token_idx
128
+ hidden_states = output.hidden_states
129
+ hidden_states = self.text_hidden_fcs(hidden_states[-1])
130
+ pred_embeddings = hidden_states[seg_token_mask]
131
+
132
+ seg_token_counts = seg_token_mask.int().sum(-1)
133
+ pred_embeddings_list = torch.split(pred_embeddings, seg_token_counts.tolist(), dim=0)
134
+
135
+ pred_masks = self._generate_and_postprocess_masks(
136
+ pred_embeddings_list, image_embeddings, resize_list, ori_size_list)
137
+
138
+ bs = len(pred_masks)
139
+ loss_mask, loss_dice = 0, 0
140
+ for i in range(bs):
141
+ pred_mask = pred_masks[i]
142
+ gt_mask = gt_masks[i]
143
+
144
+ sam_loss_mask = self.loss_mask(pred_mask, gt_mask)
145
+ sam_loss_dice = self.loss_dice(pred_mask, gt_mask)
146
+ accuracy = torch.eq((pred_mask.sigmoid() > 0.5), gt_mask).to(pred_mask).mean()
147
+ loss_mask += sam_loss_mask
148
+ loss_dice += sam_loss_dice
149
+
150
+
151
+ loss_dict = {
152
+ 'loss_mask': loss_mask / bs,
153
+ 'loss_dice': loss_dice / bs,
154
+ 'accuracy': accuracy,
155
+ 'llm_loss': output.loss,
156
+ }
157
+ return loss_dict
158
+
159
+
160
+ def _generate_and_postprocess_masks(self, pred_embeddings, image_embeddings, resize_list=None, orig_size_list=None, infer=False):
161
+ pred_masks = []
162
+ for i, pred_embedding in enumerate(pred_embeddings):
163
+ sparse_embeddings, dense_embeddings = self.grounding_encoder.prompt_encoder(
164
+ points=None, boxes=None, masks=None, text_embeds=pred_embedding.unsqueeze(1)
165
+ )
166
+ sparse_embeddings = sparse_embeddings.to(pred_embedding.dtype)
167
+ low_res_masks, _ = self.grounding_encoder.mask_decoder(
168
+ image_embeddings=image_embeddings[i].unsqueeze(0),
169
+ image_pe=self.grounding_encoder.prompt_encoder.get_dense_pe(),
170
+ sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings,
171
+ multimask_output=False, )
172
+
173
+ pred_mask = self.grounding_encoder.postprocess_masks(
174
+ low_res_masks, input_size=resize_list[i], original_size=orig_size_list[i], )
175
+ pred_masks.append(pred_mask[:, 0])
176
+ return pred_masks
177
+
178
+ def predict(self, data):
179
+ pass
180
+
181
+ def _forward(self, data, dta_samples=None):
182
+ outputs = self.llm(**data)
183
+ return outputs
projects/glamm/models/region_encoder.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+ from typing import List, Optional, Tuple
3
+ from torch import Tensor
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from mmcv import ops
11
+ from mmcv.cnn import ConvModule, Linear
12
+ from mmengine.model import BaseModule
13
+
14
+ class BaseRoIExtractor(BaseModule, metaclass=ABCMeta):
15
+ """Base class for RoI extractor.
16
+
17
+ Args:
18
+ roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and
19
+ arguments.
20
+ out_channels (int): Output channels of RoI layers.
21
+ featmap_strides (list[int]): Strides of input feature maps.
22
+ init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
23
+ dict], optional): Initialization config dict. Defaults to None.
24
+ """
25
+
26
+ def __init__(self,
27
+ roi_layer,
28
+ out_channels: int,
29
+ featmap_strides: List[int],
30
+ init_cfg=None) -> None:
31
+ super().__init__(init_cfg=init_cfg)
32
+ self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
33
+ self.out_channels = out_channels
34
+ self.featmap_strides = featmap_strides
35
+
36
+ @property
37
+ def num_inputs(self) -> int:
38
+ """int: Number of input feature maps."""
39
+ return len(self.featmap_strides)
40
+
41
+ def build_roi_layers(self, layer_cfg,
42
+ featmap_strides: List[int]) -> nn.ModuleList:
43
+ """Build RoI operator to extract feature from each level feature map.
44
+
45
+ Args:
46
+ layer_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
47
+ config RoI layer operation. Options are modules under
48
+ ``mmcv/ops`` such as ``RoIAlign``.
49
+ featmap_strides (list[int]): The stride of input feature map w.r.t
50
+ to the original image size, which would be used to scale RoI
51
+ coordinate (original image coordinate system) to feature
52
+ coordinate system.
53
+
54
+ Returns:
55
+ :obj:`nn.ModuleList`: The RoI extractor modules for each level
56
+ feature map.
57
+ """
58
+
59
+ cfg = layer_cfg.copy()
60
+ layer_type = cfg.pop('type')
61
+ if isinstance(layer_type, str):
62
+ assert hasattr(ops, layer_type)
63
+ layer_cls = getattr(ops, layer_type)
64
+ else:
65
+ layer_cls = layer_type
66
+ roi_layers = nn.ModuleList(
67
+ [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
68
+ return roi_layers
69
+
70
+ def roi_rescale(self, rois: Tensor, scale_factor: float) -> Tensor:
71
+ """Scale RoI coordinates by scale factor.
72
+
73
+ Args:
74
+ rois (Tensor): RoI (Region of Interest), shape (n, 5)
75
+ scale_factor (float): Scale factor that RoI will be multiplied by.
76
+
77
+ Returns:
78
+ Tensor: Scaled RoI.
79
+ """
80
+
81
+ cx = (rois[:, 1] + rois[:, 3]) * 0.5
82
+ cy = (rois[:, 2] + rois[:, 4]) * 0.5
83
+ w = rois[:, 3] - rois[:, 1]
84
+ h = rois[:, 4] - rois[:, 2]
85
+ new_w = w * scale_factor
86
+ new_h = h * scale_factor
87
+ x1 = cx - new_w * 0.5
88
+ x2 = cx + new_w * 0.5
89
+ y1 = cy - new_h * 0.5
90
+ y2 = cy + new_h * 0.5
91
+ new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
92
+ return new_rois
93
+
94
+ @abstractmethod
95
+ def forward(self,
96
+ feats: Tuple[Tensor],
97
+ rois: Tensor,
98
+ roi_scale_factor: Optional[float] = None) -> Tensor:
99
+ """Extractor ROI feats.
100
+
101
+ Args:
102
+ feats (Tuple[Tensor]): Multi-scale features.
103
+ rois (Tensor): RoIs with the shape (n, 5) where the first
104
+ column indicates batch id of each RoI.
105
+ roi_scale_factor (Optional[float]): RoI scale factor.
106
+ Defaults to None.
107
+
108
+ Returns:
109
+ Tensor: RoI feature.
110
+ """
111
+ pass
112
+
113
+
114
+ class MLVLFuseModule(nn.Module):
115
+ def __init__(self, input_dims=1024, embed_dims=1024, num_levels=3, num_fuse=4):
116
+ super(MLVLFuseModule, self).__init__()
117
+ self.embed_dims = embed_dims
118
+ self.num_levels = num_levels
119
+ self.num_fuse = num_fuse
120
+ self.input_dims = input_dims
121
+ self.shuffle_channles = embed_dims // 4
122
+
123
+ # contains the tuple of level indices that will do the interaction
124
+ self.fuse_lvl_list = []
125
+ num_levels = self.num_levels
126
+ for lvl in range(num_levels):
127
+ top_lvl = min(lvl + 1, num_levels - 1)
128
+ dow_lvl = max(lvl - 1, 0)
129
+ tar_lvl = lvl
130
+ self.fuse_lvl_list.append((tar_lvl, top_lvl, dow_lvl))
131
+
132
+ self.remain_chs = self.embed_dims - self.shuffle_channles * 2
133
+ self._init_layers()
134
+
135
+ def generate_coordinate(self, featmap_sizes, device='cuda'):
136
+
137
+ x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
138
+ y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
139
+ y, x = torch.meshgrid(y_range, x_range)
140
+ y = y.expand([featmap_sizes[0], 1, -1, -1])
141
+ x = x.expand([featmap_sizes[0], 1, -1, -1])
142
+ coord_feat = torch.cat([x, y], 1)
143
+
144
+ return coord_feat
145
+
146
+ def _init_layers(self):
147
+ self.input_conv = nn.ModuleList([nn.Conv2d(self.input_dims + 2,
148
+ self.embed_dims, 1)
149
+ for _ in range(self.num_levels)])
150
+ self.fuse_convs = nn.ModuleList()
151
+ for i in range(self.num_fuse):
152
+ self.fuse_convs.append(
153
+ ConvModule(self.embed_dims,
154
+ self.embed_dims,
155
+ 3,
156
+ stride=1,
157
+ padding=3 // 2,
158
+ conv_cfg=None,
159
+ norm_cfg=dict(type='GN',
160
+ num_groups=64,
161
+ requires_grad=True)
162
+ ))
163
+
164
+ def init_weights(self):
165
+ pass
166
+
167
+ def _single_shuffle(self, inputs, conv_module):
168
+ if not isinstance(conv_module, (nn.ModuleList, list)):
169
+ conv_module = [conv_module]
170
+ for single_conv_m in conv_module:
171
+ fused_inputs = []
172
+ for fuse_lvl_tuple in self.fuse_lvl_list:
173
+ tar_lvl, top_lvl, dow_lvl = fuse_lvl_tuple
174
+ tar_input = inputs[tar_lvl]
175
+ top_input = inputs[top_lvl]
176
+ down_input = inputs[dow_lvl]
177
+ remain = tar_input[:, :self.remain_chs]
178
+ from_top = top_input[:, self.remain_chs:][:, self.shuffle_channles:]
179
+ from_top = F.interpolate(from_top.to(torch.float32),
180
+ size=tar_input.shape[-2:],
181
+ mode='bilinear',
182
+ align_corners=True)
183
+ from_down = down_input[:, self.remain_chs:][:, :self.shuffle_channles]
184
+ from_down = F.interpolate(from_down.to(torch.float32),
185
+ size=tar_input.shape[-2:],
186
+ mode='bilinear',
187
+ align_corners=True)
188
+ fused_inputs.append(
189
+ torch.cat([remain, from_top.to(remain.dtype), from_down.to(remain.dtype)], dim=1))
190
+ fused_inputs = [single_conv_m(item) for item in fused_inputs]
191
+ inputs = fused_inputs
192
+ return inputs
193
+
194
+ def forward(self, inputs, ):
195
+ feat_size = [item.shape for item in inputs]
196
+ new_inputs = []
197
+ for feat, single_feat_size in zip(inputs, feat_size):
198
+ coord_feat = self.generate_coordinate(
199
+ single_feat_size, device=inputs[0].device)
200
+ # feat = torch.cat([feat, coord_feat], dim=1)
201
+ feat = torch.cat([feat, coord_feat.to(feat.dtype)], dim=1)
202
+ new_inputs.append(feat)
203
+ inputs = new_inputs
204
+
205
+ inputs = [self.input_conv[lvl](item)
206
+ for lvl, item in enumerate(inputs)]
207
+
208
+ for conv_m in self.fuse_convs:
209
+ inputs = self._single_shuffle(inputs, [conv_m])
210
+ return inputs
211
+
212
+
213
+ class MlvlRoIExtractor(BaseRoIExtractor):
214
+ def __init__(self,
215
+ roi_layer,
216
+ out_channels,
217
+ featmap_strides,
218
+ embed_dims=1024,
219
+ stride=1,
220
+ norm_init=True,
221
+ fuse_level=3,
222
+ finest_scale=56,
223
+ init_cfg=None):
224
+ super(MlvlRoIExtractor, self).__init__(roi_layer, out_channels,
225
+ featmap_strides, init_cfg)
226
+ self.embed_dims = embed_dims
227
+ self.finest_scale = finest_scale
228
+ self.fuse_level = fuse_level
229
+ self.norm_init = norm_init
230
+
231
+ self.pconvs = nn.ModuleList(
232
+ nn.Conv2d(self.embed_dims, self.embed_dims, 3, stride=1, padding=1)
233
+ for _ in range(self.fuse_level))
234
+ self.pos_embedd = nn.Sequential(
235
+ nn.Linear(4, 256),
236
+ nn.ReLU(inplace=True),
237
+ nn.LayerNorm(256),
238
+ nn.Linear(256, 1024),
239
+ nn.ReLU(inplace=True),
240
+ nn.LayerNorm(1024),
241
+ )
242
+ self.updims = nn.Linear(1024, 4096)
243
+
244
+ self.flatten_linear = nn.Linear(
245
+ self.embed_dims * self.roi_layers[0].output_size[0] ** 2, 1024)
246
+
247
+ self.norm_init_weights()
248
+
249
+ # self.dtype = torch.float32
250
+ def norm_init_weights(self):
251
+ pass
252
+
253
+ def forward(self, feats, rois, roi_scale_factor=None):
254
+ """Forward function."""
255
+ num_imgs = len(rois)
256
+ # feats = [item for item in feats]
257
+ batch_rois = torch.cat(rois, dim=0).to(feats[0].dtype)
258
+ pos_embedd = self.pos_embedd(batch_rois)
259
+ out_size = self.roi_layers[0].output_size
260
+ num_levels = len(feats)
261
+ if feats[0].dim() == 3:
262
+ h = w = int(math.sqrt(feats[0].shape[1]))
263
+ assert h == 16
264
+ assert w == 16
265
+ b, c = feats[0].shape[0], feats[0].shape[-1]
266
+ feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2)
267
+ for item in feats]
268
+ new_rois = []
269
+ for img_id, single_img_roi in enumerate(rois):
270
+ # rescale to original img scale
271
+ single_img_roi = single_img_roi * 224
272
+
273
+ roi_img_id = single_img_roi.new_ones(len(single_img_roi)) * img_id
274
+ single_img_roi = torch.cat(
275
+ [roi_img_id[:, None], single_img_roi], dim=1)
276
+ new_rois.append(single_img_roi)
277
+ rois = torch.cat(new_rois)
278
+
279
+ roi_feats = feats[0].new_zeros(self.fuse_level,
280
+ rois.size(0), self.out_channels, *out_size)
281
+
282
+ for i in range(num_levels):
283
+ if len(rois) > 0:
284
+ rois_ = rois
285
+ ori_dtype = feats[i].dtype
286
+ roi_feats_t = self.roi_layers[i](feats[i].to(
287
+ torch.float32), rois_.to(torch.float32))
288
+
289
+ roi_feats[i] = roi_feats_t.to(ori_dtype)
290
+
291
+ else:
292
+ roi_feats += sum(
293
+ x.view(-1)[0]
294
+ for x in self.parameters()) * 0. + feats[i].sum() * 0.
295
+
296
+ fuse_roi_feats = []
297
+ for i in range(self.fuse_level):
298
+ fuse_roi_feats.append(self.pconvs[i](roi_feats[i]))
299
+
300
+ fuse_roi_feats = sum(fuse_roi_feats)
301
+ fuse_roi_feats = F.relu(fuse_roi_feats)
302
+ fuse_roi_feats = fuse_roi_feats.flatten(1, -1)
303
+ fuse_roi_feats = self.flatten_linear(fuse_roi_feats)
304
+ fuse_roi_feats = fuse_roi_feats + pos_embedd
305
+ fuse_roi_feats = self.updims(fuse_roi_feats)
306
+ query_feats = []
307
+ for i in range(num_imgs):
308
+ mask = rois[:, 0] == i
309
+ query_feats.append(fuse_roi_feats[mask])
310
+
311
+ return query_feats
312
+
313
+
314
+ class MLVLROIQueryModule(nn.Module):
315
+ def __init__(self, embed_dims=1024, out_dims=4096,
316
+ num_levels=3):
317
+ super(MLVLROIQueryModule, self).__init__()
318
+ self.mlvl_fuse = MLVLFuseModule(input_dims=embed_dims,
319
+ embed_dims=embed_dims,
320
+ num_levels=num_levels,
321
+ num_fuse=5)
322
+ strids = [14 / 8, 14 / 4, 14 / 2, 14]
323
+ assert len(strids) == num_levels
324
+ bbox_roi_extractor = dict(roi_layer=dict(type='RoIAlign',
325
+ output_size=14,
326
+ sampling_ratio=2),
327
+ out_channels=embed_dims,
328
+ embed_dims=embed_dims,
329
+ fuse_level=num_levels,
330
+ featmap_strides=strids)
331
+
332
+ self.roi_align = MlvlRoIExtractor(**bbox_roi_extractor)
333
+
334
+ def forward(self, mlvl_feats, bboxes):
335
+ if mlvl_feats[0].dim() == 3:
336
+ h = w = int(math.sqrt(mlvl_feats[0].shape[1]))
337
+ assert h == 24
338
+ assert w == 24
339
+ b, c = mlvl_feats[0].shape[0], mlvl_feats[0].shape[-1]
340
+ mlvl_feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2) for item in mlvl_feats]
341
+ base_shape = mlvl_feats[0].shape[-2:]
342
+ num_level = len(mlvl_feats)
343
+ to_shape = [(base_shape[0] * 2 ** level, base_shape[1] * 2 ** level)
344
+ for level in range(num_level)]
345
+ to_shape = to_shape[::-1]
346
+ for level in range(num_level):
347
+ feat = mlvl_feats[level]
348
+ shape = to_shape[level]
349
+ # feat = feat
350
+ # mlvl_feats[level] = F.interpolate(feat, size=shape, mode='bilinear', align_corners=True)
351
+ # todo: temporary fix for "upsample_bilinear2d_out_frame" not implemented for 'BFloat16'
352
+ feat = feat.to(torch.float32)
353
+ mlvl_feats[level] = F.interpolate(
354
+ feat, size=shape, mode='bilinear', align_corners=True)
355
+ mlvl_feats[level] = mlvl_feats[level].to(torch.bfloat16)
356
+
357
+ mlvl_feats = self.mlvl_fuse(mlvl_feats)
358
+
359
+ return self.roi_align(mlvl_feats, bboxes)
projects/glamm/utils.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ from transformers import PreTrainedModel
8
+ from typing import List, Optional
9
+
10
+
11
+ IGNORE_INDEX = -100
12
+ IMAGE_TOKEN_INDEX = -200
13
+
14
+ DEFAULT_EOS_TOKEN = '</s>'
15
+ DEFAULT_BOS_TOKEN = '<s>'
16
+ DEFAULT_UNK_TOKEN = '<unk>'
17
+
18
+ DEFAULT_IMAGE_TOKEN = "<image>"
19
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
20
+ DEFAULT_IM_START_TOKEN = "<im_start>"
21
+ DEFAULT_IM_END_TOKEN = "<im_end>"
22
+ DEFAULT_BBOX_TOKEN = "<bbox>"
23
+
24
+
25
+
26
+ # Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501
27
+ def prepare_inputs_labels_for_multimodal(
28
+ llm: PreTrainedModel,
29
+ input_ids: torch.LongTensor = None,
30
+ position_ids: Optional[torch.LongTensor] = None,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
33
+ labels: Optional[torch.LongTensor] = None,
34
+ pixel_values: Optional[torch.FloatTensor] = None,
35
+ **kwargs):
36
+ if pixel_values is None:
37
+ kwargs.update({
38
+ 'input_ids': input_ids,
39
+ 'position_ids': position_ids,
40
+ 'attention_mask': attention_mask,
41
+ 'past_key_values': past_key_values,
42
+ 'inputs_embeds': None,
43
+ 'labels': labels
44
+ })
45
+ return kwargs
46
+
47
+ _labels = labels
48
+ _position_ids = position_ids
49
+ _attention_mask = attention_mask
50
+ if attention_mask is None:
51
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
52
+ else:
53
+ attention_mask = attention_mask.bool()
54
+ if position_ids is None:
55
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
56
+ if labels is None:
57
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
58
+
59
+ # remove the padding using attention_mask -- TODO: double check
60
+ input_ids = [
61
+ cur_input_ids[cur_attention_mask]
62
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
63
+ ]
64
+ labels = [
65
+ cur_labels[cur_attention_mask]
66
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
67
+ ]
68
+
69
+ new_inputs_embeds = []
70
+ new_labels = []
71
+ new_input_ids = []
72
+ cur_image_idx = 0
73
+ for batch_idx, cur_input_ids in enumerate(input_ids):
74
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
75
+ if num_images == 0:
76
+ cur_pixel_values = pixel_values[cur_image_idx]
77
+ cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids)
78
+ cur_inputs_embeds = torch.cat([cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
79
+ new_inputs_embeds.append(cur_inputs_embeds)
80
+ new_labels.append(labels[batch_idx])
81
+ new_input_ids.append(cur_input_ids)
82
+ cur_image_idx += 1
83
+ continue
84
+
85
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
86
+ cur_input_ids_noim = []
87
+ cur_labels = labels[batch_idx]
88
+ cur_labels_noim = []
89
+ for i in range(len(image_token_indices) - 1):
90
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
91
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
92
+
93
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
94
+ cur_inputs_embeds = llm.get_input_embeddings()(torch.cat(cur_input_ids_noim))
95
+ cur_inputs_embeds_no_im = torch.split(cur_inputs_embeds, split_sizes, dim=0)
96
+ cur_new_inputs_embeds = []
97
+ cur_new_labels = []
98
+ cur_new_input_ids = []
99
+
100
+ for i in range(num_images + 1):
101
+ cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
102
+ cur_new_labels.append(cur_labels_noim[i])
103
+ cur_new_input_ids.append(cur_input_ids_noim[i])
104
+ if i < num_images:
105
+ cur_pixel_values = pixel_values[cur_image_idx]
106
+ cur_image_idx += 1
107
+ cur_new_inputs_embeds.append(cur_pixel_values)
108
+ cur_new_labels.append(torch.full((cur_pixel_values.shape[0], ), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
109
+ cur_new_input_ids.append(torch.full((cur_pixel_values.shape[0], ), IMAGE_TOKEN_INDEX, device=cur_input_ids.device, dtype=cur_input_ids.dtype))
110
+
111
+ cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
112
+ cur_new_labels = torch.cat(cur_new_labels)
113
+ cur_new_input_ids = torch.cat(cur_new_input_ids)
114
+
115
+ new_inputs_embeds.append(cur_new_inputs_embeds)
116
+ new_labels.append(cur_new_labels)
117
+ new_input_ids.append(cur_new_input_ids)
118
+
119
+ # Combine them
120
+ max_len = max(x.shape[0] for x in new_inputs_embeds)
121
+ batch_size = len(new_inputs_embeds)
122
+
123
+ new_inputs_embeds_padded = []
124
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
125
+ new_input_ids_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device)
126
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
127
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
128
+
129
+ for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_inputs_embeds, new_labels, new_input_ids)):
130
+ cur_len = cur_new_embed.shape[0]
131
+ new_inputs_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
132
+ if cur_len > 0:
133
+ new_labels_padded[i, :cur_len] = cur_new_labels
134
+ new_input_ids_padded[i, :cur_len] = cur_new_input_ids
135
+ attention_mask[i, :cur_len] = True
136
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
137
+
138
+ new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)
139
+
140
+ if _labels is None:
141
+ new_labels = None
142
+ else:
143
+ new_labels = new_labels_padded
144
+
145
+ new_input_ids = new_input_ids_padded
146
+
147
+ if _attention_mask is None:
148
+ attention_mask = None
149
+ else:
150
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
151
+
152
+ if _position_ids is None:
153
+ position_ids = None
154
+
155
+ kwargs.update({
156
+ 'input_ids': None,
157
+ 'position_ids': position_ids,
158
+ 'attention_mask': attention_mask,
159
+ 'past_key_values': past_key_values,
160
+ 'inputs_embeds': new_inputs_embeds,
161
+ 'labels': new_labels,
162
+ 'new_input_ids': new_input_ids
163
+ })
164
+ return kwargs
165
+
166
+ class Summary(Enum):
167
+ NONE = 0
168
+ AVERAGE = 1
169
+ SUM = 2
170
+ COUNT = 3
171
+
172
+
173
+ class AverageMeter(object):
174
+ """Computes and stores the average and current value"""
175
+
176
+ def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
177
+ self.name = name
178
+ self.fmt = fmt
179
+ self.summary_type = summary_type
180
+ self.reset()
181
+
182
+ def reset(self):
183
+ self.val = 0
184
+ self.avg = 0
185
+ self.sum = 0
186
+ self.count = 0
187
+
188
+ def update(self, val, n=1):
189
+ self.val = val
190
+ self.sum += val * n
191
+ self.count += n
192
+ self.avg = self.sum / self.count
193
+
194
+ def all_reduce(self):
195
+ device = "cuda" if torch.cuda.is_available() else "cpu"
196
+ if isinstance(self.sum, np.ndarray):
197
+ total = torch.tensor(
198
+ self.sum.tolist()
199
+ + [
200
+ self.count,
201
+ ],
202
+ dtype=torch.float32,
203
+ device=device,
204
+ )
205
+ else:
206
+ total = torch.tensor(
207
+ [self.sum, self.count], dtype=torch.float32, device=device
208
+ )
209
+
210
+ dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
211
+ if total.shape[0] > 2:
212
+ self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
213
+ else:
214
+ self.sum, self.count = total.tolist()
215
+ self.avg = self.sum / (self.count + 1e-5)
216
+
217
+ def __str__(self):
218
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
219
+ return fmtstr.format(**self.__dict__)
220
+
221
+ def summary(self):
222
+ fmtstr = ""
223
+ if self.summary_type is Summary.NONE:
224
+ fmtstr = ""
225
+ elif self.summary_type is Summary.AVERAGE:
226
+ fmtstr = "{name} {avg:.3f}"
227
+ elif self.summary_type is Summary.SUM:
228
+ fmtstr = "{name} {sum:.3f}"
229
+ elif self.summary_type is Summary.COUNT:
230
+ fmtstr = "{name} {count:.3f}"
231
+ else:
232
+ raise ValueError("invalid summary type %r" % self.summary_type)
233
+
234
+ return fmtstr.format(**self.__dict__)
235
+
236
+
237
+ def intersectionAndUnionGPU(output, target, K, ignore_index=255):
238
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
239
+ assert output.dim() in [1, 2, 3]
240
+ assert output.shape == target.shape
241
+ output = output.view(-1)
242
+ target = target.view(-1)
243
+ output[target == ignore_index] = ignore_index
244
+ intersection = output[output == target]
245
+ area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1)
246
+ area_output = torch.histc(output, bins=K, min=0, max=K - 1)
247
+ area_target = torch.histc(target, bins=K, min=0, max=K - 1)
248
+ area_union = area_output + area_target - area_intersection
249
+ return area_intersection, area_union, area_target
250
+
251
+
252
+ class ProgressMeter(object):
253
+ def __init__(self, num_batches, meters, prefix=""):
254
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
255
+ self.meters = meters
256
+ self.prefix = prefix
257
+
258
+ def display(self, batch):
259
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
260
+ entries += [str(meter) for meter in self.meters]
261
+ print("\t".join(entries))
262
+
263
+ def display_summary(self):
264
+ entries = [" *"]
265
+ entries += [meter.summary() for meter in self.meters]
266
+ print(" ".join(entries))
267
+
268
+ def _get_batch_fmtstr(self, num_batches):
269
+ num_digits = len(str(num_batches // 1))
270
+ fmt = "{:" + str(num_digits) + "d}"
271
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
272
+
273
+
274
+ def dict_to_cuda(input_dict):
275
+ for k, v in input_dict.items():
276
+ if isinstance(input_dict[k], torch.Tensor):
277
+ input_dict[k] = v.cuda(non_blocking=True)
278
+ elif isinstance(v, list) and len(v) > 0:
279
+ input_dict[k] = [ele.cuda(non_blocking=True) if isinstance(ele, torch.Tensor) else ele for ele in v]
280
+ return input_dict
projects/llava_sam2/configs/sa2va_4b.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
2
+ LoggerHook, ParamSchedulerHook)
3
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
4
+ from torch.optim import AdamW
5
+ from transformers import AutoTokenizer
6
+
7
+ from xtuner.dataset import ConcatDataset
8
+ from xtuner.dataset.samplers import LengthGroupedSampler
9
+ from xtuner.engine.hooks import DatasetInfoHook
10
+ from xtuner.engine.runner import TrainLoop
11
+ from xtuner.utils import PROMPT_TEMPLATE
12
+ from xtuner.dataset.map_fns import template_map_fn_factory
13
+
14
+ from third_parts.mmdet.models.losses import DiceLoss, CrossEntropyLoss
15
+ from peft import LoraConfig
16
+
17
+ from projects.llava_sam2.models.internvl import InternVL_Slowfast
18
+
19
+ from projects.llava_sam2.models import VideoLLaVASAMModel, SAM2TrainRunner, VideoLLaVASAMModel_zero3
20
+ from projects.llava_sam2.datasets import VideoReVOSDataset, VideoMeVISDataset, VideoRefYoutubeVOSDataset, video_lisa_collate_fn, VideoSAM2Dataset
21
+ from projects.llava_sam2.datasets import VideoChatUniViDataset
22
+ from projects.llava_sam2.datasets import RefCOCOgGCGDataset, OpenPsgGCGDataset, FlickrGCGDataset, GranDfGCGDataset, OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
23
+ from projects.llava_sam2.datasets import LLaVADataset
24
+ from projects.llava_sam2.datasets import ReferSegmDataset
25
+ from projects.llava_sam2.models.preprocess.image_resize import DirectResize
26
+
27
+ #######################################################################
28
+ # PART 1 Settings #
29
+ #######################################################################
30
+ # Model
31
+ path = './pretrained/InternVL2_5-4B'
32
+ pretrained_pth = None
33
+
34
+ # Data
35
+ prompt_template = PROMPT_TEMPLATE.phi3_chat
36
+ max_length = 8192
37
+
38
+ # Scheduler & Optimizer
39
+ batch_size = 2 # per_device
40
+ accumulative_counts = 4
41
+ dataloader_num_workers = 4
42
+ max_epochs = 1
43
+ optim_type = AdamW
44
+ # official 1024 -> 4e-5
45
+ # lr = 1e-6
46
+ lr = 4e-5
47
+ betas = (0.9, 0.999)
48
+ weight_decay = 0.05
49
+ max_norm = 1 # grad clip
50
+ warmup_ratio = 0.05
51
+
52
+ # Save
53
+ save_steps = 1000
54
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
55
+
56
+ special_tokens = ['[SEG]', '<p>', '</p>', '<vp>', '</vp>']
57
+
58
+ tokenizer = dict(
59
+ type=AutoTokenizer.from_pretrained,
60
+ pretrained_model_name_or_path=path,
61
+ trust_remote_code=True,
62
+ padding_side='right')
63
+
64
+ extra_image_processor = dict(
65
+ type=DirectResize,
66
+ target_length=1024,
67
+ )
68
+ #######################################################################
69
+ # PART 2 Model & Tokenizer & Image Processor #
70
+ #######################################################################
71
+ model = dict(
72
+ type=VideoLLaVASAMModel_zero3,
73
+ special_tokens=special_tokens,
74
+ frozen_sam2_decoder=False,
75
+ mllm=dict(
76
+ type=InternVL_Slowfast,
77
+ model_path=path,
78
+ freeze_llm=True,
79
+ freeze_visual_encoder=True,
80
+ llm_lora=dict(
81
+ type=LoraConfig,
82
+ r=128,
83
+ lora_alpha=256,
84
+ lora_dropout=0.05,
85
+ bias='none',
86
+ task_type='CAUSAL_LM'),
87
+ special_tokens=special_tokens,
88
+ ),
89
+ tokenizer=tokenizer,
90
+ grounding_encoder=dict(
91
+ type=SAM2TrainRunner,
92
+ ),
93
+ loss_mask=dict(
94
+ type=CrossEntropyLoss,
95
+ use_sigmoid=True,
96
+ reduction='mean',
97
+ loss_weight=2.0),
98
+ loss_dice=dict(
99
+ type=DiceLoss,
100
+ use_sigmoid=True,
101
+ activate=True,
102
+ reduction='mean',
103
+ naive_dice=True,
104
+ eps=1.0,
105
+ loss_weight=0.5),
106
+ pretrained_pth=pretrained_pth,
107
+ loss_sample_points=True,
108
+ # loss_sample_points=False,
109
+ bs=batch_size,
110
+ )
111
+
112
+ #######################################################################
113
+ # PART 3 Dataset & Dataloader #
114
+ #######################################################################
115
+
116
+
117
+ VIDEO_DATAS = './data/video_datas/'
118
+ IMG_DATAS = './data/image_datas/'
119
+
120
+ ############### video res
121
+ data_root_revos = './data/video_datas/revos/'
122
+ video_revos_image_folder = data_root_revos
123
+ video_revos_expression_file = data_root_revos + 'meta_expressions_train_.json'
124
+ video_revos_mask_file = data_root_revos + 'mask_dict.json'
125
+
126
+ data_root_mevis = './data/video_datas/mevis/train/'
127
+ video_mevis_image_folder = data_root_mevis + 'JPEGImages'
128
+ video_mevis_expression_file = data_root_mevis + 'meta_expressions.json'
129
+ video_mevis_mask_file = data_root_mevis + 'mask_dict.json'
130
+
131
+ data_root_refytvos = './data/video_datas/rvos/'
132
+ video_refytvos_image_folder = data_root_refytvos + 'train/JPEGImages/'
133
+ video_refytvos_expression_file = data_root_refytvos + 'meta_expressions/train/meta_expressions.json'
134
+ video_refytvos_mask_file = data_root_refytvos + 'mask_dict.pkl'
135
+
136
+ video_revos_dataset = dict(
137
+ type=VideoReVOSDataset,
138
+ image_folder=video_revos_image_folder,
139
+ expression_file=video_revos_expression_file,
140
+ mask_file=video_revos_mask_file,
141
+ tokenizer=tokenizer,
142
+ template_map_fn=dict(
143
+ type=template_map_fn_factory, template=prompt_template),
144
+ max_length=max_length,
145
+ lazy=True,
146
+ repeats=10,
147
+ special_tokens=special_tokens,
148
+ extra_image_processor=extra_image_processor,
149
+ sampled_frames=5,
150
+ )
151
+
152
+ video_mevis_dataset = dict(
153
+ type=VideoMeVISDataset,
154
+ image_folder=video_mevis_image_folder,
155
+ expression_file=video_mevis_expression_file,
156
+ mask_file=video_mevis_mask_file,
157
+ tokenizer=tokenizer,
158
+ template_map_fn=dict(
159
+ type=template_map_fn_factory, template=prompt_template),
160
+ max_length=max_length,
161
+ lazy=True,
162
+ repeats=4,
163
+ special_tokens=special_tokens,
164
+ extra_image_processor=extra_image_processor,
165
+ sampled_frames=5,
166
+ )
167
+
168
+ video_refytvos_dataset = dict(
169
+ type=VideoRefYoutubeVOSDataset,
170
+ image_folder=video_refytvos_image_folder,
171
+ expression_file=video_refytvos_expression_file,
172
+ mask_file=video_refytvos_mask_file,
173
+ tokenizer=tokenizer,
174
+ template_map_fn=dict(
175
+ type=template_map_fn_factory, template=prompt_template),
176
+ max_length=max_length,
177
+ lazy=True,
178
+ repeats=4,
179
+ special_tokens=special_tokens,
180
+ extra_image_processor=extra_image_processor,
181
+ sampled_frames=5,
182
+ )
183
+
184
+ ################### Video chat
185
+ data_root_video_chatunivi = VIDEO_DATAS + 'video_vlm/video_chat/'
186
+ video_chatunivi_image_folder = data_root_video_chatunivi + 'Activity_Videos/'
187
+ video_chatunivi_json_file = data_root_video_chatunivi+ 'video_chat.json'
188
+
189
+ video_qa_dataset = dict(
190
+ type=VideoChatUniViDataset,
191
+ image_folder=video_chatunivi_image_folder,
192
+ json_file=video_chatunivi_json_file,
193
+ tokenizer=tokenizer,
194
+ template_map_fn=dict(
195
+ type=template_map_fn_factory, template=prompt_template),
196
+ max_length=max_length,
197
+ lazy=True,
198
+ repeats=1,
199
+ special_tokens=special_tokens,
200
+ extra_image_processor=extra_image_processor,
201
+ sampled_frames=5,
202
+ )
203
+
204
+ ################## image chat
205
+ llava_vqa_dataset = dict(
206
+ type=LLaVADataset,
207
+ tokenizer=tokenizer,
208
+ data_path='data/llava_data/LLaVA-Instruct-150K/llava_v1_5_mix665k.json',
209
+ prompt_template=prompt_template,
210
+ special_tokens=special_tokens,
211
+ image_folder='data/llava_data/llava_images/',
212
+ )
213
+
214
+ ################## image res
215
+ refcoco_segm_dataset=dict(
216
+ type=ReferSegmDataset,
217
+ tokenizer=tokenizer,
218
+ special_tokens=special_tokens,
219
+ extra_image_processor=extra_image_processor,
220
+ data_root='data/ref_seg/refcoco',
221
+ data_prefix=dict(img_path='coco2014/train2014/'),
222
+ ann_file='instances.json',
223
+ split_file='refs(unc).p',
224
+ prompt_template=prompt_template,
225
+ num_classes_per_sample=5,
226
+ max_length=max_length,
227
+ )
228
+ refcoco_plus_segm_dataset=dict(
229
+ type=ReferSegmDataset,
230
+ tokenizer=tokenizer,
231
+ special_tokens=special_tokens,
232
+ extra_image_processor=extra_image_processor,
233
+ data_root='data/ref_seg/refcoco+',
234
+ data_prefix=dict(img_path='coco2014/train2014/'),
235
+ ann_file='instances.json',
236
+ split_file='refs(unc).p',
237
+ prompt_template=prompt_template,
238
+ num_classes_per_sample=5,
239
+ max_length=max_length,
240
+ )
241
+ refcocog_segm_dataset=dict(
242
+ type=ReferSegmDataset,
243
+ tokenizer=tokenizer,
244
+ special_tokens=special_tokens,
245
+ extra_image_processor=extra_image_processor,
246
+ data_root='data/ref_seg/refcocog',
247
+ data_prefix=dict(img_path='coco2014/train2014/'),
248
+ ann_file='instances.json',
249
+ split_file='refs(umd).p',
250
+ prompt_template=prompt_template,
251
+ num_classes_per_sample=5,
252
+ max_length=max_length,
253
+ )
254
+
255
+ # image gcg datas
256
+ glamm_data_root = './data/glamm_data/'
257
+
258
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
259
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
260
+
261
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
262
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
263
+
264
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
265
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
266
+
267
+ psg_image_path = glamm_data_root + 'images/coco2017/'
268
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
269
+
270
+ glamm_refcocog_dataset = dict(
271
+ type=RefCOCOgGCGDataset,
272
+ image_folder=refcocog_image_path,
273
+ data_path=refcocog_ann_file,
274
+ tokenizer=tokenizer,
275
+ max_length=max_length,
276
+ special_tokens=special_tokens,
277
+ template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
278
+ extra_image_processor=extra_image_processor,
279
+ lazy=True,
280
+ repeats=1,
281
+ )
282
+
283
+ glamm_grandf_dataset = dict(
284
+ type=GranDfGCGDataset,
285
+ data_path=grandf_ann_file,
286
+ image_folder=grandf_image_path,
287
+ tokenizer=tokenizer,
288
+ max_length=max_length,
289
+ special_tokens=special_tokens,
290
+ template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
291
+ extra_image_processor=extra_image_processor,
292
+ lazy=True,
293
+ repeats=10,
294
+ )
295
+
296
+ glamm_psg_dataset = dict(
297
+ type=OpenPsgGCGDataset,
298
+ data_path=psg_ann_file,
299
+ image_folder=psg_image_path,
300
+ tokenizer=tokenizer,
301
+ max_length=max_length,
302
+ special_tokens=special_tokens,
303
+ template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
304
+ extra_image_processor=extra_image_processor,
305
+ lazy=True,
306
+ repeats=1,
307
+ )
308
+
309
+ glamm_flickr_dataset = dict(
310
+ type=FlickrGCGDataset,
311
+ data_path=flickr_ann_file,
312
+ image_folder=flickr_image_path,
313
+ tokenizer=tokenizer,
314
+ max_length=max_length,
315
+ special_tokens=special_tokens,
316
+ template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
317
+ extra_image_processor=extra_image_processor,
318
+ lazy=True,
319
+ repeats=1,
320
+ )
321
+
322
+ # sam2 data
323
+ data_sam2_folder = VIDEO_DATAS + 'segmentation_datasets/sam_v_full/'
324
+ data_sam2_expression_file = './whole_pesudo_cap_v3/sam_v_final_v3.json'
325
+
326
+ video_sam2_dataset = dict(
327
+ type=VideoSAM2Dataset,
328
+ sam2_folder=data_sam2_folder,
329
+ expression_file=data_sam2_expression_file,
330
+ tokenizer=tokenizer,
331
+ template_map_fn=dict(
332
+ type=template_map_fn_factory, template=prompt_template),
333
+ max_length=max_length,
334
+ lazy=True,
335
+ repeats=4,
336
+ special_tokens=special_tokens,
337
+ extra_image_processor=extra_image_processor,
338
+ sampled_frames=5,
339
+ select_number=5,
340
+ )
341
+
342
+ # osprey
343
+ data_osprey_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_conversation.json'
344
+ data_osprey_image_folders = [
345
+ IMG_DATAS+ 'coco/train2014/',
346
+ IMG_DATAS + 'coco/val2014/',
347
+ IMG_DATAS + 'coco/train2017/',
348
+ IMG_DATAS + 'coco/val2017/',
349
+ ]
350
+
351
+ image_osprey_dataset = dict(
352
+ type=OspreyDataset,
353
+ image_folder=data_osprey_image_folders,
354
+ data_path=data_osprey_file,
355
+ tokenizer=tokenizer,
356
+ template_map_fn=dict(
357
+ type=template_map_fn_factory, template=prompt_template),
358
+ max_length=max_length,
359
+ lazy=True,
360
+ repeats=1,
361
+ special_tokens=special_tokens,
362
+ )
363
+
364
+ data_osprey_detail_description_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_detail_description.json'
365
+ image_osprey_description_dataset = dict(
366
+ type=OspreyDescriptionDataset,
367
+ image_folder=data_osprey_image_folders,
368
+ data_path=data_osprey_detail_description_file,
369
+ tokenizer=tokenizer,
370
+ template_map_fn=dict(
371
+ type=template_map_fn_factory, template=prompt_template),
372
+ max_length=max_length,
373
+ lazy=True,
374
+ repeats=1,
375
+ special_tokens=special_tokens,
376
+ )
377
+
378
+ data_osprey_short_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_short_form.json'
379
+ image_osprey_short_dataset = dict(
380
+ type=OspreyShortDescriptionDataset,
381
+ image_folder=data_osprey_image_folders,
382
+ data_path=data_osprey_short_file,
383
+ tokenizer=tokenizer,
384
+ template_map_fn=dict(
385
+ type=template_map_fn_factory, template=prompt_template),
386
+ max_length=max_length,
387
+ lazy=True,
388
+ repeats=1,
389
+ special_tokens=special_tokens,
390
+ )
391
+
392
+ data_osprey_part_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_part_level.json'
393
+ image_osprey_part_dataset = dict(
394
+ type=OspreyDataset,
395
+ image_folder=data_osprey_image_folders,
396
+ data_path=data_osprey_part_file,
397
+ tokenizer=tokenizer,
398
+ template_map_fn=dict(
399
+ type=template_map_fn_factory, template=prompt_template),
400
+ max_length=max_length,
401
+ lazy=True,
402
+ repeats=1,
403
+ special_tokens=special_tokens,
404
+ )
405
+
406
+ data_osprey_positive_neg_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_lvis_positive_negative.json'
407
+ image_osprey_positive_neg_dataset = dict(
408
+ type=OspreyDataset,
409
+ image_folder=data_osprey_image_folders,
410
+ data_path=data_osprey_positive_neg_file,
411
+ tokenizer=tokenizer,
412
+ template_map_fn=dict(
413
+ type=template_map_fn_factory, template=prompt_template),
414
+ max_length=max_length,
415
+ lazy=True,
416
+ repeats=1,
417
+ special_tokens=special_tokens,
418
+ )
419
+
420
+ train_dataset = dict(
421
+ type=ConcatDataset, datasets=[
422
+ # sem seg
423
+ # semantic_seg_ade20k_dataset,
424
+ # ref seg
425
+ refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
426
+ refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
427
+ refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
428
+ refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
429
+ # image qa
430
+ llava_vqa_dataset,
431
+ # video res
432
+ video_mevis_dataset, video_revos_dataset, video_refytvos_dataset,
433
+ # video chat
434
+ video_qa_dataset,
435
+ # sam2 pesudo
436
+ video_sam2_dataset,
437
+ # gcg data
438
+ glamm_psg_dataset,
439
+ glamm_grandf_dataset,
440
+ glamm_flickr_dataset,
441
+ glamm_refcocog_dataset,
442
+ # visual prompt
443
+ image_osprey_dataset, image_osprey_description_dataset,
444
+ image_osprey_part_dataset, image_osprey_short_dataset,
445
+ image_osprey_positive_neg_dataset,
446
+ ]
447
+ )
448
+ train_dataloader = dict(
449
+ batch_size=batch_size,
450
+ num_workers=dataloader_num_workers,
451
+ dataset=train_dataset,
452
+ sampler=dict(
453
+ type=LengthGroupedSampler,
454
+ length_property='modality_length',
455
+ per_device_batch_size=batch_size * accumulative_counts),
456
+ collate_fn=dict(type=video_lisa_collate_fn)
457
+ )
458
+
459
+ #######################################################################
460
+ # PART 4 Scheduler & Optimizer #
461
+ #######################################################################
462
+ # optimizer
463
+ optim_wrapper = dict(
464
+ type=AmpOptimWrapper,
465
+ optimizer=dict(
466
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
467
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
468
+ accumulative_counts=accumulative_counts,
469
+ loss_scale='dynamic',
470
+ dtype='bfloat16'
471
+ )
472
+
473
+ # learning policy
474
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
475
+ param_scheduler = [
476
+ dict(
477
+ type=LinearLR,
478
+ start_factor=1e-5,
479
+ by_epoch=True,
480
+ begin=0,
481
+ end=warmup_ratio * max_epochs,
482
+ convert_to_iter_based=True),
483
+ dict(
484
+ type=CosineAnnealingLR,
485
+ eta_min=0.0,
486
+ by_epoch=True,
487
+ begin=warmup_ratio * max_epochs,
488
+ end=max_epochs,
489
+ convert_to_iter_based=True)
490
+ ]
491
+
492
+ # train, val, test setting
493
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
494
+
495
+ #######################################################################
496
+ # PART 5 Runtime #
497
+ #######################################################################
498
+ # Log the dialogue periodically during the training process, optional
499
+ custom_hooks = [
500
+ # dict(type=DatasetInfoHook, tokenizer=tokenizer),
501
+ ]
502
+
503
+ # configure default hooks
504
+ default_hooks = dict(
505
+ # record the time of every iteration.
506
+ timer=dict(type=IterTimerHook),
507
+ # print log every 10 iterations.
508
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
509
+ # enable the parameter scheduler.
510
+ param_scheduler=dict(type=ParamSchedulerHook),
511
+ # save checkpoint per `save_steps`.
512
+ checkpoint=dict(
513
+ type=CheckpointHook,
514
+ save_optimizer=False,
515
+ by_epoch=False,
516
+ interval=save_steps,
517
+ max_keep_ckpts=save_total_limit),
518
+ # set sampler seed in distributed evrionment.
519
+ sampler_seed=dict(type=DistSamplerSeedHook),
520
+ )
521
+
522
+ # configure environment
523
+ env_cfg = dict(
524
+ # whether to enable cudnn benchmark
525
+ cudnn_benchmark=False,
526
+ # set multi process parameters
527
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
528
+ # set distributed parameters
529
+ dist_cfg=dict(backend='nccl'),
530
+ )
531
+
532
+ # set visualizer
533
+ visualizer = None
534
+
535
+ # set log level
536
+ log_level = 'INFO'
537
+
538
+ # load from which checkpoint
539
+ load_from = None
540
+
541
+ # whether to resume training from the loaded checkpoint
542
+ resume = False
543
+
544
+ # Defaults to use random seed and disable `deterministic`
545
+ randomness = dict(seed=None, deterministic=False)
546
+
547
+ # set log processor
548
+ log_processor = dict(by_epoch=False)
projects/llava_sam2/datasets/ChatUniVi_Dataset.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Literal
4
+
5
+ import torch
6
+ from datasets import Dataset as HFDataset
7
+ from datasets import DatasetDict, load_from_disk
8
+ from mmengine import print_log
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset
11
+ import numpy as np
12
+
13
+ from xtuner.registry import BUILDER
14
+ from xtuner.dataset.huggingface import build_origin_dataset
15
+ import copy
16
+ from .encode_fn import video_lisa_encode_fn
17
+ import json
18
+ import cv2
19
+ import torchvision.transforms as T
20
+ from torchvision.transforms.functional import InterpolationMode
21
+ from decord import VideoReader, cpu
22
+
23
+
24
+ def _get_rawvideo_dec(video_path, select_frames=5):
25
+
26
+ if os.path.exists(video_path):
27
+ vreader = VideoReader(video_path, ctx=cpu(0))
28
+ elif os.path.exists(video_path.replace('mkv', 'mp4')):
29
+ vreader = VideoReader(video_path.replace('mkv', 'mp4'), ctx=cpu(0))
30
+ else:
31
+ print(video_path)
32
+ raise FileNotFoundError
33
+
34
+ fps = vreader.get_avg_fps()
35
+ f_start = 0
36
+ f_end = len(vreader) - 1
37
+ num_frames = f_end - f_start + 1
38
+ assert num_frames > 0, f'num_frames: {num_frames}, f_start: {f_start}, f_end: {f_end}, fps: {fps}, video_path: {video_path}'
39
+ # T x 3 x H x W
40
+ if num_frames <= select_frames:
41
+ sample_pos = range(f_start, f_end + 1)
42
+ else:
43
+ split_point = np.linspace(0, num_frames, num=select_frames+1, dtype=int)
44
+ sample_pos = [np.random.randint(split_point[i], split_point[i+1]) for i in range(select_frames)]
45
+ patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
46
+ return patch_images
47
+
48
+
49
+ class VideoChatUniViDataset(Dataset):
50
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
51
+ IMAGENET_STD = (0.229, 0.224, 0.225)
52
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
53
+ IMG_START_TOKEN = '<img>'
54
+ IMG_END_TOKEN = '</img>'
55
+
56
+ FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
57
+ FAST_IMG_START_TOKEN = '<fast_img>'
58
+ FAST_IMG_END_TOKEN = '</fast_img>'
59
+
60
+ def __init__(self,
61
+ image_folder,
62
+ json_file,
63
+ extra_image_processor=None,
64
+ tokenizer=None,
65
+ sampled_frames=10,
66
+ offline_processed_text_folder=None,
67
+ template_map_fn=None,
68
+ max_length=2048,
69
+ lazy=True,
70
+ repeats=1,
71
+ special_tokens=None,
72
+ use_fast=False,
73
+ n_fast_images=50,
74
+ fast_pool_size=4,
75
+ arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
76
+ preprocessor=None,
77
+ ):
78
+ assert lazy is True
79
+ self.tokenizer = BUILDER.build(tokenizer)
80
+ self.sampled_frames = sampled_frames
81
+ assert offline_processed_text_folder or (json_file and tokenizer)
82
+ self.lazy = lazy
83
+
84
+ self.max_length = max_length
85
+
86
+ self.template_map_fn = template_map_fn
87
+ if isinstance(self.template_map_fn, dict) and self.lazy:
88
+ _type = self.template_map_fn['type']
89
+ del self.template_map_fn['type']
90
+ self.template_map_fn = _type(**self.template_map_fn)
91
+
92
+ if offline_processed_text_folder and json_file:
93
+ print_log(
94
+ 'Both `offline_processed_text_folder` and '
95
+ '`data_path` are set, and we load dataset from'
96
+ '`offline_processed_text_folder` '
97
+ f'({offline_processed_text_folder})',
98
+ logger='current',
99
+ level=logging.WARNING)
100
+
101
+ if offline_processed_text_folder is not None:
102
+ raise NotImplementedError
103
+ else:
104
+ json_datas = self.json_file_preprocess(json_file)
105
+ self.json_datas = json_datas
106
+ json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
107
+ if self.lazy:
108
+ self.text_data = build_origin_dataset(json_data, 'train')
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ self.image_folder = image_folder
113
+ if extra_image_processor is not None:
114
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
115
+
116
+ self.arch_type = arch_type
117
+ if self.arch_type == 'qwen':
118
+ self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
119
+ self.IMG_START_TOKEN = '<|vision_start|>'
120
+ self.IMG_END_TOKEN = '<|vision_end|>'
121
+ elif self.arch_type == 'llava':
122
+ self.IMG_CONTEXT_TOKEN = '<image>'
123
+ self.IMG_START_TOKEN = ''
124
+ self.IMG_END_TOKEN = ''
125
+ self.repeats = repeats
126
+
127
+ self._system = ''
128
+
129
+ self.downsample_ratio = 0.5
130
+ if self.arch_type == 'llava':
131
+ self.downsample_ratio = 1
132
+ self.image_size = 448
133
+ if self.arch_type == 'llava':
134
+ self.image_size = 336
135
+ patch_size = 14
136
+ self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
137
+ if self.arch_type == 'qwen':
138
+ self.patch_token = 1
139
+
140
+ if preprocessor is None:
141
+ self.transformer = T.Compose([
142
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
143
+ T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
144
+ T.ToTensor(),
145
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
146
+ ])
147
+ self.preprocessor = None
148
+ else:
149
+ self.transformer = None
150
+ self.preprocessor = BUILDER.build(preprocessor)
151
+
152
+ self.arch_type = arch_type
153
+
154
+ if special_tokens is not None:
155
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
156
+
157
+ self.use_fast = use_fast
158
+ self.n_fast_images = n_fast_images
159
+ self.fast_pool_size = fast_pool_size
160
+
161
+ # for visualization debug
162
+ self.save_folder = './work_dirs/video_debug/'
163
+ self.cur_number = 0
164
+
165
+ print("Video Chat dataset, include {} items.".format(len(self.text_data)))
166
+
167
+ def __len__(self):
168
+ return len(self.text_data) * self.repeats
169
+
170
+ @property
171
+ def modality_length(self):
172
+ length_list = []
173
+ for data_dict in self.text_data:
174
+ cur_len = 10000
175
+ length_list.append(cur_len)
176
+ return length_list
177
+
178
+ def real_len(self):
179
+ return len(self.text_data)
180
+
181
+ def json_file_preprocess(self, json_file):
182
+ # prepare expression annotation files
183
+ with open(json_file, 'r') as f:
184
+ json_datas = json.load(f)
185
+ return json_datas
186
+
187
+ def dataset_map_fn(self, data_dict, select_k=5):
188
+ assert 'video' in data_dict
189
+ # video
190
+ video_file = data_dict['video']
191
+ video_file = os.path.join(self.image_folder, video_file)
192
+ images = _get_rawvideo_dec(video_file, select_frames=select_k)
193
+ if self.use_fast:
194
+ fast_images = _get_rawvideo_dec(video_file, select_frames=self.n_fast_images)
195
+ else:
196
+ fast_images = None
197
+
198
+ conversation = data_dict['conversations']
199
+
200
+ # prepare text
201
+ if self.use_fast:
202
+ text_dict = self.prepare_text(
203
+ select_k, conversation, num_image_tokens=self.patch_token,
204
+ n_fast_images=len(fast_images),
205
+ )
206
+ else:
207
+ text_dict = self.prepare_text(
208
+ select_k, conversation, num_image_tokens=self.patch_token,
209
+ )
210
+
211
+
212
+ ret = {'images': images, 'conversation': text_dict['conversation'], 'fast_images': fast_images}
213
+ return ret
214
+
215
+ def prepare_text(self, n_frames, conversation, num_image_tokens=256, n_fast_images=0):
216
+
217
+ if self.use_fast:
218
+ fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
219
+ f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
220
+ f'{self.FAST_IMG_END_TOKEN}' + '\n'
221
+ else:
222
+ fast_frame_token_str = ''
223
+
224
+ frame_token_str = f'{self.IMG_START_TOKEN}' \
225
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
226
+ f'{self.IMG_END_TOKEN}'
227
+
228
+ questions = []
229
+ answers = []
230
+
231
+ for conv in conversation:
232
+ if conv['from'] == 'human':
233
+ questions.append(conv['value'].replace('<image>', ''))
234
+ else:
235
+ answers.append(conv['value'])
236
+ assert len(questions) == len(answers)
237
+
238
+ qa_list = []
239
+ for i, (question, answer) in enumerate(zip(questions, answers)):
240
+ if i == 0:
241
+ frame_tokens = frame_token_str + '\n'
242
+ # frame_tokens = '=' + ' '
243
+ frame_tokens = frame_tokens * n_frames
244
+ frame_tokens = frame_tokens.strip()
245
+ frame_tokens = fast_frame_token_str + frame_tokens
246
+ qa_list.append(
247
+ {'from': 'human', 'value': frame_tokens + question}
248
+ )
249
+ else:
250
+ qa_list.append(
251
+ {'from': 'human', 'value': question}
252
+ )
253
+ qa_list.append(
254
+ {'from': 'gpt', 'value': answer}
255
+ )
256
+
257
+ input = ''
258
+ conversation = []
259
+ for msg in qa_list:
260
+ if msg['from'] == 'human':
261
+ input += msg['value']
262
+ elif msg['from'] == 'gpt':
263
+ conversation.append({'input': input, 'output': msg['value']})
264
+ input = ''
265
+ else:
266
+ raise NotImplementedError
267
+
268
+ # add system information
269
+ conversation[0].update({'system': self._system})
270
+ return {'conversation': conversation}
271
+
272
+ def __getitem__(self, index):
273
+ index = index % self.real_len()
274
+ selected_data_dict = copy.deepcopy(self.text_data[index])
275
+ data_dict = self.dataset_map_fn(selected_data_dict, select_k=self.sampled_frames)
276
+
277
+
278
+ assert 'images' in data_dict.keys()
279
+ if self.use_fast:
280
+ assert 'fast_images' in data_dict.keys()
281
+ pixel_values = []
282
+ num_video_tokens = None
283
+ num_frame_tokens = None
284
+ if data_dict.get('images', None) is not None:
285
+ frames_files = data_dict['images']
286
+ for frame_image in frames_files:
287
+ frame_image = frame_image.convert('RGB')
288
+ ori_width, ori_height = frame_image.size
289
+
290
+ if self.preprocessor is not None:
291
+ pass
292
+ else:
293
+ frame_image = self.transformer(frame_image)
294
+ pixel_values.append(frame_image)
295
+
296
+ if self.preprocessor is not None:
297
+ if self.arch_type == 'qwen':
298
+ _data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
299
+ _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
300
+ _data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
301
+ num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
302
+ num_frames = _data_dict['image_grid_thw'].shape[0]
303
+ num_video_tokens = num_frame_tokens * num_frames
304
+ elif self.arch_type == 'llava':
305
+ _data_dict = self.preprocessor(pixel_values, do_resize=True,
306
+ size=(self.image_size, self.image_size))
307
+ _data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
308
+ _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
309
+ else:
310
+ raise NotImplementedError
311
+ data_dict.update(_data_dict)
312
+ else:
313
+ pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
314
+ data_dict['pixel_values'] = pixel_values
315
+ else:
316
+ data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
317
+ data_dict['masks'] = None
318
+
319
+ if num_video_tokens is not None:
320
+ assert self.patch_token == 1
321
+ input_str = data_dict['conversation'][0]['input']
322
+ input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens)
323
+ assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens
324
+ data_dict['conversation'][0]['input'] = input_str
325
+
326
+ result = self.template_map_fn(data_dict)
327
+ data_dict.update(result)
328
+ result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
329
+ data_dict.update(result)
330
+
331
+ # for fast branch
332
+ if self.use_fast:
333
+ fast_pixel_values = []
334
+ frames_files = data_dict['fast_images']
335
+ for frame_image in frames_files:
336
+ frame_image = frame_image.convert('RGB')
337
+ ori_width, ori_height = frame_image.size
338
+
339
+ frame_image = self.transformer(frame_image)
340
+ fast_pixel_values.append(frame_image)
341
+
342
+ fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
343
+ data_dict['fast_pixel_values'] = fast_pixel_values
344
+
345
+
346
+ # # for debug
347
+ # self.visualization_debug(data_dict)
348
+ # if self.cur_number < 10:
349
+ # return self[random.randint(0, len(self))]
350
+
351
+ data_dict['type'] = 'video'
352
+ return data_dict
353
+
354
+ def visualization_debug(self, data_dict):
355
+ save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
356
+ if not os.path.exists(save_folder):
357
+ os.mkdir(save_folder)
358
+ self.cur_number += 1
359
+
360
+ # images
361
+
362
+ show_images = []
363
+
364
+ pixel_values = data_dict['pixel_values']
365
+ save_folder_image = os.path.join(save_folder, 'image')
366
+ if not os.path.exists(save_folder_image):
367
+ os.mkdir(save_folder_image)
368
+ for i_image, image_pixel_value in enumerate(pixel_values):
369
+ # print(image_pixel_value.shape)
370
+ image_pixel_value[0] = image_pixel_value[0] * 0.2686
371
+ image_pixel_value[1] = image_pixel_value[1] * 0.2613
372
+ image_pixel_value[2] = image_pixel_value[2] * 0.2757
373
+ image_pixel_value[0] = image_pixel_value[0] + 0.4814
374
+ image_pixel_value[1] = image_pixel_value[1] + 0.4578
375
+ image_pixel_value[2] = image_pixel_value[2] + 0.4082
376
+ image_pixel_value = image_pixel_value * 255
377
+ image_pixel_value = image_pixel_value.permute(1, 2, 0)
378
+ image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
379
+ # print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
380
+ # print(image_pixel_value.shape)
381
+ show_images.append(image_pixel_value)
382
+ cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
383
+
384
+ # text
385
+ input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
386
+ with open(os.path.join(save_folder, 'text.json'), 'w') as f:
387
+ json.dump([input_text], f)
388
+
389
+ return
projects/llava_sam2/datasets/GCG_Dataset.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import torch
5
+ from datasets import Dataset as HFDataset
6
+ from datasets import DatasetDict, load_from_disk
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+ from pycocotools import mask
10
+ import numpy as np
11
+ import copy
12
+
13
+ from xtuner.registry import BUILDER
14
+ from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
15
+ import torchvision.transforms as T
16
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN
17
+ from torchvision.transforms.functional import InterpolationMode
18
+ from .encode_fn import video_lisa_encode_fn
19
+ from .utils import dynamic_preprocess
20
+
21
+ from .gcg_process import glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn, glamm_refcocog_map_fn
22
+
23
+ class GCGDataset(Dataset):
24
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
25
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
26
+ IMG_START_TOKEN = '<img>'
27
+ IMG_END_TOKEN = '</img>'
28
+
29
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
30
+ IMAGENET_STD = (0.229, 0.224, 0.225)
31
+ def __init__(self,
32
+ image_folder,
33
+ data_path=None,
34
+ tokenizer=None,
35
+ max_length=8196,
36
+ special_tokens=None,
37
+ template_map_fn=None,
38
+ extra_image_processor=None,
39
+ lazy=True,
40
+ repeats=1,
41
+ single_image_mode=False,
42
+ ):
43
+ super().__init__()
44
+ assert lazy
45
+ self.lazy = lazy
46
+ self.max_length = max_length
47
+
48
+ json_data = self.json_file_preprocess(data_path)
49
+ json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
50
+ self.text_data = build_origin_dataset(json_data, 'train')
51
+
52
+ self.image_folder = image_folder
53
+
54
+ self.tokenizer = BUILDER.build(tokenizer)
55
+ if special_tokens is not None:
56
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
57
+
58
+ self.template_map_fn = template_map_fn
59
+ if isinstance(self.template_map_fn, dict) and self.lazy:
60
+ _type = self.template_map_fn['type']
61
+ del self.template_map_fn['type']
62
+ self.template_map_fn = _type(**self.template_map_fn)
63
+
64
+ if extra_image_processor is not None:
65
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
66
+
67
+ self.repeats = repeats
68
+
69
+ self._system = ''
70
+
71
+ self.min_dynamic_patch = 1
72
+ self.max_dynamic_patch = 12
73
+ self.downsample_ratio = 0.5
74
+ self.image_size = 448
75
+ self.use_thumbnail = True
76
+ patch_size = 14
77
+ self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
78
+
79
+ self.transformer = T.Compose([
80
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
81
+ T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
82
+ T.ToTensor(),
83
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
84
+ ])
85
+
86
+ if special_tokens is not None:
87
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
88
+
89
+ self.single_image_mode = single_image_mode
90
+
91
+ def json_file_preprocess(self, data_path):
92
+ with open(data_path, 'r') as f:
93
+ json_data = json.load(f)
94
+ return json_data
95
+
96
+ @property
97
+ def modality_length(self):
98
+ length_list = []
99
+ for data_dict in self.text_data:
100
+ if self.lazy:
101
+ cur_len = 100
102
+ else:
103
+ cur_len = len(data_dict['input_ids'])
104
+ if data_dict.get('image', None) is None:
105
+ cur_len = -cur_len
106
+ length_list.append(cur_len)
107
+ return length_list * self.repeats
108
+
109
+ def __len__(self):
110
+ return len(self.text_data) * self.repeats
111
+
112
+ def real_len(self):
113
+ return len(self.text_data)
114
+
115
+ def decode_mask(self, object_masks, ori_height, ori_width):
116
+ binary_masks = []
117
+ for object_mask in object_masks:
118
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
119
+ for seg in object_mask:
120
+ rles = mask.frPyObjects([seg], ori_height, ori_width)
121
+ m = mask.decode(rles)
122
+ m = m.astype(np.uint8)
123
+ binary_mask += m.squeeze()
124
+
125
+ binary_masks.append(binary_mask)
126
+ if len(binary_masks) == 0:
127
+ return None
128
+ masks = np.stack(binary_masks, axis=0)
129
+ masks = torch.from_numpy(masks)
130
+ return masks
131
+
132
+ def dataset_map_fn(self, data_dict):
133
+ data_dict = glamm_refcocog_map_fn(data_dict)
134
+ return data_dict
135
+
136
+ def replace_image_str(self, data_dict, image_str):
137
+ data_dict['conversation'][0]['input'] = \
138
+ data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
139
+ return data_dict
140
+
141
+ def __getitem__(self, index):
142
+
143
+ index = index % self.real_len()
144
+ data_dict = copy.deepcopy(self.text_data[index])
145
+
146
+ # parse datasets
147
+ result = self.dataset_map_fn(data_dict)
148
+ data_dict.update(result)
149
+
150
+ # process image
151
+ image_file = data_dict['image']
152
+ image = Image.open(os.path.join(self.image_folder,
153
+ image_file)).convert('RGB')
154
+ ori_width, ori_height = image.size
155
+ if hasattr(self, 'extra_image_processor'):
156
+ g_image = np.array(image) # for grounding
157
+ g_image = self.extra_image_processor.apply_image(g_image)
158
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
159
+ data_dict['g_pixel_values'] = g_pixel_values
160
+
161
+ if self.single_image_mode:
162
+ images = [image]
163
+ else:
164
+ images = dynamic_preprocess(image, self.min_dynamic_patch,
165
+ self.max_dynamic_patch,
166
+ self.image_size, self.use_thumbnail)
167
+ pixel_values = [self.transformer(image) for image in images]
168
+ pixel_values = torch.stack(pixel_values)
169
+ data_dict['pixel_values'] = pixel_values
170
+
171
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
172
+ image_token_str = f'{self.IMG_START_TOKEN}' \
173
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
174
+ f'{self.IMG_END_TOKEN}'
175
+
176
+ data_dict = self.replace_image_str(data_dict, image_token_str)
177
+
178
+ result = self.template_map_fn(data_dict)
179
+ data_dict.update(result)
180
+ result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
181
+ with_image_token=True)
182
+ data_dict.update(result)
183
+ # process mask
184
+ data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
185
+
186
+ if data_dict['masks'] is None:
187
+ return self.__getitem__(0)
188
+
189
+ return data_dict
190
+
191
+ class RefCOCOgGCGDataset(GCGDataset):
192
+ def __init__(self,
193
+ image_folder,
194
+ data_path=None,
195
+ tokenizer=None,
196
+ max_length=8196,
197
+ special_tokens=None,
198
+ template_map_fn=None,
199
+ extra_image_processor=None,
200
+ lazy=True,
201
+ repeats=1,
202
+ single_image_mode=False,
203
+ ):
204
+ super().__init__(
205
+ image_folder=image_folder,
206
+ data_path=data_path,
207
+ tokenizer=tokenizer,
208
+ max_length=max_length,
209
+ special_tokens=special_tokens,
210
+ template_map_fn=template_map_fn,
211
+ extra_image_processor=extra_image_processor,
212
+ lazy=lazy,
213
+ repeats=repeats,
214
+ single_image_mode=single_image_mode,
215
+ )
216
+
217
+ def json_file_preprocess(self, data_path):
218
+ json_data = json.load(open(data_path))
219
+
220
+ # convert {id: dict} to dict(..., id=xx)
221
+ for idx in range(len(json_data)):
222
+ id = list(json_data[idx].keys())[0]
223
+ json_data[idx] = json_data[idx][id]
224
+ json_data[idx].update({'id': id})
225
+ return json_data
226
+
227
+ class GranDfGCGDataset(GCGDataset):
228
+ def __init__(self,
229
+ image_folder,
230
+ data_path=None,
231
+ tokenizer=None,
232
+ max_length=8196,
233
+ special_tokens=None,
234
+ template_map_fn=None,
235
+ extra_image_processor=None,
236
+ lazy=True,
237
+ repeats=1,
238
+ single_image_mode=False,
239
+ ):
240
+ super().__init__(
241
+ image_folder=image_folder,
242
+ data_path=data_path,
243
+ tokenizer=tokenizer,
244
+ max_length=max_length,
245
+ special_tokens=special_tokens,
246
+ template_map_fn=template_map_fn,
247
+ extra_image_processor=extra_image_processor,
248
+ lazy=lazy,
249
+ repeats=repeats,
250
+ single_image_mode=single_image_mode,
251
+ )
252
+
253
+ def dataset_map_fn(self, data_dict):
254
+ data_dict = glamm_granf_map_fn(data_dict)
255
+ return data_dict
256
+
257
+ def decode_mask(self, object_masks, ori_height, ori_width):
258
+ binary_masks = []
259
+ for object_mask in object_masks:
260
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
261
+
262
+ for rle in object_mask:
263
+ m = mask.decode(rle).astype(np.uint8)
264
+ binary_mask += m.squeeze()
265
+
266
+ binary_masks.append(binary_mask)
267
+ if len(binary_masks) == 0:
268
+ return None
269
+ masks = np.stack(binary_masks, axis=0)
270
+ masks = torch.from_numpy(masks)
271
+ return masks
272
+
273
+ class OpenPsgGCGDataset(GranDfGCGDataset):
274
+ def __init__(self,
275
+ image_folder,
276
+ data_path=None,
277
+ tokenizer=None,
278
+ max_length=8196,
279
+ special_tokens=None,
280
+ template_map_fn=None,
281
+ extra_image_processor=None,
282
+ lazy=True,
283
+ repeats=1,
284
+ single_image_mode=False,
285
+ ):
286
+ super().__init__(
287
+ image_folder=image_folder,
288
+ data_path=data_path,
289
+ tokenizer=tokenizer,
290
+ max_length=max_length,
291
+ special_tokens=special_tokens,
292
+ template_map_fn=template_map_fn,
293
+ extra_image_processor=extra_image_processor,
294
+ lazy=lazy,
295
+ repeats=repeats,
296
+ single_image_mode=single_image_mode,
297
+ )
298
+ def dataset_map_fn(self, data_dict):
299
+ data_dict = glamm_openpsg_map_fn(data_dict)
300
+ return data_dict
301
+
302
+
303
+ class FlickrGCGDataset(GCGDataset):
304
+ def __init__(self,
305
+ image_folder,
306
+ data_path=None,
307
+ tokenizer=None,
308
+ max_length=8196,
309
+ special_tokens=None,
310
+ template_map_fn=None,
311
+ extra_image_processor=None,
312
+ lazy=True,
313
+ repeats=1,
314
+ single_image_mode=False,
315
+ ):
316
+ super().__init__(
317
+ image_folder=image_folder,
318
+ data_path=data_path,
319
+ tokenizer=tokenizer,
320
+ max_length=max_length,
321
+ special_tokens=special_tokens,
322
+ template_map_fn=template_map_fn,
323
+ extra_image_processor=extra_image_processor,
324
+ lazy=lazy,
325
+ repeats=repeats,
326
+ single_image_mode=single_image_mode,
327
+ )
328
+
329
+ def dataset_map_fn(self, data_dict):
330
+ data_dict = glamm_flickr_map_fn(data_dict)
331
+ return data_dict
332
+
333
+ def json_file_preprocess(self, data_path):
334
+ def filter_images(data_infos, min_size):
335
+ return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size]
336
+
337
+ # convert {id: dict} to dict(..., id=xx)
338
+ from pycocotools.coco import COCO
339
+ self.coco = COCO(data_path)
340
+ self.image_ids = self.coco.getImgIds()
341
+ data_infos = []
342
+ total_ann_ids = []
343
+ removed_img_count = 0
344
+ for img_id in self.image_ids:
345
+ info = self.coco.loadImgs([img_id])[0]
346
+ if len(info['caption'].split(' ')) < 3:
347
+ removed_img_count += 1
348
+ continue
349
+ info['filename'] = info['file_name'].split('_')[-1]
350
+ info['height'] = int(info['height'])
351
+ info['width'] = int(info['width'])
352
+ data_infos.append(info)
353
+ ann_ids = self.coco.getAnnIds(imgIds=[img_id])
354
+ total_ann_ids.extend(ann_ids)
355
+ assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!"
356
+ print(f'Removed {removed_img_count} images.')
357
+ data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)]
358
+
359
+ # obtain_annotations
360
+ for data_info in data_infos:
361
+ ann_ids = self.coco.getAnnIds(imgIds=data_info['id'])
362
+ ann_info = self.coco.loadAnns(ann_ids)
363
+ data_info.update({'ann_info': ann_info})
364
+ return data_infos
365
+
366
+ def decode_mask(self, object_masks, ori_height, ori_width):
367
+ binary_masks = []
368
+ for object_mask in object_masks:
369
+ binary_mask = mask.decode(object_mask).astype(np.uint8)
370
+ binary_masks.append(binary_mask)
371
+ if len(binary_masks) == 0:
372
+ return None
373
+ masks = np.stack(binary_masks, axis=0)
374
+ masks = torch.from_numpy(masks)
375
+ return masks
projects/llava_sam2/datasets/Grand_Dataset.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import torch
6
+ from datasets import Dataset as HFDataset
7
+ from datasets import DatasetDict, load_from_disk
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+ from pycocotools import mask
11
+ import numpy as np
12
+ import copy
13
+
14
+ from xtuner.registry import BUILDER
15
+ from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
16
+ import torchvision.transforms as T
17
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN
18
+ from torchvision.transforms.functional import InterpolationMode
19
+ from .encode_fn import video_lisa_encode_fn
20
+ from .utils import dynamic_preprocess
21
+
22
+ from .grand_process import glamm_grand_map_fn
23
+
24
+ class GranDDataset(Dataset):
25
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
26
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
27
+ IMG_START_TOKEN = '<img>'
28
+ IMG_END_TOKEN = '</img>'
29
+
30
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
31
+ IMAGENET_STD = (0.229, 0.224, 0.225)
32
+ def __init__(self,
33
+ image_folder,
34
+ json_folder=None,
35
+ tokenizer=None,
36
+ max_length=8196,
37
+ special_tokens=None,
38
+ template_map_fn=None,
39
+ extra_image_processor=None,
40
+ lazy=True,
41
+ repeats=1,
42
+ single_image_mode=False,
43
+ image_list_save_path='./work_dirs/grand_image.json',
44
+ json_list_save_path='./work_dirs/grand_jsons.json',
45
+ ):
46
+ super().__init__()
47
+ assert lazy
48
+ self.lazy = lazy
49
+ self.max_length = max_length
50
+
51
+ self.image_list_save_path = image_list_save_path
52
+ self.json_list_save_path = json_list_save_path
53
+
54
+ json_files, image_path_dict = self.json_file_preprocess(image_folder, json_folder)
55
+ self.json_data = json_files
56
+ self.image_path_dict = image_path_dict
57
+
58
+ self.image_folder = image_folder
59
+
60
+ self.tokenizer = BUILDER.build(tokenizer)
61
+ if special_tokens is not None:
62
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
63
+
64
+ self.template_map_fn = template_map_fn
65
+ if isinstance(self.template_map_fn, dict) and self.lazy:
66
+ _type = self.template_map_fn['type']
67
+ del self.template_map_fn['type']
68
+ self.template_map_fn = _type(**self.template_map_fn)
69
+
70
+ if extra_image_processor is not None:
71
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
72
+
73
+ self.repeats = repeats
74
+
75
+ self._system = ''
76
+
77
+ self.min_dynamic_patch = 1
78
+ self.max_dynamic_patch = 12
79
+ self.downsample_ratio = 0.5
80
+ self.image_size = 448
81
+ self.use_thumbnail = True
82
+ patch_size = 14
83
+ self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
84
+
85
+ self.transformer = T.Compose([
86
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
87
+ T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
88
+ T.ToTensor(),
89
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
90
+ ])
91
+
92
+ if special_tokens is not None:
93
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
94
+
95
+ self.single_image_mode = single_image_mode
96
+
97
+ def json_file_preprocess(self, image_folder, json_folder):
98
+
99
+ # list jsons
100
+ print("Processing GRAND json files !!!")
101
+ if os.path.exists(self.json_list_save_path):
102
+ with open(self.json_list_save_path, 'r') as f:
103
+ json_files = json.load(f)
104
+ else:
105
+ json_files = os.listdir(json_folder)
106
+ _json_files = []
107
+ for _file in json_files:
108
+ if '.json' in _file:
109
+ _json_files.append(os.path.join(json_folder, _file))
110
+ json_files = _json_files
111
+ with open(self.json_list_save_path, 'w') as f:
112
+ json.dump(json_files, f)
113
+ print(f"Finished, {len(json_files)} json files !")
114
+
115
+ # list images
116
+ print("Processing GRAND image files !!!")
117
+ if os.path.exists(self.image_list_save_path):
118
+ with open(self.image_list_save_path, 'r') as f:
119
+ image_path_dict = json.load(f)
120
+ else:
121
+ sub_folders = os.listdir(image_folder)
122
+ _sub_folders = []
123
+ for folder_name in sub_folders:
124
+ if 'sa_00' in folder_name:
125
+ _sub_folders.append(folder_name)
126
+ sub_folders = _sub_folders
127
+ sub_folders = [os.path.join(image_folder, folder_name) for folder_name in sub_folders]
128
+
129
+ image_path_dict = {}
130
+ for sub_folder in sub_folders:
131
+ files = os.listdir(sub_folder)
132
+ for _file in files:
133
+ if '.jpg' in _file:
134
+ image_path_dict[_file] = os.path.join(sub_folder, _file)
135
+
136
+ with open(self.image_list_save_path, 'w') as f:
137
+ json.dump(image_path_dict, f)
138
+ print(f"Finished, {len(image_path_dict)} image files !")
139
+
140
+ return json_files, image_path_dict
141
+
142
+ @property
143
+ def modality_length(self):
144
+ length_list = [10000] * len(self.json_data)
145
+ return length_list * self.repeats
146
+
147
+ def __len__(self):
148
+ return len(self.json_data) * self.repeats
149
+
150
+ def real_len(self):
151
+ return len(self.json_data)
152
+
153
+ def decode_mask(self, object_masks, ori_height, ori_width):
154
+ binary_masks = []
155
+ for object_mask in object_masks:
156
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
157
+ for seg in object_mask:
158
+ m = mask.decode(seg)
159
+ m = m.astype(np.uint8)
160
+ binary_mask += m.squeeze()
161
+
162
+ binary_masks.append(binary_mask)
163
+ if len(binary_masks) == 0:
164
+ return None
165
+ masks = np.stack(binary_masks, axis=0)
166
+ masks = torch.from_numpy(masks)
167
+ return masks
168
+
169
+ def dataset_map_fn(self, data_dict):
170
+ data_dict = glamm_grand_map_fn(data_dict)
171
+ return data_dict
172
+
173
+ def replace_image_str(self, data_dict, image_str):
174
+ data_dict['conversation'][0]['input'] = \
175
+ data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
176
+ return data_dict
177
+
178
+ def __getitem__(self, index):
179
+
180
+ index = index % self.real_len()
181
+ json_file_path = self.json_data[index]
182
+ with open(json_file_path, 'r') as f:
183
+ json_dict = json.load(f)
184
+
185
+ image_name = list(json_dict.keys())[0]
186
+
187
+ if image_name not in self.image_path_dict.keys():
188
+ return self.__getitem__(random.randint(0, len(self.json_data) - 1))
189
+ image_path = self.image_path_dict[image_name]
190
+
191
+ json_dict = json_dict[image_name]
192
+ # parse datasets
193
+ result = self.dataset_map_fn(json_dict)
194
+ json_dict.update(result)
195
+ data_dict = json_dict
196
+
197
+ data_dict['image'] = image_path
198
+
199
+ # process image
200
+ image_file = data_dict['image']
201
+ try:
202
+ image = Image.open(os.path.join(self.image_folder,
203
+ image_file)).convert('RGB')
204
+ except:
205
+ return self.__getitem__(random.randint(0, len(self.json_data) - 1))
206
+ ori_width, ori_height = image.size
207
+ if hasattr(self, 'extra_image_processor'):
208
+ g_image = np.array(image) # for grounding
209
+ g_image = self.extra_image_processor.apply_image(g_image)
210
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
211
+ data_dict['g_pixel_values'] = g_pixel_values
212
+
213
+ if self.single_image_mode:
214
+ images = [image]
215
+ else:
216
+ images = dynamic_preprocess(image, self.min_dynamic_patch,
217
+ self.max_dynamic_patch,
218
+ self.image_size, self.use_thumbnail)
219
+ pixel_values = [self.transformer(image) for image in images]
220
+ pixel_values = torch.stack(pixel_values)
221
+ data_dict['pixel_values'] = pixel_values
222
+
223
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
224
+ image_token_str = f'{self.IMG_START_TOKEN}' \
225
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
226
+ f'{self.IMG_END_TOKEN}'
227
+
228
+ data_dict = self.replace_image_str(data_dict, image_token_str)
229
+
230
+ result = self.template_map_fn(data_dict)
231
+ data_dict.update(result)
232
+ result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
233
+ with_image_token=True)
234
+ data_dict.update(result)
235
+ # process mask
236
+ data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
237
+
238
+ if data_dict['masks'] is None:
239
+ return self.__getitem__(random.randint(0, len(self.json_data) - 1))
240
+
241
+ return data_dict
projects/llava_sam2/datasets/MeVIS_Dataset.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .ReVOS_Dataset import VideoReVOSDataset
2
+
3
+
4
+ class VideoMeVISDataset(VideoReVOSDataset):
5
+ pass
projects/llava_sam2/datasets/Osprey_Dataset.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import torch
5
+ from datasets import Dataset as HFDataset
6
+ from datasets import DatasetDict, load_from_disk
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+ from pycocotools import mask as maskUtils
10
+ import numpy as np
11
+ import copy
12
+
13
+ from xtuner.registry import BUILDER
14
+ from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
15
+ import torchvision.transforms as T
16
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN
17
+ from torchvision.transforms.functional import InterpolationMode
18
+ from .encode_fn import video_lisa_encode_fn
19
+ from .utils import dynamic_preprocess
20
+
21
+ import random
22
+
23
+ import torch.nn.functional as F
24
+
25
+ class OspreyDataset(Dataset):
26
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
27
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
28
+ IMG_START_TOKEN = '<img>'
29
+ IMG_END_TOKEN = '</img>'
30
+
31
+ LIMIT = ''
32
+
33
+ VP_START_TOKEN = '<vp>'
34
+ VP_END_TOKEN = '</vp>'
35
+
36
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
37
+ IMAGENET_STD = (0.229, 0.224, 0.225)
38
+ def __init__(self,
39
+ image_folder,
40
+ data_path=None,
41
+ tokenizer=None,
42
+ max_length=8196,
43
+ special_tokens=None,
44
+ template_map_fn=None,
45
+ extra_image_processor=None,
46
+ lazy=True,
47
+ repeats=1,
48
+ single_image_mode=False,
49
+ ):
50
+ super().__init__()
51
+ assert lazy
52
+ self.lazy = lazy
53
+ self.max_length = max_length
54
+
55
+ json_data = self.json_file_preprocess(data_path)
56
+ self.text_data = json_data
57
+
58
+ self.image_folder = image_folder
59
+
60
+ self.tokenizer = BUILDER.build(tokenizer)
61
+ if special_tokens is not None:
62
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
63
+
64
+ self.template_map_fn = template_map_fn
65
+ if isinstance(self.template_map_fn, dict) and self.lazy:
66
+ _type = self.template_map_fn['type']
67
+ del self.template_map_fn['type']
68
+ self.template_map_fn = _type(**self.template_map_fn)
69
+
70
+ if extra_image_processor is not None:
71
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
72
+
73
+ self.repeats = repeats
74
+
75
+ self._system = ''
76
+
77
+ self.min_dynamic_patch = 1
78
+ self.max_dynamic_patch = 12
79
+ self.downsample_ratio = 0.5
80
+ self.image_size = 448
81
+ self.use_thumbnail = True
82
+ patch_size = 14
83
+ self.patch_size = patch_size
84
+ self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
85
+
86
+ self.transformer = T.Compose([
87
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
88
+ T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
89
+ T.ToTensor(),
90
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
91
+ ])
92
+
93
+ if special_tokens is not None:
94
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
95
+
96
+ self.single_image_mode = single_image_mode
97
+
98
+ def json_file_preprocess(self, data_path):
99
+ with open(data_path, 'r') as f:
100
+ json_data = json.load(f)
101
+ return json_data
102
+
103
+ @property
104
+ def modality_length(self):
105
+ length_list = []
106
+ for data_dict in self.text_data:
107
+ if self.lazy:
108
+ cur_len = 100
109
+ else:
110
+ cur_len = len(data_dict['input_ids'])
111
+ if data_dict.get('image', None) is None:
112
+ cur_len = -cur_len
113
+ length_list.append(cur_len)
114
+ return length_list * self.repeats
115
+
116
+ def __len__(self):
117
+ return len(self.text_data) * self.repeats
118
+
119
+ def real_len(self):
120
+ return len(self.text_data)
121
+
122
+ def annToMask(self, mask_ann, h, w):
123
+ if isinstance(mask_ann, list):
124
+ rles = maskUtils.frPyObjects(mask_ann, h, w)
125
+ rle = maskUtils.merge(rles)
126
+ elif isinstance(mask_ann['counts'], list):
127
+ # uncompressed RLE
128
+ rle = maskUtils.frPyObjects(mask_ann, h, w)
129
+ else:
130
+ # rle
131
+ rle = mask_ann
132
+ mask = maskUtils.decode(rle)
133
+ return mask
134
+
135
+ def decode_mask(self, object_masks, ori_height, ori_width):
136
+ binary_masks = []
137
+ for object_mask in object_masks:
138
+ binary_mask = self.annToMask(object_mask, ori_height, ori_width)
139
+ binary_masks.append(binary_mask)
140
+ if len(binary_masks) == 0:
141
+ return None
142
+ masks = np.stack(binary_masks, axis=0)
143
+ masks = torch.from_numpy(masks)
144
+ return masks
145
+
146
+ def _process_conversation(self, converations, n_regions, region_pixels):
147
+ start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions)
148
+ for i in range(n_regions):
149
+ start_region_str = start_region_str + \
150
+ f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN
151
+ if i == n_regions - 1:
152
+ start_region_str = start_region_str + '.\n'
153
+ else:
154
+ start_region_str = start_region_str + ', '
155
+
156
+ for i, item in enumerate(converations):
157
+ item['value'] = item['value'].replace('<', '').replace('>', '')
158
+ if item['from'] == 'human':
159
+ item['value'] = item['value'] + self.LIMIT
160
+ # first conv process
161
+ if i == 0:
162
+ assert item['from'] == "human"
163
+ item['value'] = start_region_str + item['value']
164
+
165
+ messages = converations
166
+ input = ''
167
+
168
+ conversation = []
169
+ while messages and messages[0]['from'] == 'gpt':
170
+ # Skip the first one if it is from gpt
171
+ messages = messages[1:]
172
+ for msg in messages:
173
+ if msg['from'] == 'human':
174
+ if DEFAULT_IMAGE_TOKEN in msg['value']:
175
+ msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
176
+ '').strip()
177
+ msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
178
+ msg['value'] = msg['value'].strip()
179
+ input += msg['value']
180
+
181
+ elif msg['from'] == 'gpt':
182
+ conversation.append({'input': input, 'output': msg['value']})
183
+ input = ''
184
+ else:
185
+ raise NotImplementedError
186
+
187
+ return conversation
188
+
189
+ def _get_region_infos(self, masks):
190
+ # masks tensor, (n_obj, h, w)
191
+ masks = F.interpolate(
192
+ masks.unsqueeze(0),
193
+ size=(int(self.image_size // self.patch_size * self.downsample_ratio),
194
+ int(self.image_size // self.patch_size * self.downsample_ratio)),
195
+ mode='nearest').squeeze(0)
196
+ region_pixels = []
197
+ for mask in masks:
198
+ region_pixels.append(mask.bool().to(torch.int64).sum())
199
+ return masks, region_pixels
200
+
201
+ def dataset_map_fn(self, data_dict):
202
+ file_name = data_dict['file_name'] # image file name
203
+ conversations = data_dict['conversations']
204
+ masks = [anno["segmentation"] for anno in data_dict["annotation"]]
205
+ height = data_dict['height']
206
+ width = data_dict['width']
207
+ _ret = {}
208
+
209
+ _ret['image'] = file_name
210
+ _ret['height'] = height
211
+ _ret['width'] = width
212
+
213
+ masks = self.decode_mask(masks, height, width)
214
+ masks, region_pixels = self._get_region_infos(masks)
215
+
216
+ if masks is None:
217
+ return None
218
+
219
+ conversations = self._process_conversation(conversations, len(masks), region_pixels)
220
+ _ret['conversation'] = conversations
221
+ _ret['prompt_masks'] = masks
222
+ return _ret
223
+
224
+ def replace_image_str(self, data_dict, image_str):
225
+ data_dict['conversation'][0]['input'] = \
226
+ data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
227
+ return data_dict
228
+
229
+ def __getitem__(self, index):
230
+
231
+ index = index % self.real_len()
232
+ data_dict = copy.deepcopy(self.text_data[index])
233
+
234
+ # parse datasets
235
+ result = self.dataset_map_fn(data_dict) # {'image', 'height', 'width', 'conversation', 'masks'}
236
+ if result is None or result['prompt_masks'] is None:
237
+ return self.__getitem__(0)
238
+
239
+ data_dict = result
240
+
241
+ # process image
242
+ image_file = data_dict['image']
243
+ if isinstance(self.image_folder, list):
244
+ for image_folder in self.image_folder:
245
+ image_path = os.path.join(image_folder, image_file)
246
+ if os.path.exists(image_path):
247
+ image = Image.open(image_path).convert('RGB')
248
+ break
249
+ else:
250
+ image = Image.open(os.path.join(self.image_folder,
251
+ image_file)).convert('RGB')
252
+ ori_width, ori_height = image.size
253
+
254
+ if self.single_image_mode:
255
+ images = [image]
256
+ else:
257
+ images = dynamic_preprocess(image, self.min_dynamic_patch,
258
+ self.max_dynamic_patch,
259
+ self.image_size, self.use_thumbnail)
260
+ vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
261
+ data_dict['vp_overall_mask'] = vp_overall_mask
262
+
263
+ pixel_values = [self.transformer(image) for image in images]
264
+ pixel_values = torch.stack(pixel_values)
265
+ data_dict['pixel_values'] = pixel_values
266
+
267
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
268
+ image_token_str = f'{self.IMG_START_TOKEN}' \
269
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
270
+ f'{self.IMG_END_TOKEN}'
271
+
272
+ data_dict = self.replace_image_str(data_dict, image_token_str)
273
+
274
+ result = self.template_map_fn(data_dict)
275
+ data_dict.update(result)
276
+ result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
277
+ with_image_token=True)
278
+ data_dict.update(result)
279
+ # process mask
280
+ # data_dict['prompt_masks'] = data_dict['prompt_masks']
281
+
282
+ if data_dict['prompt_masks'] is None:
283
+ return self.__getitem__(0)
284
+
285
+ return data_dict
286
+
287
+
288
+ DETAILED_QUESTIONS = [
289
+ 'Can you provide me with a detailed description of the region in the picture marked by <region>?',
290
+ "I'm curious about the region represented by <region> in the picture. Could you describe it in detail?",
291
+ 'What can you tell me about the region indicated by <region> in the image?',
292
+ "I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?",
293
+ 'Could you describe the region shown as <region> in the picture in great detail?',
294
+ 'What details can you give me about the region outlined by <region> in the photo?',
295
+ 'Please provide me with a comprehensive description of the region marked with <region> in the image.',
296
+ 'Can you give me a detailed account of the region labeled as <region> in the picture?',
297
+ "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?",
298
+ 'What is the region outlined by <region> in the picture like? Could you give me a detailed description?',
299
+ 'Can you provide me with a detailed description of the region in the picture marked by <region>, please?',
300
+ "I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?",
301
+ 'What can you tell me about the region indicated by <region> in the image, exactly?',
302
+ "I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?",
303
+ 'Could you describe the region shown as <region> in the picture in great detail, please?',
304
+ 'What details can you give me about the region outlined by <region> in the photo, please?',
305
+ 'Please provide me with a comprehensive description of the region marked with <region> in the image, please.',
306
+ 'Can you give me a detailed account of the region labeled as <region> in the picture, please?',
307
+ "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?",
308
+ 'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?',
309
+ 'Please describe the region <region> in the image in detail.',
310
+ 'Can you offer a thorough analysis of the region <region> in the image?',
311
+ 'Could you elaborate on the region highlighted by <region> in the picture provided?',
312
+ 'Please share more information about the zone emphasized with <region> in the photo.',
313
+ 'What insights can you give about the area denoted by <region> in the image presented?',
314
+ 'Can you share a comprehensive rundown of the region denoted by <region> in the presented image?',
315
+ "I'd like to know more about the region highlighted by <region> in the picture provided.",
316
+ 'Work through the important details of the area <region> in the image.',
317
+ 'Illustrate the area represented by <region> through a descriptive explanation.',
318
+ 'Examine the region <region> closely and share its details.'
319
+ ]
320
+
321
+ class OspreyDescriptionDataset(OspreyDataset):
322
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
323
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
324
+ IMG_START_TOKEN = '<img>'
325
+ IMG_END_TOKEN = '</img>'
326
+
327
+ VP_START_TOKEN = '<vp>'
328
+ VP_END_TOKEN = '</vp>'
329
+
330
+ LIMIT=''
331
+
332
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
333
+ IMAGENET_STD = (0.229, 0.224, 0.225)
334
+ def __init__(self,
335
+ image_folder,
336
+ data_path=None,
337
+ tokenizer=None,
338
+ max_length=8196,
339
+ special_tokens=None,
340
+ template_map_fn=None,
341
+ extra_image_processor=None,
342
+ lazy=True,
343
+ repeats=1,
344
+ single_image_mode=False,
345
+ ):
346
+ super(OspreyDescriptionDataset, self).__init__(
347
+ image_folder=image_folder,
348
+ data_path=data_path,
349
+ tokenizer=tokenizer,
350
+ max_length=max_length,
351
+ special_tokens=special_tokens,
352
+ template_map_fn=template_map_fn,
353
+ extra_image_processor=extra_image_processor,
354
+ lazy=lazy,
355
+ repeats=repeats,
356
+ single_image_mode=single_image_mode,
357
+ )
358
+
359
+ def dataset_map_fn(self, data_dict):
360
+ file_name = data_dict['file_name'] # image file name
361
+ descriptions = data_dict['description']
362
+ masks = [anno["segmentation"] for anno in data_dict["annotation"]]
363
+ height = data_dict['height']
364
+ width = data_dict['width']
365
+ _ret = {}
366
+
367
+ _ret['image'] = file_name
368
+ _ret['height'] = height
369
+ _ret['width'] = width
370
+
371
+ masks = self.decode_mask(masks, height, width)
372
+ masks, region_pixels = self._get_region_infos(masks)
373
+
374
+ if masks is None:
375
+ return None
376
+
377
+ conversations = self._process_conversation(descriptions, len(masks), region_pixels)
378
+ _ret['conversation'] = conversations
379
+ _ret['prompt_masks'] = masks
380
+ return _ret
381
+
382
+ def _process_conversation(self, descriptions, n_regions, region_pixels):
383
+ start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions)
384
+ for i in range(n_regions):
385
+ start_region_str = start_region_str + \
386
+ f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN
387
+ if i == n_regions - 1:
388
+ start_region_str = start_region_str + '.\n'
389
+ else:
390
+ start_region_str = start_region_str + ', '
391
+
392
+ converations = []
393
+ for i, item in enumerate(descriptions):
394
+ question = random.choice(DETAILED_QUESTIONS).strip().replace('<region>', f"region{i+1}") + self.LIMIT
395
+ answer = item.replace('<', '').replace('>', '')
396
+ # first conv process
397
+ if i == 0:
398
+ question = start_region_str + question
399
+ converations.append({'from': 'human', 'value': question})
400
+ converations.append({'from': 'gpt', 'value': answer})
401
+
402
+ messages = converations
403
+ input = ''
404
+
405
+ conversation = []
406
+ while messages and messages[0]['from'] == 'gpt':
407
+ # Skip the first one if it is from gpt
408
+ messages = messages[1:]
409
+ for msg in messages:
410
+ if msg['from'] == 'human':
411
+ if DEFAULT_IMAGE_TOKEN in msg['value']:
412
+ msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
413
+ '').strip()
414
+ msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
415
+ msg['value'] = msg['value'].strip()
416
+ input += msg['value']
417
+
418
+ elif msg['from'] == 'gpt':
419
+ conversation.append({'input': input, 'output': msg['value']})
420
+ input = ''
421
+ else:
422
+ raise NotImplementedError
423
+ return conversation
424
+
425
+
426
+ class OspreyShortDescriptionDataset(OspreyDataset):
427
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
428
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
429
+ IMG_START_TOKEN = '<img>'
430
+ IMG_END_TOKEN = '</img>'
431
+
432
+ VP_START_TOKEN = '<vp>'
433
+ VP_END_TOKEN = '</vp>'
434
+
435
+ LIMIT = ' Answer the question using a single word or phrase.'
436
+
437
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
438
+ IMAGENET_STD = (0.229, 0.224, 0.225)
439
+
440
+ def __init__(self,
441
+ image_folder,
442
+ data_path=None,
443
+ tokenizer=None,
444
+ max_length=8196,
445
+ special_tokens=None,
446
+ template_map_fn=None,
447
+ extra_image_processor=None,
448
+ lazy=True,
449
+ repeats=1,
450
+ single_image_mode=False,
451
+ ):
452
+ super(OspreyShortDescriptionDataset, self).__init__(
453
+ image_folder=image_folder,
454
+ data_path=data_path,
455
+ tokenizer=tokenizer,
456
+ max_length=max_length,
457
+ special_tokens=special_tokens,
458
+ template_map_fn=template_map_fn,
459
+ extra_image_processor=extra_image_processor,
460
+ lazy=lazy,
461
+ repeats=repeats,
462
+ single_image_mode=single_image_mode,
463
+ )
projects/llava_sam2/datasets/ReSAM2_Dataset.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ from datasets import Dataset as HFDataset
5
+ from datasets import DatasetDict, load_from_disk
6
+ from mmengine import print_log
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+ import numpy as np
10
+
11
+ from xtuner.registry import BUILDER
12
+ from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
13
+ import copy
14
+ from .encode_fn import video_lisa_encode_fn
15
+ import json
16
+ import random
17
+ import pycocotools.mask as maskUtils
18
+ import cv2
19
+ import torchvision.transforms as T
20
+ from torchvision.transforms.functional import InterpolationMode
21
+
22
+ SEG_QUESTIONS = [
23
+ "Please segment the object according to the description: {class_name}",
24
+ ]
25
+
26
+ SEG_QUESTIONS_SHORT = [
27
+ "Can you segment the {class_name} in this image?",
28
+ "Please segment {class_name} in this image.",
29
+ "What is {class_name} in this image? Please respond with segmentation mask.",
30
+ "What is {class_name} in this image? Please output segmentation mask.",
31
+
32
+ "Can you segment the {class_name} in this image",
33
+ "Please segment {class_name} in this image",
34
+ "What is {class_name} in this image? Please respond with segmentation mask",
35
+ "What is {class_name} in this image? Please output segmentation mask",
36
+
37
+ "Could you provide a segmentation mask for the {class_name} in this image?",
38
+ "Please identify and segment the {class_name} in this image.",
39
+ "Where is the {class_name} in this picture? Please respond with a segmentation mask.",
40
+ "Can you highlight the {class_name} in this image with a segmentation mask?",
41
+
42
+ "Could you provide a segmentation mask for the {class_name} in this image",
43
+ "Please identify and segment the {class_name} in this image",
44
+ "Where is the {class_name} in this picture? Please respond with a segmentation mask",
45
+ "Can you highlight the {class_name} in this image with a segmentation mask",
46
+ ]
47
+
48
+ ANSWER_LIST = [
49
+ "It is [SEG].",
50
+ "Sure, [SEG].",
51
+ "Sure, it is [SEG].",
52
+ "Sure, the segmentation result is [SEG].",
53
+ "[SEG].",
54
+ ]
55
+
56
+ class VideoSAM2Dataset(Dataset):
57
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
58
+ IMAGENET_STD = (0.229, 0.224, 0.225)
59
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
60
+ IMG_START_TOKEN = '<img>'
61
+ IMG_END_TOKEN = '</img>'
62
+
63
+ FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
64
+ FAST_IMG_START_TOKEN = '<fast_img>'
65
+ FAST_IMG_END_TOKEN = '</fast_img>'
66
+
67
+ def __init__(self,
68
+ sam2_folder,
69
+ expression_file,
70
+ extra_image_processor=None,
71
+ tokenizer=None,
72
+ select_number=5,
73
+ sampled_frames=5,
74
+ offline_processed_text_folder=None,
75
+ template_map_fn=None,
76
+ max_length=8196,
77
+ lazy=True,
78
+ repeats=1,
79
+ special_tokens=None,
80
+ use_fast=False,
81
+ n_fast_images=50,
82
+ fast_pool_size=4,
83
+ mode='long',
84
+ frame_contiguous_sample=False,
85
+ ):
86
+ assert mode in ['long', 'long_short', 'short']
87
+ self.mode = mode
88
+ self.cur_mode = mode
89
+ assert lazy is True
90
+ self.tokenizer = BUILDER.build(tokenizer)
91
+ self.select_number = select_number
92
+ self.sampled_frames = sampled_frames
93
+ assert offline_processed_text_folder or (expression_file and tokenizer)
94
+ self.lazy = lazy
95
+
96
+ self.max_length = max_length
97
+
98
+ self.template_map_fn = template_map_fn
99
+ if isinstance(self.template_map_fn, dict) and self.lazy:
100
+ _type = self.template_map_fn['type']
101
+ del self.template_map_fn['type']
102
+ self.template_map_fn = _type(**self.template_map_fn)
103
+
104
+ if offline_processed_text_folder and expression_file:
105
+ print_log(
106
+ 'Both `offline_processed_text_folder` and '
107
+ '`data_path` are set, and we load dataset from'
108
+ '`offline_processed_text_folder` '
109
+ f'({offline_processed_text_folder})',
110
+ logger='current',
111
+ level=logging.WARNING)
112
+
113
+ if offline_processed_text_folder is not None:
114
+ raise NotImplementedError
115
+ else:
116
+ video_ids, anno_dict = self.json_file_preprocess(expression_file)
117
+ if self.lazy:
118
+ self.video_ids = video_ids
119
+ self.anno_dict = anno_dict
120
+ else:
121
+ raise NotImplementedError
122
+
123
+ self.sam2_folder = sam2_folder
124
+ if extra_image_processor is not None:
125
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
126
+ self.down_ratio = 1
127
+ self.repeats = repeats
128
+
129
+ self._system = ''
130
+
131
+ self.downsample_ratio = 0.5
132
+ self.image_size = 448
133
+ patch_size = 14
134
+ self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
135
+
136
+ self.transformer = T.Compose([
137
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
138
+ T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
139
+ T.ToTensor(),
140
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
141
+ ])
142
+
143
+ if special_tokens is not None:
144
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
145
+
146
+ self.use_fast = use_fast
147
+ self.n_fast_images = n_fast_images
148
+ self.fast_pool_size = fast_pool_size
149
+
150
+ self.frame_contiguous_sample = frame_contiguous_sample
151
+
152
+ # for visualization debug
153
+ self.save_folder = './work_dirs/video_debug/'
154
+ self.cur_number = 0
155
+
156
+ print("Video res dataset (ref-sam2), include {} items.".format(len(self.video_ids)))
157
+
158
+ def __len__(self):
159
+ return len(self.video_ids) * self.repeats
160
+
161
+ @property
162
+ def modality_length(self):
163
+ length_list = []
164
+ for data_dict in self.video_ids:
165
+ cur_len = 20000
166
+ length_list.append(cur_len)
167
+ return length_list
168
+
169
+ def real_len(self):
170
+ return len(self.video_ids)
171
+
172
+ def json_file_preprocess(self, expression_file):
173
+ # prepare expression annotation files
174
+ with open(expression_file, 'r') as f:
175
+ expression_datas = json.load(f)
176
+
177
+ video_ids = list(expression_datas.keys())
178
+ return video_ids, expression_datas
179
+
180
+ def dataset_map_fn(self, objects_expression_infos, n_frames, n_fast_frames=0):
181
+ # prepare text
182
+ if self.mode == 'long':
183
+ expressions = [object_info['formated'] for object_info in objects_expression_infos]
184
+ self.cur_mode = self.mode
185
+ elif self.mode == 'short':
186
+ expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps'])-1)] for object_info in objects_expression_infos]
187
+ self.cur_mode = self.mode
188
+ else:
189
+ if random.random() < 0.5:
190
+ expressions = [object_info['formated'] for object_info in objects_expression_infos]
191
+ self.cur_mode = 'long'
192
+ else:
193
+ expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps']) - 1)] for
194
+ object_info in objects_expression_infos]
195
+ self.cur_mode = 'short'
196
+ text_dict = self.prepare_text(n_frames, expressions, num_image_tokens=self.patch_token,
197
+ n_fast_frames=n_fast_frames)
198
+ ret = {'conversation': text_dict['conversation']}
199
+ return ret
200
+
201
+ def prepare_text(self, n_frames, expressions, num_image_tokens=256, n_fast_frames=0):
202
+
203
+ if self.use_fast:
204
+ fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
205
+ f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_frames * self.fast_pool_size * self.fast_pool_size}' \
206
+ f'{self.FAST_IMG_END_TOKEN}' + '\n'
207
+ else:
208
+ fast_frame_token_str = ''
209
+
210
+ frame_token_str = f'{self.IMG_START_TOKEN}' \
211
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
212
+ f'{self.IMG_END_TOKEN}'
213
+
214
+ questions = []
215
+ answers = []
216
+ for i, exp in enumerate(expressions):
217
+ if self.cur_mode == 'short':
218
+ question_template = random.choice(SEG_QUESTIONS_SHORT)
219
+ exp = exp.replace("A ", '')
220
+ else:
221
+ question_template = random.choice(SEG_QUESTIONS)
222
+ questions.append(question_template.format(class_name=exp))
223
+ answers.append(random.choice(ANSWER_LIST))
224
+ qa_list = []
225
+ for i, (question, answer) in enumerate(zip(questions, answers)):
226
+ if i == 0:
227
+ frame_tokens = frame_token_str + '\n'
228
+ # frame_tokens = '=' + ' '
229
+ frame_tokens = frame_tokens * n_frames
230
+ frame_tokens = frame_tokens.strip()
231
+ frame_tokens = fast_frame_token_str + frame_tokens
232
+ qa_list.append(
233
+ {'from': 'human', 'value': frame_tokens + question}
234
+ )
235
+ else:
236
+ qa_list.append(
237
+ {'from': 'human', 'value': question}
238
+ )
239
+ qa_list.append(
240
+ {'from': 'gpt', 'value': answer}
241
+ )
242
+
243
+ input = ''
244
+ conversation = []
245
+ for msg in qa_list:
246
+ if msg['from'] == 'human':
247
+ input += msg['value']
248
+ elif msg['from'] == 'gpt':
249
+ conversation.append({'input': input, 'output': msg['value']})
250
+ input = ''
251
+ else:
252
+ raise NotImplementedError
253
+
254
+ # add system information
255
+ conversation[0].update({'system': self._system})
256
+ return {'conversation': conversation}
257
+
258
+ def __getitem__(self, index):
259
+ index = index % self.real_len()
260
+ video_id = self.video_ids[index]
261
+ expression_dict = self.anno_dict[video_id]
262
+ object_ids = list(expression_dict['objects'].keys())
263
+
264
+ video_path = os.path.join(self.sam2_folder, expression_dict['video_path'])
265
+ anno_path = os.path.join(self.sam2_folder, expression_dict['anno_path'])
266
+
267
+ video_frames = get_video_frames(video_path)
268
+
269
+ if self.use_fast:
270
+ # sample fast branch
271
+ fast_interval = len(video_frames) / (self.n_fast_images + 1e-4)
272
+ sampled_fast_frame_idxs = [min(int(i * fast_interval), len(video_frames) - 1) for i in range(self.n_fast_images)]
273
+ fast_video_frames = [video_frames[_idx] for _idx in sampled_fast_frame_idxs]
274
+ else:
275
+ fast_video_frames = None
276
+
277
+ video_frames = video_frames[::4]
278
+
279
+ # mask annotation
280
+ with open(anno_path, 'r') as f:
281
+ mask_data = json.load(f)
282
+ masklents = decode_masklet(mask_data['masklet'])
283
+
284
+ n_frames = len(masklents)
285
+ n_objects = len(object_ids)
286
+
287
+ # sample object
288
+ if n_objects > self.select_number:
289
+ selected_indexes = np.random.choice(n_objects, self.select_number)
290
+ else:
291
+ selected_indexes = np.random.choice(n_objects, self.select_number, replace=True)
292
+
293
+ selected_object_ids = [object_ids[_idx] for _idx in selected_indexes]
294
+ objects_expression_infos = [expression_dict['objects'][_idx] for _idx in selected_object_ids]
295
+ _masklents = []
296
+ for _mask in masklents:
297
+ _mask_selected = []
298
+ for _idx in selected_object_ids:
299
+ _mask_selected.append(_mask[:, :, int(_idx)])
300
+ _mask_selected = np.stack(_mask_selected, axis=2)
301
+ _masklents.append(_mask_selected)
302
+ masklents = _masklents
303
+
304
+ # sample video frames
305
+ # prepare images, random select k frames
306
+ if n_frames > self.sampled_frames + 1:
307
+ if self.frame_contiguous_sample and random.random() < 0.5:
308
+ # do contiguous sample
309
+ selected_start_frame = np.random.choice(n_frames - self.sampled_frames, 1, replace=False)
310
+ selected_frame_indexes = [selected_start_frame[0] + _i for _i in range(self.sampled_frames)]
311
+ else:
312
+ selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=False)
313
+ else:
314
+ selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=True)
315
+ selected_frame_indexes.sort()
316
+
317
+ video_frames = [video_frames[_idx] for _idx in selected_frame_indexes]
318
+ masklents = [masklents[_idx] for _idx in selected_frame_indexes]
319
+
320
+ data_dict = self.dataset_map_fn(objects_expression_infos, len(video_frames), n_fast_frames=self.n_fast_images)
321
+ result = self.template_map_fn(data_dict)
322
+ data_dict.update(result)
323
+ result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
324
+ data_dict.update(result)
325
+
326
+ pixel_values = []
327
+ extra_pixel_values = []
328
+ for frame in video_frames:
329
+ frame = frame[:, :, ::-1]
330
+ frame_image = Image.fromarray(frame).convert('RGB')
331
+ ori_width, ori_height = frame_image.size
332
+ if self.extra_image_processor is not None:
333
+ g_image = np.array(frame_image) # for grounding
334
+ g_image = self.extra_image_processor.apply_image(g_image)
335
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
336
+ extra_pixel_values.append(g_pixel_values)
337
+
338
+ frame_image = self.transformer(frame_image)
339
+ pixel_values.append(frame_image)
340
+
341
+ pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
342
+ data_dict['pixel_values'] = pixel_values
343
+ if self.extra_image_processor is not None:
344
+ data_dict['g_pixel_values'] = extra_pixel_values
345
+
346
+ # for fast branch
347
+ if self.use_fast:
348
+ fast_pixel_values = []
349
+ for frame_image in fast_video_frames:
350
+ frame = frame_image[:, :, ::-1]
351
+ frame_image = Image.fromarray(frame).convert('RGB')
352
+ ori_width, ori_height = frame_image.size
353
+
354
+ frame_image = self.transformer(frame_image)
355
+ fast_pixel_values.append(frame_image)
356
+
357
+ fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
358
+ data_dict['fast_pixel_values'] = fast_pixel_values
359
+
360
+ # process and get masks
361
+ masklents = np.stack(masklents, axis=0) # (n_frames, h, w, n_obj)
362
+ masklents = torch.from_numpy(masklents).permute(3, 0, 1, 2)
363
+ masklents = masklents.flatten(0, 1)
364
+ # print('sam2-mask_shape:', masklents.shape)
365
+ # print('sam2-pixel_values:', data_dict['pixel_values'].shape)
366
+ # print('sam2-g_pixel_values:', len(data_dict['g_pixel_values']), ', ', data_dict['g_pixel_values'][0].shape)
367
+ data_dict['masks'] = masklents
368
+ data_dict['type'] = 'video'
369
+ return data_dict
370
+
371
+ def visualization_debug(self, data_dict):
372
+ save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
373
+ if not os.path.exists(save_folder):
374
+ os.mkdir(save_folder)
375
+ self.cur_number += 1
376
+
377
+ # images
378
+
379
+ show_images = []
380
+
381
+ pixel_values = data_dict['pixel_values']
382
+ save_folder_image = os.path.join(save_folder, 'image')
383
+ if not os.path.exists(save_folder_image):
384
+ os.mkdir(save_folder_image)
385
+ for i_image, image_pixel_value in enumerate(pixel_values):
386
+ # print(image_pixel_value.shape)
387
+ image_pixel_value[0] = image_pixel_value[0] * 0.2686
388
+ image_pixel_value[1] = image_pixel_value[1] * 0.2613
389
+ image_pixel_value[2] = image_pixel_value[2] * 0.2757
390
+ image_pixel_value[0] = image_pixel_value[0] + 0.4814
391
+ image_pixel_value[1] = image_pixel_value[1] + 0.4578
392
+ image_pixel_value[2] = image_pixel_value[2] + 0.4082
393
+ image_pixel_value = image_pixel_value * 255
394
+ image_pixel_value = image_pixel_value.permute(1, 2, 0)
395
+ image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
396
+ # print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
397
+ # print(image_pixel_value.shape)
398
+ show_images.append(image_pixel_value)
399
+ cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
400
+
401
+ # text
402
+ input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
403
+ with open(os.path.join(save_folder, 'text.json'), 'w') as f:
404
+ json.dump([input_text], f)
405
+
406
+ # masks
407
+ save_folder_mask = os.path.join(save_folder, 'mask')
408
+ if not os.path.exists(save_folder_mask):
409
+ os.mkdir(save_folder_mask)
410
+ n_frames = len(pixel_values)
411
+ masks = data_dict['masks']
412
+ _, h, w = masks.shape
413
+ masks = masks.reshape(-1, n_frames, h, w)
414
+ for i_obj, obj_masks in enumerate(masks):
415
+ save_folder_mask_obj_folder = os.path.join(save_folder_mask, 'obj_{}'.format(i_obj))
416
+ if not os.path.exists(save_folder_mask_obj_folder):
417
+ os.mkdir(save_folder_mask_obj_folder)
418
+ for i_frame, f_mask in enumerate(obj_masks):
419
+ f_mask = f_mask.numpy()
420
+ f_mask = f_mask * 255
421
+ f_mask = np.stack([f_mask * 1, f_mask * 0, f_mask * 0], axis=2)
422
+ f_mask = show_images[i_frame] * 0.3 + 0.7 * f_mask
423
+ f_mask = f_mask.astype(np.uint8)
424
+ cv2.imwrite(os.path.join(save_folder_mask_obj_folder, '{}.png'.format(i_frame)), f_mask)
425
+ return
426
+
427
+ def get_video_frames(video_path):
428
+ cap = cv2.VideoCapture(video_path)
429
+
430
+ if not cap.isOpened():
431
+ print("Error: Cannot open video file.")
432
+ return
433
+
434
+ frames = []
435
+
436
+ frame_id = 0
437
+ while True:
438
+ ret, frame = cap.read()
439
+
440
+ if not ret:
441
+ break
442
+
443
+ frames.append(frame)
444
+
445
+ frame_id += 1
446
+
447
+ cap.release()
448
+ return frames
449
+
450
+
451
+ def images_to_video(frames, video_name, fps=6):
452
+ height, width, layers = frames[0].shape
453
+
454
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
455
+ video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
456
+
457
+ for frame in frames:
458
+ video.write(frame)
459
+
460
+ # cv2.destroyAllWindows()
461
+ video.release()
462
+ return
463
+
464
+ def decode_masklet(masklet):
465
+ masks = []
466
+ for _rle in masklet:
467
+ mask = maskUtils.decode(_rle)
468
+ masks.append(mask)
469
+ return masks
470
+
471
+ def draw_mask(image, mask):
472
+ obj_mask = mask * 255
473
+ obj_mask = np.stack([obj_mask * 1, obj_mask * 0, obj_mask * 0], axis=2)
474
+ obj_mask = obj_mask * 0.5 + copy.deepcopy(image) * 0.5
475
+ obj_mask = obj_mask.astype(np.uint8)
476
+ return obj_mask
477
+
478
+ def add_mask2images(frames, masklets):
479
+ show_videos = []
480
+ for i_frames, (frame, masks) in enumerate(zip(frames, masklets)):
481
+ if i_frames == 0:
482
+ n_obj = masks.shape[-1]
483
+ for i_obj in range(n_obj):
484
+ show_videos.append([])
485
+
486
+ n_obj = masks.shape[-1]
487
+ for i_obj in range(n_obj):
488
+ show_videos[i_obj].append(draw_mask(copy.deepcopy(frame), masks[:, :, i_obj]))
489
+ return show_videos
projects/llava_sam2/datasets/ReVOS_Dataset.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Literal
4
+
5
+ import torch
6
+ from datasets import Dataset as HFDataset
7
+ from datasets import DatasetDict
8
+ from mmengine import print_log
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset
11
+ import numpy as np
12
+
13
+ from xtuner.registry import BUILDER
14
+ from xtuner.dataset.huggingface import build_origin_dataset
15
+ import copy
16
+
17
+ from .encode_fn import video_lisa_encode_fn
18
+ import json
19
+ import random
20
+ import pycocotools.mask as maskUtils
21
+ import cv2
22
+ import torchvision.transforms as T
23
+ from torchvision.transforms.functional import InterpolationMode
24
+
25
+ SEG_QUESTIONS = [
26
+ "Can you segment the {class_name} in this image?",
27
+ "Please segment {class_name} in this image.",
28
+ "What is {class_name} in this image? Please respond with segmentation mask.",
29
+ "What is {class_name} in this image? Please output segmentation mask.",
30
+
31
+ "Can you segment the {class_name} in this image",
32
+ "Please segment {class_name} in this image",
33
+ "What is {class_name} in this image? Please respond with segmentation mask",
34
+ "What is {class_name} in this image? Please output segmentation mask",
35
+
36
+ "Could you provide a segmentation mask for the {class_name} in this image?",
37
+ "Please identify and segment the {class_name} in this image.",
38
+ "Where is the {class_name} in this picture? Please respond with a segmentation mask.",
39
+ "Can you highlight the {class_name} in this image with a segmentation mask?",
40
+
41
+ "Could you provide a segmentation mask for the {class_name} in this image",
42
+ "Please identify and segment the {class_name} in this image",
43
+ "Where is the {class_name} in this picture? Please respond with a segmentation mask",
44
+ "Can you highlight the {class_name} in this image with a segmentation mask",
45
+ ]
46
+
47
+ ANSWER_LIST = [
48
+ "It is [SEG].",
49
+ "Sure, [SEG].",
50
+ "Sure, it is [SEG].",
51
+ "Sure, the segmentation result is [SEG].",
52
+ "[SEG].",
53
+ ]
54
+
55
+ class VideoReVOSDataset(Dataset):
56
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
57
+ IMAGENET_STD = (0.229, 0.224, 0.225)
58
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
59
+ IMG_START_TOKEN = '<img>'
60
+ IMG_END_TOKEN = '</img>'
61
+
62
+ FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
63
+ FAST_IMG_START_TOKEN = '<fast_img>'
64
+ FAST_IMG_END_TOKEN = '</fast_img>'
65
+
66
+ def __init__(self,
67
+ image_folder,
68
+ expression_file,
69
+ mask_file,
70
+ extra_image_processor=None,
71
+ tokenizer=None,
72
+ select_number=5,
73
+ sampled_frames=10,
74
+ offline_processed_text_folder=None,
75
+ template_map_fn=None,
76
+ max_length=2048,
77
+ lazy=True,
78
+ repeats=1,
79
+ special_tokens=None,
80
+ frame_contiguous_sample=False,
81
+ use_fast=False,
82
+ arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
83
+ preprocessor=None,
84
+ # only work if use_fast = True
85
+ n_fast_images=50,
86
+ fast_pool_size=4,
87
+ fast_token_after_question=False,
88
+ ):
89
+ assert lazy is True
90
+ self.tokenizer = BUILDER.build(tokenizer)
91
+ self.select_number = select_number
92
+ self.sampled_frames = sampled_frames
93
+ assert offline_processed_text_folder or (expression_file and tokenizer)
94
+ self.lazy = lazy
95
+
96
+ self.max_length = max_length
97
+
98
+ self.template_map_fn = template_map_fn
99
+ if isinstance(self.template_map_fn, dict) and self.lazy:
100
+ _type = self.template_map_fn['type']
101
+ del self.template_map_fn['type']
102
+ self.template_map_fn = _type(**self.template_map_fn)
103
+
104
+ if offline_processed_text_folder and expression_file:
105
+ print_log(
106
+ 'Both `offline_processed_text_folder` and '
107
+ '`data_path` are set, and we load dataset from'
108
+ '`offline_processed_text_folder` '
109
+ f'({offline_processed_text_folder})',
110
+ logger='current',
111
+ level=logging.WARNING)
112
+
113
+ self.arch_type = arch_type
114
+ if self.arch_type == 'qwen':
115
+ self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
116
+ self.IMG_START_TOKEN = '<|vision_start|>'
117
+ self.IMG_END_TOKEN = '<|vision_end|>'
118
+ elif self.arch_type == 'llava':
119
+ self.IMG_CONTEXT_TOKEN = '<image>'
120
+ self.IMG_START_TOKEN = ''
121
+ self.IMG_END_TOKEN = ''
122
+
123
+
124
+ if offline_processed_text_folder is not None:
125
+ raise NotImplementedError
126
+ else:
127
+ vid2metaid, metas, mask_dict = self.json_file_preprocess(expression_file, mask_file)
128
+ self.vid2metaid = vid2metaid
129
+ self.videos = list(self.vid2metaid.keys())
130
+ self.mask_dict = mask_dict
131
+ self.json_datas = metas
132
+ json_datas = metas
133
+ json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
134
+ if self.lazy:
135
+ self.text_data = build_origin_dataset(json_data, 'train')
136
+ else:
137
+ raise NotImplementedError
138
+
139
+ self.image_folder = image_folder
140
+ if extra_image_processor is not None:
141
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
142
+ self.down_ratio = 1
143
+ self.repeats = repeats
144
+
145
+ self._system = ''
146
+
147
+ self.downsample_ratio = 0.5
148
+ if self.arch_type == 'llava':
149
+ self.downsample_ratio = 1
150
+ self.image_size = 448
151
+ if self.arch_type == 'llava':
152
+ self.image_size = 336
153
+ patch_size = 14
154
+ self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
155
+ if self.arch_type == 'qwen':
156
+ self.patch_token = 1
157
+
158
+ if preprocessor is None:
159
+ self.transformer = T.Compose([
160
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
161
+ T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
162
+ T.ToTensor(),
163
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
164
+ ])
165
+ self.preprocessor = None
166
+ else:
167
+ self.transformer = None
168
+ self.preprocessor = BUILDER.build(preprocessor)
169
+
170
+ if special_tokens is not None:
171
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
172
+
173
+ self.use_fast = use_fast
174
+ self.n_fast_images = n_fast_images
175
+ self.fast_pool_size = fast_pool_size
176
+
177
+ self.frame_contiguous_sample = frame_contiguous_sample
178
+
179
+ # for visualization debug
180
+ self.save_folder = './work_dirs/video_debug/'
181
+ self.cur_number = 0
182
+
183
+ # exist_thr
184
+ self.exist_thr = 8
185
+ self.fast_token_after_question = fast_token_after_question
186
+ if self.fast_token_after_question:
187
+ assert self.use_fast
188
+
189
+ print("Video res dataset, include {} items.".format(len(self.vid2metaid)))
190
+
191
+ def __len__(self):
192
+ return len(self.vid2metaid) * self.repeats
193
+
194
+ @property
195
+ def modality_length(self):
196
+ length_list = []
197
+ for data_dict in self.vid2metaid:
198
+ cur_len = 10000
199
+ length_list.append(cur_len)
200
+ return length_list
201
+
202
+ def real_len(self):
203
+ return len(self.vid2metaid)
204
+
205
+ def json_file_preprocess(self, expression_file, mask_file):
206
+ # prepare expression annotation files
207
+ with open(expression_file, 'r') as f:
208
+ expression_datas = json.load(f)['videos']
209
+
210
+ metas = []
211
+ anno_count = 0 # serve as anno_id
212
+ vid2metaid = {}
213
+ for vid_name in expression_datas:
214
+ vid_express_data = expression_datas[vid_name]
215
+
216
+ vid_frames = sorted(vid_express_data['frames'])
217
+ vid_len = len(vid_frames)
218
+
219
+ exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
220
+ for exp_id in exp_id_list:
221
+ exp_dict = vid_express_data['expressions'][exp_id]
222
+ meta = {}
223
+ meta['video'] = vid_name
224
+ meta['exp'] = exp_dict['exp'] # str
225
+ meta['mask_anno_id'] = exp_dict['anno_id']
226
+
227
+ if 'obj_id' in exp_dict.keys():
228
+ meta['obj_id'] = exp_dict['obj_id']
229
+ else:
230
+ meta['obj_id'] = [0, ] # Ref-Youtube-VOS only has one object per expression
231
+ meta['anno_id'] = [str(anno_count), ]
232
+ anno_count += 1
233
+ meta['frames'] = vid_frames
234
+ meta['exp_id'] = exp_id
235
+
236
+ meta['length'] = vid_len
237
+ metas.append(meta)
238
+ if vid_name not in vid2metaid.keys():
239
+ vid2metaid[vid_name] = []
240
+ vid2metaid[vid_name].append(len(metas) - 1)
241
+
242
+ # process mask annotation files
243
+ with open(mask_file, 'rb') as f:
244
+ mask_dict = json.load(f)
245
+
246
+ return vid2metaid, metas, mask_dict
247
+
248
+ def create_img_to_refs_mapping(self, refs_train):
249
+ img2refs = {}
250
+ for ref in refs_train:
251
+ img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
252
+ return img2refs
253
+
254
+ def decode_mask(self, video_masks, image_size):
255
+ ret_masks = []
256
+ for object_masks in video_masks:
257
+ # None object
258
+ if len(object_masks) == 0:
259
+ if len(ret_masks) != 0:
260
+ _object_masks = ret_masks[0] * 0
261
+ else:
262
+ _object_masks = np.zeros(
263
+ (self.sampled_frames, image_size[0], image_size[1]), dtype=np.uint8)
264
+ else:
265
+ _object_masks = []
266
+ for i_frame in range(len(object_masks[0])):
267
+ _mask = np.zeros(image_size, dtype=np.uint8)
268
+ for i_anno in range(len(object_masks)):
269
+ if object_masks[i_anno][i_frame] is None:
270
+ continue
271
+ m = maskUtils.decode(object_masks[i_anno][i_frame])
272
+ if m.ndim == 3:
273
+ m = m.sum(axis=2).astype(np.uint8)
274
+ else:
275
+ m = m.astype(np.uint8)
276
+ _mask = _mask | m
277
+ _object_masks.append(_mask)
278
+ _object_masks = np.stack(_object_masks, axis=0)
279
+ # if self.pad_image_to_square:
280
+ # _object_masks = expand2square_mask(_object_masks)
281
+ ret_masks.append(_object_masks)
282
+ _shape = ret_masks[0].shape
283
+ for item in ret_masks:
284
+ if item.shape != _shape:
285
+ print([_ret_mask.shape for _ret_mask in ret_masks])
286
+ return None
287
+ ret_masks = np.stack(ret_masks, axis=0) # (n_obj, n_frames, h, w)
288
+
289
+ ret_masks = torch.from_numpy(ret_masks)
290
+ # ret_masks = F.interpolate(ret_masks, size=(self.image_size // self.down_ratio,
291
+ # self.image_size // self.down_ratio), mode='nearest')
292
+ ret_masks = ret_masks.flatten(0, 1)
293
+ return ret_masks
294
+
295
+ def dataset_map_fn(self, data_dict, select_k=5):
296
+ images = []
297
+
298
+ len_frames = len(data_dict[0]['frames'])
299
+ for objet_info in data_dict:
300
+ assert len_frames == len(objet_info['frames'])
301
+
302
+ # prepare images, random select k frames
303
+ if len_frames > select_k + 1:
304
+ if self.frame_contiguous_sample and random.random() < 0.5:
305
+ # do contiguous sample
306
+ selected_start_frame = np.random.choice(len_frames - select_k, 1, replace=False)
307
+ selected_frame_indexes = [selected_start_frame[0] + _i for _i in range(select_k)]
308
+ else:
309
+ selected_frame_indexes = np.random.choice(len_frames, select_k, replace=False)
310
+ else:
311
+ selected_frame_indexes = np.random.choice(len_frames, select_k, replace=True)
312
+ selected_frame_indexes.sort()
313
+
314
+ if self.use_fast:
315
+ # sample fast branch
316
+ fast_interval = len_frames / (self.n_fast_images + 1e-4)
317
+ sampled_fast_frame_idxs = [min(int(i * fast_interval), len_frames - 1) for i in range(self.n_fast_images)]
318
+ fast_video_frames = []
319
+ for selected_frame_index in sampled_fast_frame_idxs:
320
+ frame_id = data_dict[0]['frames'][selected_frame_index]
321
+ fast_video_frames.append(os.path.join(data_dict[0]['video'], frame_id + '.jpg'))
322
+ else:
323
+ fast_video_frames = None
324
+ sampled_fast_frame_idxs = None
325
+
326
+ for selected_frame_index in selected_frame_indexes:
327
+ frame_id = data_dict[0]['frames'][selected_frame_index]
328
+ images.append(os.path.join(data_dict[0]['video'], frame_id + '.jpg'))
329
+
330
+ # prepare text
331
+ expressions = [object_info['exp'] for object_info in data_dict]
332
+ if self.use_fast:
333
+ text_dict = self.prepare_text(select_k, expressions, num_image_tokens=self.patch_token,
334
+ n_fast_images=len(fast_video_frames),)
335
+ else:
336
+ text_dict = self.prepare_text(select_k, expressions, num_image_tokens=self.patch_token)
337
+
338
+
339
+ # prepare masks
340
+ video_masks = []
341
+ for object_info in data_dict:
342
+ anno_ids = object_info['mask_anno_id']
343
+ # print('anno_ids: ', anno_ids)
344
+ obj_masks = []
345
+ for anno_id in anno_ids:
346
+ anno_id = str(anno_id)
347
+ frames_masks = self.mask_dict[anno_id]
348
+ frames_masks_ = []
349
+ for frame_idx in selected_frame_indexes:
350
+ frames_masks_.append(copy.deepcopy(frames_masks[frame_idx]))
351
+ obj_masks.append(frames_masks_)
352
+ video_masks.append(obj_masks)
353
+
354
+ if self.use_fast:
355
+ fast_video_masks = []
356
+ assert sampled_fast_frame_idxs is not None
357
+ for object_info in data_dict:
358
+ anno_ids = object_info['mask_anno_id']
359
+ obj_masks = []
360
+ for anno_id in anno_ids:
361
+ anno_id = str(anno_id)
362
+ frames_masks = self.mask_dict[anno_id]
363
+ frames_masks_ = []
364
+ for frame_idx in sampled_fast_frame_idxs:
365
+ frames_masks_.append(copy.deepcopy(frames_masks[frame_idx]))
366
+ obj_masks.append(frames_masks_)
367
+ fast_video_masks.append(obj_masks)
368
+ else:
369
+ fast_video_masks = None
370
+
371
+ ret = {'images': images, 'video_masks': video_masks, 'conversation': text_dict['conversation'],
372
+ 'fast_images': fast_video_frames, 'fast_video_masks': fast_video_masks}
373
+ return ret
374
+
375
+ def prepare_text(self, n_frames, expressions, num_image_tokens=256, n_fast_images=50):
376
+
377
+ if self.use_fast and not self.fast_token_after_question:
378
+ fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
379
+ f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
380
+ f'{self.FAST_IMG_END_TOKEN}' + '\n'
381
+ else:
382
+ fast_frame_token_str = ''
383
+
384
+ frame_token_str = f'{self.IMG_START_TOKEN}' \
385
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
386
+ f'{self.IMG_END_TOKEN}'
387
+ if self.fast_token_after_question:
388
+ assert self.use_fast
389
+ after_question_str = f'{self.FAST_IMG_START_TOKEN}' \
390
+ f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
391
+ f'{self.FAST_IMG_END_TOKEN}'
392
+ else:
393
+ after_question_str = ''
394
+
395
+ questions = []
396
+ answers = []
397
+ for i, exp in enumerate(expressions):
398
+ # the exp is a question
399
+ if '?' in exp:
400
+ questions.append(exp)
401
+ else:
402
+ exp = exp.replace('.', '').strip()
403
+ question_template = random.choice(SEG_QUESTIONS)
404
+ questions.append(question_template.format(class_name=exp.lower()))
405
+
406
+ answers.append(random.choice(ANSWER_LIST))
407
+ qa_list = []
408
+ for i, (question, answer) in enumerate(zip(questions, answers)):
409
+ if i == 0:
410
+ frame_tokens = frame_token_str + '\n'
411
+ # frame_tokens = '=' + ' '
412
+ frame_tokens = frame_tokens * n_frames
413
+ frame_tokens = frame_tokens.strip()
414
+ frame_tokens = fast_frame_token_str + frame_tokens
415
+ qa_list.append(
416
+ {'from': 'human', 'value': frame_tokens + question + after_question_str}
417
+ )
418
+ else:
419
+ qa_list.append(
420
+ {'from': 'human', 'value': question + after_question_str}
421
+ )
422
+ qa_list.append(
423
+ {'from': 'gpt', 'value': answer}
424
+ )
425
+
426
+ input = ''
427
+ conversation = []
428
+ for msg in qa_list:
429
+ if msg['from'] == 'human':
430
+ input += msg['value']
431
+ elif msg['from'] == 'gpt':
432
+ conversation.append({'input': input, 'output': msg['value']})
433
+ input = ''
434
+ else:
435
+ raise NotImplementedError
436
+
437
+ # add system information
438
+ conversation[0].update({'system': self._system})
439
+ return {'conversation': conversation}
440
+
441
+ def __getitem__(self, index):
442
+ index = index % self.real_len()
443
+ selected_video_objects = self.vid2metaid[self.videos[index]]
444
+ video_objects_infos = [copy.deepcopy(self.text_data[idx]) for idx in selected_video_objects]
445
+
446
+ if len(video_objects_infos) > self.select_number:
447
+ selected_indexes = np.random.choice(len(video_objects_infos), self.select_number)
448
+ video_objects_infos = [video_objects_infos[_idx] for _idx in selected_indexes]
449
+ else:
450
+ selected_indexes = np.random.choice(len(video_objects_infos), self.select_number, replace=True)
451
+ video_objects_infos = [video_objects_infos[_idx] for _idx in selected_indexes]
452
+
453
+ data_dict = self.dataset_map_fn(video_objects_infos, select_k=self.sampled_frames)
454
+
455
+ assert 'images' in data_dict.keys()
456
+ pixel_values = []
457
+ extra_pixel_values = []
458
+ num_video_tokens = None
459
+ num_frame_tokens = None
460
+ if data_dict.get('images', None) is not None:
461
+ frames_files = data_dict['images']
462
+ frames_files = [os.path.join(self.image_folder, frame_file) for frame_file in frames_files]
463
+ for frame_path in frames_files:
464
+ frame_image = Image.open(frame_path).convert('RGB')
465
+ ori_width, ori_height = frame_image.size
466
+ if self.extra_image_processor is not None:
467
+ g_image = np.array(frame_image) # for grounding
468
+ g_image = self.extra_image_processor.apply_image(g_image)
469
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
470
+ extra_pixel_values.append(g_pixel_values)
471
+
472
+ if self.preprocessor is not None:
473
+ pass
474
+ else:
475
+ frame_image = self.transformer(frame_image)
476
+ pixel_values.append(frame_image)
477
+
478
+ if self.preprocessor is not None:
479
+ if self.arch_type == 'qwen':
480
+ _data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
481
+ _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
482
+ _data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
483
+ num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
484
+ num_frames = _data_dict['image_grid_thw'].shape[0]
485
+ num_video_tokens = num_frame_tokens * num_frames
486
+ elif self.arch_type == 'llava':
487
+ _data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
488
+ _data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
489
+ _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
490
+ else:
491
+ raise NotImplementedError
492
+ data_dict.update(_data_dict)
493
+ else:
494
+ pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
495
+ data_dict['pixel_values'] = pixel_values
496
+ if self.extra_image_processor is not None:
497
+ data_dict['g_pixel_values'] = extra_pixel_values
498
+
499
+ # process and get masks
500
+ masks = self.decode_mask(data_dict['video_masks'], image_size=(ori_height, ori_width))
501
+ if masks is None:
502
+ return self.__getitem__(random.randint(0, self.real_len()))
503
+ data_dict['masks'] = masks
504
+ else:
505
+ data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
506
+ data_dict['masks'] = None
507
+
508
+ if num_video_tokens is not None:
509
+ assert self.patch_token == 1
510
+ input_str = data_dict['conversation'][0]['input']
511
+ input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens)
512
+ assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens
513
+ data_dict['conversation'][0]['input'] = input_str
514
+
515
+ result = self.template_map_fn(data_dict)
516
+ data_dict.update(result)
517
+ result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length)
518
+ data_dict.update(result)
519
+
520
+ # for fast branch
521
+ if self.use_fast:
522
+ fast_pixel_values = []
523
+ frames_files = data_dict['fast_images']
524
+ frames_files = [os.path.join(self.image_folder, frame_file) for frame_file in frames_files]
525
+ for frame_path in frames_files:
526
+ frame_image = Image.open(frame_path).convert('RGB')
527
+ ori_width, ori_height = frame_image.size
528
+
529
+ frame_image = self.transformer(frame_image)
530
+ fast_pixel_values.append(frame_image)
531
+
532
+ fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
533
+ data_dict['fast_pixel_values'] = fast_pixel_values
534
+
535
+ # process and get masks
536
+ masks = self.decode_mask(data_dict['fast_video_masks'], image_size=(ori_height, ori_width))
537
+
538
+ if masks is None:
539
+ return self.__getitem__(random.randint(0, self.real_len()))
540
+
541
+ data_dict['fast_exists'] = masks.to(dtype=torch.int).sum(dim=(-2, -1)).ge(self.exist_thr).unsqueeze(-1)
542
+
543
+
544
+ del data_dict['fast_video_masks']
545
+ data_dict['type'] = 'video'
546
+ return data_dict
547
+
548
+ def visualization_debug(self, data_dict):
549
+ save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
550
+ if not os.path.exists(save_folder):
551
+ os.mkdir(save_folder)
552
+ self.cur_number += 1
553
+
554
+ # images
555
+
556
+ show_images = []
557
+
558
+ pixel_values = data_dict['pixel_values']
559
+ save_folder_image = os.path.join(save_folder, 'image')
560
+ if not os.path.exists(save_folder_image):
561
+ os.mkdir(save_folder_image)
562
+ for i_image, image_pixel_value in enumerate(pixel_values):
563
+ # print(image_pixel_value.shape)
564
+ image_pixel_value[0] = image_pixel_value[0] * 0.2686
565
+ image_pixel_value[1] = image_pixel_value[1] * 0.2613
566
+ image_pixel_value[2] = image_pixel_value[2] * 0.2757
567
+ image_pixel_value[0] = image_pixel_value[0] + 0.4814
568
+ image_pixel_value[1] = image_pixel_value[1] + 0.4578
569
+ image_pixel_value[2] = image_pixel_value[2] + 0.4082
570
+ image_pixel_value = image_pixel_value * 255
571
+ image_pixel_value = image_pixel_value.permute(1, 2, 0)
572
+ image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
573
+ # print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
574
+ # print(image_pixel_value.shape)
575
+ show_images.append(image_pixel_value)
576
+ cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
577
+
578
+ # text
579
+ input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
580
+ with open(os.path.join(save_folder, 'text.json'), 'w') as f:
581
+ json.dump([input_text], f)
582
+
583
+ # masks
584
+ save_folder_mask = os.path.join(save_folder, 'mask')
585
+ if not os.path.exists(save_folder_mask):
586
+ os.mkdir(save_folder_mask)
587
+ n_frames = len(pixel_values)
588
+ masks = data_dict['masks']
589
+ _, h, w = masks.shape
590
+ masks = masks.reshape(-1, n_frames, h, w)
591
+ for i_obj, obj_masks in enumerate(masks):
592
+ save_folder_mask_obj_folder = os.path.join(save_folder_mask, 'obj_{}'.format(i_obj))
593
+ if not os.path.exists(save_folder_mask_obj_folder):
594
+ os.mkdir(save_folder_mask_obj_folder)
595
+ for i_frame, f_mask in enumerate(obj_masks):
596
+ f_mask = f_mask.numpy()
597
+ f_mask = f_mask * 255
598
+ f_mask = np.stack([f_mask * 1, f_mask * 0, f_mask * 0], axis=2)
599
+ f_mask = show_images[i_frame] * 0.3 + 0.7 * f_mask
600
+ f_mask = f_mask.astype(np.uint8)
601
+ cv2.imwrite(os.path.join(save_folder_mask_obj_folder, '{}.png'.format(i_frame)), f_mask)
602
+ return
projects/llava_sam2/datasets/RefCOCO_Dataset.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import glob
4
+ import json
5
+ import logging
6
+ import os
7
+ from typing import Literal
8
+
9
+ import torch
10
+
11
+ from mmengine import print_log
12
+ from mmengine.config import Config, ConfigDict
13
+ from PIL import Image
14
+ from torch.utils.data import Dataset
15
+ import numpy as np
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as T
18
+ from torchvision.transforms.functional import InterpolationMode
19
+ from pycocotools.coco import COCO
20
+ from pycocotools import mask as mask_utils
21
+
22
+ from xtuner.registry import BUILDER
23
+ from xtuner.utils import IGNORE_INDEX
24
+ from xtuner.dataset.utils import encode_fn
25
+ from xtuner.dataset.map_fns import llava_map_fn
26
+
27
+ from projects.glamm.datasets.utils.utils import expand2square
28
+
29
+ from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
30
+ from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
31
+
32
+ from third_parts.mmdet.datasets.refcoco import RefCocoDataset
33
+
34
+ from .utils import dynamic_preprocess
35
+
36
+
37
+ class ReferSegmDataset(RefCocoDataset):
38
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
39
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
40
+ IMG_START_TOKEN = '<img>'
41
+ IMG_END_TOKEN = '</img>'
42
+
43
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
44
+ IMAGENET_STD = (0.229, 0.224, 0.225)
45
+
46
+ def __init__(self,
47
+ data_root,
48
+ ann_file=None,
49
+ split_file=None,
50
+ special_tokens=None,
51
+ prompt_template=None,
52
+ extra_image_processor=None,
53
+ data_prefix=dict(img_path='train2014/'),
54
+ tokenizer=None,
55
+ max_length=2048,
56
+ num_classes_per_sample=3,
57
+ single_image_mode=False,
58
+ arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
59
+ preprocessor=None,
60
+ **kwargs):
61
+ super().__init__(
62
+ data_root=data_root,
63
+ data_prefix=data_prefix,
64
+ pipeline=None,
65
+ ann_file=ann_file,
66
+ split_file=split_file,
67
+ **kwargs,
68
+ )
69
+ self.begin_str = f'{DEFAULT_IMAGE_TOKEN}\n'
70
+ if extra_image_processor is not None:
71
+ self.extra_image_processor = BUILDER.build(extra_image_processor)
72
+
73
+ self.arch_type = arch_type
74
+ if self.arch_type == 'qwen':
75
+ self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
76
+ self.IMG_START_TOKEN = '<|vision_start|>'
77
+ self.IMG_END_TOKEN = '<|vision_end|>'
78
+ elif self.arch_type == 'llava':
79
+ self.IMG_CONTEXT_TOKEN = '<image>'
80
+ self.IMG_START_TOKEN = ''
81
+ self.IMG_END_TOKEN = ''
82
+
83
+ self.tokenizer = BUILDER.build(tokenizer)
84
+ if special_tokens is not None:
85
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
86
+
87
+ self.image_folder = data_root
88
+ self.template = prompt_template
89
+ self.max_length = max_length
90
+ if self.arch_type == 'intern_vl':
91
+ # self._system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
92
+ self._system = ''
93
+ self.template['INSTRUCTION'] = '<|user|>\n{input}<|end|><|assistant|>\n'
94
+ elif self.arch_type == 'qwen':
95
+ self._system = ''
96
+ elif self.arch_type == 'llava':
97
+ self._system = ''
98
+
99
+ self.num_classes_per_sample = num_classes_per_sample
100
+ self.min_dynamic_patch = 1
101
+ self.max_dynamic_patch = 12
102
+ self.downsample_ratio = 0.5
103
+ if self.arch_type == 'llava':
104
+ self.downsample_ratio = 1
105
+ self.image_size = 448
106
+ if self.arch_type == 'llava':
107
+ self.image_size = 336
108
+ self.use_thumbnail = True
109
+ patch_size = 14
110
+ self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
111
+
112
+ if preprocessor is None:
113
+ self.transformer = T.Compose([
114
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
115
+ T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
116
+ T.ToTensor(),
117
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
118
+ ])
119
+ self.preprocessor = None
120
+ else:
121
+ self.transformer = None
122
+ self.preprocessor = BUILDER.build(preprocessor)
123
+ self.arch_type = arch_type
124
+ self.single_image_mode = single_image_mode
125
+ self._max_refetch = 1000
126
+
127
+ print("Image RES dataset, include {} items.".format(len(self)))
128
+
129
+ @property
130
+ def modality_length(self):
131
+ import pickle
132
+ length_list = []
133
+ for idx in range(len(self)):
134
+ length_list.append(100)
135
+ return length_list
136
+
137
+ def _parse_annotations(self, ann_info):
138
+ image_path = ann_info['img_path']
139
+ image = Image.open(image_path).convert('RGB')
140
+ width, height = image.size
141
+
142
+ masks, phrases = [], []
143
+ instances, text = ann_info['instances'], ann_info['text']
144
+ # index = np.random.choice(range(len(instances)), min(
145
+ # len(instances), self.num_classes_per_sample))
146
+ index = np.random.choice(range(len(instances)), self.num_classes_per_sample, replace=True)
147
+ for idx in index:
148
+ inst = instances[idx]
149
+ phrase = text[idx].lower()
150
+ if '.' == phrase[-1]:
151
+ phrase = phrase[:-1]
152
+ phrases.append(phrase)
153
+ binary_mask = np.zeros((height, width), dtype=np.uint8)
154
+ for seg in inst["mask"]:
155
+ rles = mask_utils.frPyObjects([seg], height, width)
156
+ m = mask_utils.decode(rles)
157
+ m = m.astype(np.uint8)
158
+ binary_mask += m.squeeze()
159
+ masks.append(binary_mask)
160
+
161
+ conversation = []
162
+ for i, phrase in enumerate(phrases):
163
+ question = random.choice(SEG_QUESTIONS).format(class_name=phrase)
164
+ if i == 0:
165
+ question = self.begin_str + question
166
+ conversation.append({'from': 'human', 'value': question})
167
+ conversation.append({'from': 'gpt', 'value': random.choice(ANSWER_LIST)})
168
+ masks = torch.stack([torch.from_numpy(mask) for mask in masks], dim=0)
169
+
170
+ ann_info.update({
171
+ 'masks': masks,
172
+ 'conversations': conversation,
173
+ 'image': image_path
174
+ })
175
+ return ann_info
176
+
177
+ def prepare_data(self, index):
178
+ data_dict = super().prepare_data(index)
179
+ data_dict = self._parse_annotations(data_dict)
180
+ if data_dict is None:
181
+ return None
182
+
183
+ out_data_dict = {}
184
+ if 'masks' in data_dict:
185
+ out_data_dict['masks'] = data_dict['masks']
186
+
187
+ if data_dict.get('image', None) is not None:
188
+ image_file = data_dict['image']
189
+ try:
190
+ image = Image.open(image_file).convert('RGB')
191
+ except Exception as e:
192
+ print(f'Error: {e}', flush=True)
193
+ print_log(f'Error: {e}', logger='current')
194
+ return None
195
+ if hasattr(self, 'extra_image_processor'):
196
+ g_image = np.array(image) # for grounding
197
+ g_image = self.extra_image_processor.apply_image(g_image)
198
+ g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
199
+ out_data_dict['g_pixel_values'] = g_pixel_values
200
+
201
+ if self.single_image_mode:
202
+ images = [image]
203
+ else:
204
+ images = dynamic_preprocess(image, self.min_dynamic_patch,
205
+ self.max_dynamic_patch,
206
+ self.image_size, self.use_thumbnail)
207
+ if self.preprocessor is not None:
208
+ if self.arch_type == 'qwen':
209
+ _data_dict = self.preprocessor(images, do_resize=True)
210
+ _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
211
+ _data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
212
+ num_image_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
213
+ elif self.arch_type == 'llava':
214
+ _data_dict = self.preprocessor(images, do_resize=True, size=(self.image_size, self.image_size))
215
+ _data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
216
+ _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
217
+ num_image_tokens = _data_dict['pixel_values'].shape[0] * self.patch_token
218
+ else:
219
+ raise NotImplementedError
220
+ out_data_dict.update(_data_dict)
221
+ else:
222
+ pixel_values = [self.transformer(image) for image in images]
223
+ pixel_values = torch.stack(pixel_values)
224
+ out_data_dict['pixel_values'] = pixel_values
225
+
226
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
227
+ image_token_str = f'{self.IMG_START_TOKEN}' \
228
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
229
+ f'{self.IMG_END_TOKEN}'
230
+ token_dict = self.get_inputid_labels(data_dict['conversations'], image_token_str)
231
+ out_data_dict.update(token_dict)
232
+ else:
233
+ token_dict = self.get_inputid_labels(data_dict['conversations'], None)
234
+ out_data_dict.update(token_dict)
235
+ out_data_dict['pixel_values'] = torch.zeros(1, 3, self.image_size, self.image_size)
236
+ return out_data_dict
237
+
238
+ def get_inputid_labels(self, conversations, image_token_str) -> dict:
239
+ input = ''
240
+ out_conversation = []
241
+ while conversations and conversations[0]['from'] == 'gpt':
242
+ # Skip the first one if it is from gpt
243
+ conversations = conversations[1:]
244
+ for msg in conversations:
245
+ if msg['from'] == 'human':
246
+ if image_token_str is None and '<image>' in msg['value']:
247
+ msg['value'] = msg['value'].replace('<image>', '')
248
+ if '<image>' in msg['value']:
249
+ msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
250
+ input += msg['value'].strip()
251
+ elif msg['from'] == 'gpt':
252
+ out_conversation.append({
253
+ 'input': input,
254
+ 'output': msg['value'].strip()
255
+ })
256
+ input = ''
257
+ else:
258
+ raise NotImplementedError
259
+
260
+ input_ids, labels = [], []
261
+ for i, single_turn_conversation in enumerate(out_conversation):
262
+ input = single_turn_conversation.get('input', '')
263
+ if input is None:
264
+ input = ''
265
+ input_text = self.template.INSTRUCTION.format(
266
+ input=input, round=i + 1)
267
+
268
+ if i == 0:
269
+ if self._system != '' and self._system is not None:
270
+ system = self.template.SYSTEM.format(system=self._system)
271
+ input_text = system + input_text
272
+ input_encode = self.tokenizer.encode(
273
+ input_text, add_special_tokens=True)
274
+ else:
275
+ input_encode = self.tokenizer.encode(
276
+ input_text, add_special_tokens=False)
277
+ input_ids += input_encode
278
+ labels += [IGNORE_INDEX] * len(input_encode)
279
+
280
+ output_text = single_turn_conversation.get('output', '')
281
+ if self.template.get('SUFFIX', None):
282
+ output_text += self.template.SUFFIX
283
+ output_encode = self.tokenizer.encode(
284
+ output_text, add_special_tokens=False)
285
+ input_ids += output_encode
286
+ labels += copy.deepcopy(output_encode)
287
+
288
+ if len(input_ids) > self.max_length:
289
+ input_ids = input_ids[:self.max_length]
290
+ labels = labels[:self.max_length]
291
+ # print('len_ids: ', len(input_ids))
292
+ return {'input_ids': input_ids, 'labels': labels}
293
+
294
+ def __getitem__(self, index):
295
+ for _ in range(self._max_refetch + 1):
296
+ data = self.prepare_data(index)
297
+ # Broken images may cause the returned data to be None
298
+ if data is None:
299
+ index = self._rand_another()
300
+ continue
301
+ return data
302
+
303
+
304
+ if __name__ == '__main__':
305
+ from transformers import CLIPImageProcessor, AutoTokenizer
306
+ from third_parts.segment_anything.utils.transforms import ResizeLongestSide
307
+
308
+ pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
309
+ llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
310
+
311
+ tokenizer = dict(
312
+ type=AutoTokenizer.from_pretrained,
313
+ pretrained_model_name_or_path=llm_name_or_path)
314
+ image_processor = dict(
315
+ type=CLIPImageProcessor.from_pretrained,
316
+ pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
317
+ extra_image_processor = dict(
318
+ type=ResizeLongestSide,
319
+ target_length=1024,
320
+ )
321
+ from xtuner.utils.templates import PROMPT_TEMPLATE
322
+
323
+ prompt_template = PROMPT_TEMPLATE.vicuna
324
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
325
+ from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
326
+
327
+ dataset = ReferSegmDataset(
328
+ tokenizer=tokenizer,
329
+ special_tokens=['[SEG]'],
330
+ extra_image_processor=extra_image_processor,
331
+ prompt_template=prompt_template,
332
+ data_root='data/coco/',
333
+ data_prefix=dict(img_path='train2014/'),
334
+ ann_file='refcoco+/instances.json',
335
+ split_file='refcoco+/refs(unc).p',
336
+ )
337
+ for i in range(1000):
338
+ dataset[i]
projects/llava_sam2/datasets/RefYoutubeVOS_Dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ReVOS_Dataset import VideoReVOSDataset
2
+ import json
3
+ import pickle
4
+
5
+ class VideoRefYoutubeVOSDataset(VideoReVOSDataset):
6
+
7
+ def json_file_preprocess(self, expression_file, mask_file):
8
+ # prepare expression annotation files
9
+ with open(expression_file, 'r') as f:
10
+ expression_datas = json.load(f)['videos']
11
+
12
+ metas = []
13
+ anno_count = 0 # serve as anno_id
14
+ vid2metaid = {}
15
+ for vid_name in expression_datas:
16
+ vid_express_data = expression_datas[vid_name]
17
+
18
+ vid_frames = sorted(vid_express_data['frames'])
19
+ vid_len = len(vid_frames)
20
+
21
+ exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
22
+ for exp_id in exp_id_list:
23
+ exp_dict = vid_express_data['expressions'][exp_id]
24
+ meta = {}
25
+ meta['video'] = vid_name
26
+ meta['exp'] = exp_dict['exp'] # str
27
+ meta['mask_anno_id'] = [str(anno_count), ]
28
+
29
+ if 'obj_id' in exp_dict.keys():
30
+ meta['obj_id'] = exp_dict['obj_id']
31
+ else:
32
+ meta['obj_id'] = [0, ] # Ref-Youtube-VOS only has one object per expression
33
+ meta['anno_id'] = [str(anno_count), ]
34
+ anno_count += 1
35
+ meta['frames'] = vid_frames
36
+ meta['exp_id'] = exp_id
37
+
38
+ meta['length'] = vid_len
39
+ metas.append(meta)
40
+ if vid_name not in vid2metaid.keys():
41
+ vid2metaid[vid_name] = []
42
+ vid2metaid[vid_name].append(len(metas) - 1)
43
+
44
+ # process mask annotation files
45
+ with open(mask_file, 'rb') as f:
46
+ mask_dict = pickle.load(f)
47
+ return vid2metaid, metas, mask_dict
projects/llava_sam2/datasets/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .collect_fns import video_lisa_collate_fn
2
+ from .MeVIS_Dataset import VideoMeVISDataset
3
+ from .ReVOS_Dataset import VideoReVOSDataset
4
+ from .RefYoutubeVOS_Dataset import VideoRefYoutubeVOSDataset
5
+ from .encode_fn import video_lisa_encode_fn
6
+ from .RefCOCO_Dataset import ReferSegmDataset
7
+ from .ReSAM2_Dataset import VideoSAM2Dataset
8
+ from .vqa_dataset import LLaVADataset, InfinityMMDataset
9
+
10
+ from .GCG_Dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset
11
+ from .Grand_Dataset import GranDDataset
12
+
13
+ from .Osprey_Dataset import OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
14
+
15
+ from .ChatUniVi_Dataset import VideoChatUniViDataset
projects/llava_sam2/datasets/collect_fns.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Sequence
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn.utils.rnn import pad_sequence
6
+
7
+ from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
8
+ pad_for_sequence_parallel)
9
+ from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
10
+
11
+
12
+ def video_lisa_collate_fn(instances: Sequence[Dict],
13
+ pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
14
+ return_hf_format: bool = False,
15
+ use_varlen_attn: bool = False):
16
+ seq_parallel_world_size = get_sequence_parallel_world_size()
17
+
18
+ input_ids, labels = [], []
19
+ has_image = any(inst.get('pixel_values') is not None for inst in instances)
20
+ has_pe = any(inst.get('image_grid_thw', None) is not None for inst in instances)
21
+ has_fast_image = any(inst.get('fast_pixel_values', None) is not None for inst in instances)
22
+ has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances)
23
+ has_mask = any(inst.get('masks') is not None for inst in instances)
24
+ has_bboxes = any(inst.get('bboxes') is not None for inst in instances)
25
+ has_points = any(inst.get('points') is not None for inst in instances)
26
+ has_fast_exists = any(inst.get('fast_exists') is not None for inst in instances)
27
+
28
+ has_vp = any(inst.get('vp_overall_mask') is not None for inst in instances)
29
+ has_prompt_mask = any(inst.get('prompt_masks') is not None for inst in instances)
30
+
31
+ if use_varlen_attn:
32
+ position_ids, cumulative_len = [], []
33
+ assert len(instances) == 1, (
34
+ f'If utilizing varlen attention, the batch size should be'
35
+ f' set to 1, but got {len(instances)}')
36
+ assert not has_image, 'Currently, it is not configured to '
37
+ 'accommodate the use of varlen Attention in multimodal training'
38
+
39
+ if has_image:
40
+ pixel_values = []
41
+ frames_per_batch = []
42
+ image_grid_thw = []
43
+ if has_grounding_image:
44
+ grounding_pixel_values = []
45
+ if has_mask:
46
+ object_masks = []
47
+ if has_bboxes:
48
+ object_bboxes = []
49
+ if has_points:
50
+ prompt_points = []
51
+ if has_fast_image:
52
+ fast_pixel_values = []
53
+ if has_fast_exists:
54
+ fast_exists = []
55
+ if has_vp:
56
+ vp_overall_mask = []
57
+ else:
58
+ vp_overall_mask = None
59
+
60
+ if has_prompt_mask:
61
+ prompt_masks = []
62
+ else:
63
+ prompt_masks = None
64
+
65
+ for example in instances:
66
+ input_ids.append(torch.LongTensor(example['input_ids']))
67
+ labels.append(torch.LongTensor(example['labels']))
68
+ if use_varlen_attn:
69
+ cumulative_len.append(torch.IntTensor(example['cumulative_len']))
70
+ position_ids.append(torch.LongTensor(example['position_ids']))
71
+
72
+ if has_image:
73
+ pixel_values.append(example['pixel_values'])
74
+ if has_pe:
75
+ image_grid_thw.append(example['image_grid_thw'])
76
+ if has_vp:
77
+ if 'vp_overall_mask' in example.keys() and example['vp_overall_mask'] is not None:
78
+ vp_overall_mask.append(example['vp_overall_mask'])
79
+ else:
80
+ vp_overall_mask.append(torch.Tensor([False] * len(pixel_values[-1])))
81
+ if has_fast_image:
82
+ if 'fast_pixel_values' in example.keys() and example['fast_pixel_values'] is not None:
83
+ fast_pixel_values.append(example['fast_pixel_values'])
84
+ if has_fast_exists:
85
+ if 'fast_exists' in example.keys() and example['fast_exists'] is not None:
86
+ fast_exists.append(example['fast_exists'])
87
+ if has_grounding_image and 'g_pixel_values' in example.keys():
88
+ if isinstance(example['g_pixel_values'], list):
89
+ grounding_pixel_values += example['g_pixel_values']
90
+ frames_per_batch.append(len(example['g_pixel_values']))
91
+ else:
92
+ grounding_pixel_values.append(example['g_pixel_values'])
93
+ frames_per_batch.append(1)
94
+
95
+ if has_mask:
96
+ if 'masks' in example.keys() and example['masks'] is not None:
97
+ if isinstance(example['masks'], list):
98
+ if isinstance(example['masks'][0], np.ndarray):
99
+ _masks = np.stack(example['masks'], axis=0)
100
+ _masks = torch.from_numpy(_masks)
101
+ object_masks.append(_masks)
102
+ else:
103
+ object_masks.append(torch.stack(example['masks'], dim=0))
104
+ else:
105
+ object_masks.append(example['masks'])
106
+ if has_bboxes:
107
+ if 'bboxes' in example.keys() and example['bboxes'] is not None:
108
+ object_bboxes.append(example['bboxes'])
109
+ if has_points:
110
+ if 'points' in example.keys() and example['points'] is not None:
111
+ prompt_points.append(example['points'])
112
+
113
+ if has_prompt_mask:
114
+ if 'prompt_masks' in example.keys():
115
+ prompt_masks.append(example['prompt_masks'])
116
+
117
+ ori_length = [len(ids) for ids in input_ids]
118
+ if len(instances) > 1:
119
+ input_ids = pad_sequence(
120
+ input_ids, batch_first=True, padding_value=pad_index)
121
+ labels = pad_sequence(
122
+ labels, batch_first=True, padding_value=IGNORE_INDEX)
123
+ else:
124
+ input_ids = torch.stack(input_ids)
125
+ labels = torch.stack(labels)
126
+
127
+ if use_varlen_attn:
128
+ assert input_ids.size(1) % seq_parallel_world_size == 0
129
+ attention_mask = None
130
+ position_ids = torch.stack(position_ids, dim=0)
131
+ else:
132
+ # Some tokenizers have the same eos token and pad token, so input_ids
133
+ # cannot be masked directly based on the pad token id.
134
+ attention_mask = torch.zeros_like(input_ids).bool()
135
+ for i, length in enumerate(ori_length):
136
+ attention_mask[i, :length] = True
137
+
138
+ bs, seq_len = input_ids.shape
139
+ position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
140
+
141
+ if seq_parallel_world_size > 1:
142
+ input_ids = pad_for_sequence_parallel(input_ids, pad_index)
143
+ labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
144
+ position_ids = pad_for_sequence_parallel(position_ids, 0)
145
+ if attention_mask is not None:
146
+ attention_mask = pad_for_sequence_parallel(attention_mask, 0)
147
+
148
+ if use_varlen_attn:
149
+ max_seqlen = (
150
+ cumulative_len[0][1:] - # noqa: W504
151
+ cumulative_len[0][:-1]).max().item()
152
+ data_dict = {
153
+ 'input_ids': input_ids,
154
+ 'cumulative_len': cumulative_len,
155
+ 'position_ids': position_ids,
156
+ 'labels': labels,
157
+ 'max_seqlen': max_seqlen
158
+ }
159
+ else:
160
+ data_dict = {
161
+ 'input_ids': input_ids,
162
+ 'attention_mask': attention_mask,
163
+ 'position_ids': position_ids,
164
+ 'labels': labels
165
+ }
166
+
167
+ if has_image:
168
+ if all(x.shape == pixel_values[0].shape for x in pixel_values):
169
+ pixel_values = torch.stack(pixel_values, dim=0)
170
+ data_dict['frames_per_batch'] = frames_per_batch
171
+ data_dict['pixel_values'] = pixel_values
172
+ if has_pe:
173
+ data_dict['image_grid_thw'] = image_grid_thw
174
+
175
+ if has_fast_image:
176
+ if all(x.shape == fast_pixel_values[0].shape for x in fast_pixel_values):
177
+ fast_pixel_values = torch.stack(fast_pixel_values, dim=0)
178
+ data_dict['fast_pixel_values'] = fast_pixel_values
179
+
180
+ if has_fast_exists:
181
+ data_dict['fast_exists'] = fast_exists
182
+
183
+ if has_vp:
184
+ data_dict['vp_overall_mask'] = torch.cat(vp_overall_mask, dim=0)
185
+
186
+ if has_prompt_mask:
187
+ data_dict['prompt_masks'] = prompt_masks
188
+
189
+ if has_grounding_image:
190
+ # if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values):
191
+ # grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0)
192
+ data_dict['g_pixel_values'] = grounding_pixel_values
193
+
194
+ if has_mask:
195
+ data_dict['masks'] = object_masks
196
+
197
+ if has_bboxes:
198
+ data_dict['bboxes'] = object_bboxes
199
+
200
+ if has_points:
201
+ data_dict['points'] = prompt_points
202
+
203
+ if return_hf_format:
204
+ return data_dict
205
+ else:
206
+ return {'data': data_dict, 'data_samples': None}
projects/llava_sam2/datasets/encode_fn.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from xtuner.dataset.utils import get_bos_eos_token_ids
3
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
4
+
5
+ def video_lisa_encode_fn(
6
+ example,
7
+ tokenizer,
8
+ max_length,
9
+ input_ids_with_output=True,
10
+ **kwargs
11
+ ):
12
+ """We only support the following three scenarios:
13
+
14
+ 1. Incremental pretraining dataset.
15
+ example['conversation'] = [
16
+ {
17
+ 'input': '',
18
+ 'output': '### Human: Can you write xxx'
19
+ }
20
+ ]
21
+
22
+ 2. Single-turn conversation dataset.
23
+ example['conversation'] = [
24
+ {
25
+ 'input': 'Give three tips for staying healthy.',
26
+ 'output': '1.Eat a balanced diet xxx'
27
+ }
28
+ ]
29
+
30
+ 3. Multi-turn conversation dataset.
31
+ example['conversation'] = [
32
+ {
33
+ 'input': 'Give three tips for staying healthy.',
34
+ 'output': '1.Eat a balanced diet xxx'
35
+ },
36
+ {
37
+ 'input': 'Please expand on the second point.',
38
+ 'output': 'Here is an expanded explanation of the xxx'
39
+ }
40
+ ]
41
+ """
42
+ bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
43
+ is_multi_turn_conversation = len(example['conversation']) > 1
44
+ if is_multi_turn_conversation:
45
+ assert input_ids_with_output
46
+
47
+ input_ids, labels = [], []
48
+ next_needs_bos_token = True
49
+ for single_turn_conversation in example['conversation']:
50
+ input = single_turn_conversation['input']
51
+ input_encode = tokenizer.encode(input, add_special_tokens=False)
52
+ if next_needs_bos_token:
53
+ input_ids += bos_token_id
54
+ labels += [IGNORE_INDEX] * len(bos_token_id)
55
+ input_ids += input_encode
56
+ labels += [IGNORE_INDEX] * len(input_encode)
57
+ if input_ids_with_output:
58
+ # Add output
59
+ output_with_loss = single_turn_conversation.get(
60
+ 'output_with_loss', True)
61
+ output = single_turn_conversation['output']
62
+ output_encode = tokenizer.encode(output, add_special_tokens=False)
63
+ input_ids += output_encode
64
+ if output_with_loss:
65
+ labels += copy.deepcopy(output_encode)
66
+ else:
67
+ labels += [IGNORE_INDEX] * len(output_encode)
68
+ # Add EOS_TOKEN (with loss)
69
+ if single_turn_conversation.get('need_eos_token', True):
70
+ next_needs_bos_token = True
71
+ input_ids += eos_token_id
72
+ if output_with_loss:
73
+ labels += copy.deepcopy(eos_token_id)
74
+ else:
75
+ labels += [IGNORE_INDEX] * len(eos_token_id)
76
+ else:
77
+ next_needs_bos_token = False
78
+ # Add SEP (without loss)
79
+ sep = single_turn_conversation.get('sep', '')
80
+ if sep != '':
81
+ sep_encode = tokenizer.encode(sep, add_special_tokens=False)
82
+ input_ids += sep_encode
83
+ labels += [IGNORE_INDEX] * len(sep_encode)
84
+
85
+ if len(input_ids) > max_length:
86
+ input_ids = input_ids[:max_length]
87
+ labels = labels[:max_length]
88
+ return {'input_ids': input_ids, 'labels': labels}
89
+
90
+
91
+ def video_lisa_encode_multi_conv_fn(
92
+ example,
93
+ tokenizer,
94
+ max_length,
95
+ input_ids_with_output=True
96
+ ):
97
+ """We only support the following three scenarios:
98
+
99
+ 1. Incremental pretraining dataset.
100
+ example['conversation'] = [
101
+ {
102
+ 'input': '',
103
+ 'output': '### Human: Can you write xxx'
104
+ }
105
+ ]
106
+
107
+ 2. Single-turn conversation dataset.
108
+ example['conversation'] = [
109
+ {
110
+ 'input': 'Give three tips for staying healthy.',
111
+ 'output': '1.Eat a balanced diet xxx'
112
+ }
113
+ ]
114
+
115
+ 3. Multi-turn conversation dataset.
116
+ example['conversation'] = [
117
+ {
118
+ 'input': 'Give three tips for staying healthy.',
119
+ 'output': '1.Eat a balanced diet xxx'
120
+ },
121
+ {
122
+ 'input': 'Please expand on the second point.',
123
+ 'output': 'Here is an expanded explanation of the xxx'
124
+ }
125
+ ]
126
+ """
127
+ bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
128
+ assert not input_ids_with_output
129
+ input_id_list = []
130
+ for conv in example['conversation']:
131
+ input_ids = []
132
+ next_needs_bos_token = True
133
+ for single_turn_conversation in conv:
134
+ input = single_turn_conversation['input']
135
+ input_encode = tokenizer.encode(input, add_special_tokens=False)
136
+ if next_needs_bos_token:
137
+ input_ids += bos_token_id
138
+ input_ids += input_encode
139
+
140
+ if len(input_ids) > max_length:
141
+ input_ids = input_ids[:max_length]
142
+
143
+ input_id_list.append(input_ids)
144
+ return {'input_ids': input_id_list}
projects/llava_sam2/datasets/gcg_process.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN
4
+
5
+ GCG_QUESTIONS = [
6
+ DEFAULT_IMAGE_TOKEN + 'Could you please give me a brief description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
7
+ DEFAULT_IMAGE_TOKEN + 'Can you provide a brief description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
8
+ DEFAULT_IMAGE_TOKEN + 'Please briefly describe the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
9
+ DEFAULT_IMAGE_TOKEN + 'Could you give a brief explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
10
+ DEFAULT_IMAGE_TOKEN + 'Could you give me an brief explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
11
+ DEFAULT_IMAGE_TOKEN + 'Could you provide me with a briefly analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
12
+ ]
13
+
14
+ def refcocog_parse_annotations(example):
15
+ # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str}
16
+ annotations = {'labels': [], 'caption': [], 'masks': [], 'tokens_positive': [],
17
+ 'file_name': example['img_file_name'], 'image': example['img_file_name']}
18
+
19
+ orig_caption = example['caption'].strip('"').strip()
20
+ annotations['caption'] = orig_caption.lower()
21
+
22
+ for detail in example['refs']:
23
+ phrase = detail['sentence']
24
+ if phrase.lower() in annotations['caption']:
25
+ annotations['labels'].append(phrase)
26
+ index = annotations['caption'].find(phrase)
27
+ end_index = index + len(phrase) if index != -1 else -1
28
+ annotations['tokens_positive'].append([index, end_index])
29
+ # still polygon or rle
30
+ annotations['masks'].append(detail["segmentation"])
31
+
32
+ # Sort tokens_positive and corresponding lists
33
+ tokens_positive = annotations['tokens_positive']
34
+ sorted_indices = sorted(range(len(tokens_positive)), key=lambda i: tokens_positive[i][0])
35
+ annotations['tokens_positive'] = [tokens_positive[i] for i in sorted_indices]
36
+ annotations['masks'] = [annotations['masks'][i] for i in sorted_indices]
37
+ annotations['labels'] = [annotations['labels'][i] for i in sorted_indices]
38
+
39
+ # Trimming overlapping intervals
40
+ for i in range(len(tokens_positive)):
41
+ for j in range(i + 1, len(tokens_positive)):
42
+ # If there is overlap
43
+ if tokens_positive[i][1] >= tokens_positive[j][0]:
44
+ # Modify the end index of phrase i to be one less than the start index of phrase j
45
+ tokens_positive[i][1] = tokens_positive[j][0] - 1
46
+ # Modify the phrases to reflect the change in indices
47
+ annotations['labels'][i] = orig_caption[tokens_positive[i][0]:tokens_positive[i][1] + 1]
48
+ break # Exit inner loop since i was modified
49
+
50
+ return annotations
51
+
52
+ def refcocog_conversation(caption, tokens_positive):
53
+ # insert <p> </p> and [seg] to caption and select a question
54
+ question = random.choice(GCG_QUESTIONS).strip()
55
+
56
+ # Prepare caption with tags
57
+ def tag_caption(caption, tokens):
58
+ for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
59
+ caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
60
+ return caption
61
+
62
+ detailed_answer = tag_caption(caption, tokens_positive)
63
+
64
+ conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
65
+ return conversations
66
+
67
+ def refcocog_preprocess(example):
68
+ data_labels = example['labels']
69
+ masks = example['masks']
70
+ caption = example['caption']
71
+ tokens_positive = example['tokens_positive']
72
+
73
+ # Function to sort elements based on the start index of each phrase
74
+ def sort_by_start_index(items, order):
75
+ return [items[i] for i in order]
76
+
77
+ # Sort phrases based on their appearance in the sentence
78
+ phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
79
+ masks = sort_by_start_index(masks, phrase_order)
80
+ data_labels = sort_by_start_index(data_labels, phrase_order)
81
+ tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
82
+
83
+ conversations = refcocog_conversation(caption, tokens_positive)
84
+ example['conversations'] = conversations
85
+ example['labels'] = data_labels
86
+ example['masks'] = masks
87
+ example['tokens_positive'] = tokens_positive
88
+
89
+ return example
90
+
91
+ def glamm_refcocog_map_fn(example):
92
+ # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str}
93
+
94
+ example = refcocog_parse_annotations(example)
95
+ # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
96
+
97
+ example = refcocog_preprocess(example)
98
+
99
+ # do llava preprocess
100
+ messages = example['conversations']
101
+ input = ''
102
+ conversation = []
103
+ while messages and messages[0]['from'] == 'gpt':
104
+ # Skip the first one if it is from gpt
105
+ messages = messages[1:]
106
+ for msg in messages:
107
+ if msg['from'] == 'human':
108
+ if DEFAULT_IMAGE_TOKEN in msg['value']:
109
+ msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
110
+ '').strip()
111
+ msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
112
+ msg['value'] = msg['value'].strip()
113
+ input += msg['value']
114
+
115
+ elif msg['from'] == 'gpt':
116
+ conversation.append({'input': input, 'output': msg['value']})
117
+ input = ''
118
+ else:
119
+ raise NotImplementedError
120
+ example.update({'conversation': conversation})
121
+ return example
122
+
123
+ def grandf_parse_annotations(example):
124
+ image_path = example['file_name']
125
+ annotations = {
126
+ 'labels': [], 'caption': [], 'masks': [],
127
+ 'tokens_positive': [], 'file_name': image_path,
128
+ 'image': image_path}
129
+ annotations['caption'] = example['caption'].strip('"').strip()
130
+
131
+ for word, grounding in example["groundings"].items():
132
+ if grounding is None:
133
+ continue
134
+ annotations['labels'].append(word)
135
+ annotations['tokens_positive'].append(grounding["token_positives"])
136
+ annotations['masks'].append(grounding["rle_masks"])
137
+
138
+ return annotations
139
+
140
+ def grandf_conversation(caption, tokens_positive):
141
+ question = random.choice(GCG_QUESTIONS).strip()
142
+
143
+ # Prepare caption with tags
144
+ def tag_caption(caption, tokens):
145
+ for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
146
+ caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
147
+ return caption
148
+
149
+ detailed_answer = tag_caption(caption, tokens_positive)
150
+
151
+ conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
152
+ return conversations
153
+ def grandf_preprocess(example):
154
+ data_labels = example['labels']
155
+ masks = example['masks']
156
+ caption = example['caption']
157
+ tokens_positive = example['tokens_positive']
158
+
159
+ # Function to sort elements based on the start index of each phrase
160
+ def sort_by_start_index(items, order):
161
+ return [items[i] for i in order]
162
+
163
+ # Sort phrases based on their appearance in the sentence
164
+ phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
165
+ masks = sort_by_start_index(masks, phrase_order)
166
+ data_labels = sort_by_start_index(data_labels, phrase_order)
167
+ tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
168
+
169
+ conversations = grandf_conversation(caption, tokens_positive)
170
+ example['conversations'] = conversations
171
+ example['labels'] = data_labels
172
+ example['masks'] = masks
173
+ example['tokens_positive'] = tokens_positive
174
+ return example
175
+
176
+ def glamm_granf_map_fn(example):
177
+ # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
178
+ # "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
179
+ example = grandf_parse_annotations(example)
180
+ # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
181
+
182
+ example = grandf_preprocess(example)
183
+
184
+ # do llava preprocess
185
+ messages = example['conversations']
186
+ input = ''
187
+ conversation = []
188
+ while messages and messages[0]['from'] == 'gpt':
189
+ # Skip the first one if it is from gpt
190
+ messages = messages[1:]
191
+ for msg in messages:
192
+ if msg['from'] == 'human':
193
+ if DEFAULT_IMAGE_TOKEN in msg['value']:
194
+ msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
195
+ '').strip()
196
+ msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
197
+ msg['value'] = msg['value'].strip()
198
+ input += msg['value']
199
+
200
+ elif msg['from'] == 'gpt':
201
+ conversation.append({'input': input, 'output': msg['value']})
202
+ input = ''
203
+ else:
204
+ raise NotImplementedError
205
+ example.update({'conversation': conversation})
206
+ return example
207
+
208
+ glamm_openpsg_map_fn = glamm_granf_map_fn
209
+
210
+ def flickr_parse_annotations(example):
211
+ annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'caption': example['caption'], 'masks': [],
212
+ 'tokens_positive': [], 'image': example['file_name']}
213
+ ann_info = example["ann_info"]
214
+ for ann in ann_info:
215
+ if ann.get('ignore', False):
216
+ continue
217
+ x1, y1, w, h = ann['bbox']
218
+ inter_w = max(0, min(x1 + w, example['width']) - max(x1, 0))
219
+ inter_h = max(0, min(y1 + h, example['height']) - max(y1, 0))
220
+ if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1:
221
+ continue
222
+ bbox = [x1, y1, x1 + w, y1 + h]
223
+ annotations['bboxes'].append(bbox)
224
+ tokens_positive = ann['tokens_positive']
225
+ gt_label = [example['caption'][span[0]:span[1]] for span in tokens_positive]
226
+ annotations['labels'].append(gt_label[0])
227
+ annotations['tokens_positive'].append(tokens_positive[0])
228
+
229
+ rle = ann['sam_mask']
230
+ annotations['masks'].append(rle)
231
+
232
+ # Convert bounding boxes to numpy arrays
233
+ annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[
234
+ 'bboxes'] else np.zeros((0, 4), dtype=np.float32)
235
+ annotations['bboxes_ignore'] = np.array(annotations['bboxes_ignore'], dtype=np.float32) if annotations[
236
+ 'bboxes_ignore'] else np.zeros((0, 4), dtype=np.float32)
237
+ return annotations
238
+
239
+ def flickr_preprocess(example):
240
+ data_labels = example['labels']
241
+ masks = example['masks']
242
+ caption = example['caption']
243
+ tokens_positive = example['tokens_positive']
244
+
245
+ # Function to sort elements based on the start index of each phrase
246
+ def sort_by_start_index(items, order):
247
+ return [items[i] for i in order]
248
+
249
+ # Sort phrases based on their appearance in the sentence
250
+ phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
251
+ masks = sort_by_start_index(masks, phrase_order)
252
+ data_labels = sort_by_start_index(data_labels, phrase_order)
253
+ tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
254
+
255
+ conversations = grandf_conversation(caption, tokens_positive)
256
+ example['conversations'] = conversations
257
+ example['labels'] = data_labels
258
+ example['masks'] = masks
259
+ example['tokens_positive'] = tokens_positive
260
+ return example
261
+
262
+ def glamm_flickr_map_fn(example):
263
+ # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
264
+ # "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
265
+
266
+ example = flickr_parse_annotations(example)
267
+
268
+ example = flickr_preprocess(example)
269
+
270
+ # do llava preprocess
271
+ messages = example['conversations']
272
+ input = ''
273
+ conversation = []
274
+ while messages and messages[0]['from'] == 'gpt':
275
+ # Skip the first one if it is from gpt
276
+ messages = messages[1:]
277
+ for msg in messages:
278
+ if msg['from'] == 'human':
279
+ if DEFAULT_IMAGE_TOKEN in msg['value']:
280
+ msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
281
+ '').strip()
282
+ msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
283
+ msg['value'] = msg['value'].strip()
284
+ input += msg['value']
285
+
286
+ elif msg['from'] == 'gpt':
287
+ conversation.append({'input': input, 'output': msg['value']})
288
+ input = ''
289
+ else:
290
+ raise NotImplementedError
291
+ example.update({'conversation': conversation})
292
+ return example
293
+
294
+
295
+
296
+
297
+
projects/llava_sam2/datasets/grand_process.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN
4
+
5
+ GCG_QUESTIONS = [
6
+ DEFAULT_IMAGE_TOKEN + 'Could you please give me a brief description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
7
+ DEFAULT_IMAGE_TOKEN + 'Can you provide a brief description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
8
+ DEFAULT_IMAGE_TOKEN + 'Please briefly describe the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
9
+ DEFAULT_IMAGE_TOKEN + 'Could you give a brief explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
10
+ DEFAULT_IMAGE_TOKEN + 'Could you give me an brief explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
11
+ DEFAULT_IMAGE_TOKEN + 'Could you provide me with a briefly analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
12
+ ]
13
+
14
+ def grand_parse_annotations(example):
15
+ annotations = {
16
+ 'caption': [], 'masks': [],
17
+ 'tokens_positive': [], 'labels': []}
18
+ annotations['caption'] = example['dense_caption']['caption'].strip('"').strip()
19
+ object_infos = example['dense_caption']['details']
20
+
21
+ all_seg_objects_dict = {}
22
+ for seg_object_dict in example["objects"]:
23
+ all_seg_objects_dict[seg_object_dict['id']] = seg_object_dict
24
+ for seg_object_dict in example["floating_objects"]:
25
+ all_seg_objects_dict[seg_object_dict['id']] = seg_object_dict
26
+
27
+ for object_info in object_infos:
28
+ ids = object_info["ids"]
29
+ if object_info["tokens_positive"] is None:
30
+ continue
31
+ annotations['labels'].append(object_info["phrase"])
32
+ annotations['tokens_positive'].append(object_info["tokens_positive"])
33
+ _masks = []
34
+ for _id in ids:
35
+ _masks.append(all_seg_objects_dict[_id]['segmentation'])
36
+ annotations['masks'].append(_masks)
37
+ return annotations
38
+
39
+ def grand_conversation(caption, tokens_positive):
40
+ question = random.choice(GCG_QUESTIONS).strip()
41
+
42
+ # Prepare caption with tags
43
+ def tag_caption(caption, tokens):
44
+ for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
45
+ caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
46
+ return caption
47
+
48
+ detailed_answer = tag_caption(caption, tokens_positive)
49
+
50
+ conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
51
+ return conversations
52
+
53
+ def grand_preprocess(example):
54
+ data_labels = example['labels']
55
+ masks = example['masks']
56
+ caption = example['caption']
57
+ tokens_positive = example['tokens_positive']
58
+
59
+ # Function to sort elements based on the start index of each phrase
60
+ def sort_by_start_index(items, order):
61
+ return [items[i] for i in order]
62
+
63
+ # Sort phrases based on their appearance in the sentence
64
+ phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
65
+ masks = sort_by_start_index(masks, phrase_order)
66
+ data_labels = sort_by_start_index(data_labels, phrase_order)
67
+ tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
68
+
69
+ conversations = grand_conversation(caption, tokens_positive)
70
+ example['conversations'] = conversations
71
+ example['labels'] = data_labels
72
+ example['masks'] = masks
73
+ example['tokens_positive'] = tokens_positive
74
+ return example
75
+
76
+ def glamm_grand_map_fn(example):
77
+ # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
78
+ # "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
79
+ example = grand_parse_annotations(example)
80
+ # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
81
+
82
+ example = grand_preprocess(example)
83
+
84
+ # do llava preprocess
85
+ messages = example['conversations']
86
+ input = ''
87
+ conversation = []
88
+ while messages and messages[0]['from'] == 'gpt':
89
+ # Skip the first one if it is from gpt
90
+ messages = messages[1:]
91
+ for msg in messages:
92
+ if msg['from'] == 'human':
93
+ if DEFAULT_IMAGE_TOKEN in msg['value']:
94
+ msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
95
+ '').strip()
96
+ msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
97
+ msg['value'] = msg['value'].strip()
98
+ input += msg['value']
99
+
100
+ elif msg['from'] == 'gpt':
101
+ conversation.append({'input': input, 'output': msg['value']})
102
+ input = ''
103
+ else:
104
+ raise NotImplementedError
105
+ example.update({'conversation': conversation})
106
+ return example
107
+
108
+
109
+
110
+
projects/llava_sam2/datasets/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
3
+ image_size):
4
+ best_ratio_diff = float('inf')
5
+ best_ratio = (1, 1)
6
+ area = width * height
7
+ for ratio in target_ratios:
8
+ target_aspect_ratio = ratio[0] / ratio[1]
9
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
10
+ if ratio_diff < best_ratio_diff:
11
+ best_ratio_diff = ratio_diff
12
+ best_ratio = ratio
13
+ elif ratio_diff == best_ratio_diff:
14
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
15
+ best_ratio = ratio
16
+ return best_ratio
17
+
18
+ def dynamic_preprocess(image,
19
+ min_num=1,
20
+ max_num=6,
21
+ image_size=448,
22
+ use_thumbnail=False):
23
+ orig_width, orig_height = image.size
24
+ aspect_ratio = orig_width / orig_height
25
+
26
+ # calculate the existing image aspect ratio
27
+ target_ratios = {(i, j)
28
+ for n in range(min_num, max_num + 1)
29
+ for i in range(1, n + 1) for j in range(1, n + 1)
30
+ if i * j <= max_num and i * j >= min_num}
31
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
32
+
33
+ # find the closest aspect ratio to the target
34
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
35
+ target_ratios, orig_width,
36
+ orig_height, image_size)
37
+
38
+ # calculate the target width and height
39
+ target_width = image_size * target_aspect_ratio[0]
40
+ target_height = image_size * target_aspect_ratio[1]
41
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
42
+
43
+ # resize the image
44
+ resized_img = image.resize((target_width, target_height))
45
+ processed_images = []
46
+ for i in range(blocks):
47
+ box = ((i % (target_width // image_size)) * image_size,
48
+ (i // (target_width // image_size)) * image_size,
49
+ ((i % (target_width // image_size)) + 1) * image_size,
50
+ ((i // (target_width // image_size)) + 1) * image_size)
51
+ # split the image
52
+ split_img = resized_img.crop(box)
53
+ processed_images.append(split_img)
54
+ assert len(processed_images) == blocks
55
+ if use_thumbnail and len(processed_images) != 1:
56
+ thumbnail_img = image.resize((image_size, image_size))
57
+ processed_images.append(thumbnail_img)
58
+ return processed_images
projects/llava_sam2/datasets/vqa_dataset.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import glob
4
+ import json
5
+ import logging
6
+ import os
7
+ from typing import Literal
8
+
9
+ import torch
10
+
11
+ from mmengine import print_log
12
+ from mmengine.config import Config, ConfigDict
13
+ from PIL import Image
14
+ from torch.utils.data import Dataset
15
+ import numpy as np
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as T
18
+ from torchvision.transforms.functional import InterpolationMode
19
+ from pycocotools.coco import COCO
20
+ from pycocotools import mask as mask_utils
21
+
22
+ from xtuner.registry import BUILDER
23
+ from xtuner.utils import IGNORE_INDEX
24
+ from xtuner.dataset.utils import encode_fn
25
+ from xtuner.dataset.map_fns import llava_map_fn
26
+
27
+ from projects.glamm.datasets.utils.utils import expand2square
28
+
29
+ from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
30
+ from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
31
+
32
+ from .utils import dynamic_preprocess
33
+
34
+
35
+ class InfinityMMDataset(Dataset):
36
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
37
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
38
+ IMG_START_TOKEN = '<img>'
39
+ IMG_END_TOKEN = '</img>'
40
+
41
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
42
+ IMAGENET_STD = (0.229, 0.224, 0.225)
43
+
44
+ def __init__(self,
45
+ tokenizer,
46
+ data_path,
47
+ prompt_template,
48
+ special_tokens=None,
49
+ max_length=8192,
50
+ offline_save_path='./work_dirs/infinityMM.json',
51
+ ):
52
+ self.offline_save_path = offline_save_path
53
+ self.tokenizer = BUILDER.build(tokenizer)
54
+ if special_tokens is not None:
55
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
56
+ self._system = ''
57
+
58
+ self.template = prompt_template
59
+ self.max_length = max_length
60
+
61
+ self.min_dynamic_patch = 1
62
+ self.max_dynamic_patch = 12
63
+ self.downsample_ratio = 0.5
64
+ self.image_size = 448
65
+ self.use_thumbnail = True
66
+ patch_size = 14
67
+ self.patch_token = int(
68
+ (self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
69
+
70
+ self.transformer = T.Compose([
71
+ T.Lambda(lambda img: img.convert('RGB')
72
+ if img.mode != 'RGB' else img),
73
+ T.Resize((self.image_size, self.image_size),
74
+ interpolation=InterpolationMode.BICUBIC),
75
+ T.ToTensor(),
76
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
77
+ ])
78
+
79
+ self.data = self._load_annotations(data_path)
80
+ self._max_refetch = 1000
81
+
82
+ def _load_annotations(self, data_path):
83
+ if os.path.exists(self.offline_save_path):
84
+ with open(self.offline_save_path, 'r') as f:
85
+ ret = json.load(f)
86
+ print(f"Load InfinityMM file list from {self.offline_save_path}, {len(ret)} items !!!")
87
+ return ret
88
+ sub_folders = []
89
+ for sub_folder in os.listdir(data_path):
90
+ if '.' not in sub_folder:
91
+ # a folder
92
+ if "LVIS_111k" in sub_folder:
93
+ # special case, have subsub folder
94
+ subsub_folders = os.listdir(os.path.join(data_path, sub_folder))
95
+ for subsub_folder in subsub_folders:
96
+ sub_folders.append(os.path.join(data_path, sub_folder, subsub_folder))
97
+ else:
98
+ sub_folders.append(os.path.join(data_path, sub_folder))
99
+
100
+ all_jsons = []
101
+ for sub_folder in sub_folders:
102
+ print(f"Processing {sub_folder} !!!")
103
+ _files = os.listdir(sub_folder)
104
+ _num = 0
105
+ for _file in _files:
106
+ if '.json' in _file:
107
+ _json_path = os.path.join(sub_folder, _file)
108
+ _num += 1
109
+ all_jsons.append(os.path.join(sub_folder, _file))
110
+ print(f"Finished {sub_folder} has {_num} items.")
111
+
112
+ with open(self.offline_save_path, 'w') as f:
113
+ json.dump(all_jsons, f)
114
+
115
+ return all_jsons
116
+
117
+ def __getitem__(self, index):
118
+ for _ in range(self._max_refetch + 1):
119
+ data = self.prepare_data(index)
120
+ # Broken images may cause the returned data to be None
121
+ if data is None:
122
+ index = self._rand_another()
123
+ continue
124
+ return data
125
+
126
+ def __len__(self):
127
+ return len(self.data)
128
+
129
+ @property
130
+ def modality_length(self):
131
+ self.group_length = []
132
+ for data_dict in self.data:
133
+ self.group_length.append(100)
134
+ return self.group_length
135
+
136
+ @property
137
+ def length(self):
138
+ group_length = np.array(self.group_length)
139
+ group_length = np.abs(group_length).tolist()
140
+ return group_length
141
+
142
+ def prepare_data(self, index):
143
+ data_path = self.data[index]
144
+
145
+ with open(data_path, 'r') as f:
146
+ data_dict = json.load(f)
147
+ if 'image' in data_dict.keys():
148
+ data_dict['image'] = data_path.replace('.json', '.jpg')
149
+
150
+ if data_dict is None:
151
+ return None
152
+
153
+ out_data_dict = {}
154
+
155
+ if data_dict.get('image', None) is not None:
156
+ image_file = data_dict['image']
157
+ try:
158
+ image = Image.open(image_file).convert('RGB')
159
+ except Exception as e:
160
+ print(f'Error: {e}', flush=True)
161
+ print_log(f'Error: {e}', logger='current')
162
+ return None
163
+
164
+ images = dynamic_preprocess(image, self.min_dynamic_patch,
165
+ self.max_dynamic_patch,
166
+ self.image_size, self.use_thumbnail)
167
+ pixel_values = [self.transformer(image) for image in images]
168
+ pixel_values = torch.stack(pixel_values)
169
+ out_data_dict['pixel_values'] = pixel_values
170
+
171
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
172
+ image_token_str = f'{self.IMG_START_TOKEN}' \
173
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
174
+ f'{self.IMG_END_TOKEN}'
175
+ token_dict = self.get_inputid_labels(
176
+ data_dict['conversations'], image_token_str)
177
+ out_data_dict.update(token_dict)
178
+ else:
179
+ token_dict = self.get_inputid_labels(
180
+ data_dict['conversations'], None)
181
+ out_data_dict.update(token_dict)
182
+ out_data_dict['pixel_values'] = torch.zeros(
183
+ 1, 3, self.image_size, self.image_size)
184
+ return out_data_dict
185
+
186
+ def _rand_another(self) -> int:
187
+ return np.random.randint(0, len(self.data))
188
+
189
+ def get_inputid_labels(self, conversations, image_token_str) -> dict:
190
+ input = ''
191
+ out_conversation = []
192
+ while conversations and conversations[0]['from'] == 'gpt':
193
+ # Skip the first one if it is from gpt
194
+ conversations = conversations[1:]
195
+ for i, msg in enumerate(conversations):
196
+ if msg['from'] == 'human':
197
+
198
+ # change to 1 image
199
+ if '<image>' in msg['value']:
200
+ msg['value'] = msg['value'].replace('<image>\n', '').replace('<image>', '')
201
+ if i == 0:
202
+ msg['value'] = "<image>\n" + msg['value']
203
+
204
+ if image_token_str is None and '<image>' in msg['value']:
205
+ msg['value'] = msg['value'].replace('<image>', '')
206
+ if '<image>' in msg['value']:
207
+ msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
208
+ input += msg['value'].strip()
209
+ elif msg['from'] == 'gpt':
210
+ out_conversation.append({
211
+ 'input': input,
212
+ 'output': msg['value'].strip()
213
+ })
214
+ input = ''
215
+ else:
216
+ raise NotImplementedError
217
+
218
+ input_ids, labels = [], []
219
+ for i, single_turn_conversation in enumerate(out_conversation):
220
+ input = single_turn_conversation.get('input', '')
221
+ if input is None:
222
+ input = ''
223
+ input_text = self.template.INSTRUCTION.format(
224
+ input=input, round=i + 1)
225
+
226
+ if i == 0:
227
+ if self._system != '' and self._system is not None:
228
+ system = self.template.SYSTEM.format(system=self._system)
229
+ input_text = system + input_text
230
+ input_encode = self.tokenizer.encode(
231
+ input_text, add_special_tokens=True)
232
+ else:
233
+ input_encode = self.tokenizer.encode(
234
+ input_text, add_special_tokens=False)
235
+ input_ids += input_encode
236
+ labels += [IGNORE_INDEX] * len(input_encode)
237
+
238
+ output_text = single_turn_conversation.get('output', '')
239
+ if self.template.get('SUFFIX', None):
240
+ output_text += self.template.SUFFIX
241
+ output_encode = self.tokenizer.encode(
242
+ output_text, add_special_tokens=False)
243
+ input_ids += output_encode
244
+ labels += copy.deepcopy(output_encode)
245
+
246
+ if len(input_ids) > self.max_length:
247
+ input_ids = input_ids[:self.max_length]
248
+ labels = labels[:self.max_length]
249
+ print_log(
250
+ f'Warning: input_ids length({len(input_ids)}) '
251
+ f'is longer than max_length, cut to {self.max_length}',
252
+ logger='current')
253
+ return {'input_ids': input_ids, 'labels': labels}
254
+
255
+
256
+ class LLaVADataset(Dataset):
257
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
258
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
259
+ IMG_START_TOKEN = '<img>'
260
+ IMG_END_TOKEN = '</img>'
261
+
262
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
263
+ IMAGENET_STD = (0.229, 0.224, 0.225)
264
+
265
+ def __init__(self,
266
+ tokenizer,
267
+ data_path,
268
+ prompt_template,
269
+ special_tokens=None,
270
+ image_folder=None,
271
+ max_length=8192,
272
+ arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
273
+ preprocessor=None,
274
+ skip_pure_text=False,
275
+ ):
276
+
277
+ self.tokenizer = BUILDER.build(tokenizer)
278
+ if special_tokens is not None:
279
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
280
+
281
+ self.image_folder = image_folder
282
+ self.template = prompt_template
283
+ self.max_length = max_length
284
+
285
+ self._system = ''
286
+
287
+ self.arch_type = arch_type
288
+ self.min_dynamic_patch = 1
289
+ self.max_dynamic_patch = 12
290
+ self.downsample_ratio = 0.5
291
+ if self.arch_type == 'llava':
292
+ self.downsample_ratio = 1
293
+ self.image_size = 448
294
+ if self.arch_type == 'llava':
295
+ self.image_size = 336
296
+ self.use_thumbnail = True
297
+ patch_size = 14
298
+ self.patch_token = int(
299
+ (self.image_size // patch_size)**2 * (self.downsample_ratio**2))
300
+
301
+
302
+ if self.arch_type == 'qwen':
303
+ self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
304
+ self.IMG_START_TOKEN = '<|vision_start|>'
305
+ self.IMG_END_TOKEN = '<|vision_end|>'
306
+ elif self.arch_type == 'llava':
307
+ self.IMG_CONTEXT_TOKEN = '<image>'
308
+ self.IMG_START_TOKEN = ''
309
+ self.IMG_END_TOKEN = ''
310
+
311
+ if preprocessor is None:
312
+ self.transformer = T.Compose([
313
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
314
+ T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
315
+ T.ToTensor(),
316
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
317
+ ])
318
+ self.preprocessor = None
319
+ else:
320
+ self.transformer = None
321
+ self.preprocessor = BUILDER.build(preprocessor)
322
+
323
+ self.data = self._load_annotations(data_path, image_folder)
324
+ self._max_refetch = 1000
325
+
326
+ self.skip_pure_text = skip_pure_text
327
+
328
+ def _load_annotations(self, data_path, image_folder=None):
329
+ data = json.load(open(data_path))
330
+ return data
331
+
332
+ def __getitem__(self, index):
333
+ for _ in range(self._max_refetch + 1):
334
+ data = self.prepare_data(index)
335
+ # Broken images may cause the returned data to be None
336
+ if data is None:
337
+ index = self._rand_another()
338
+ continue
339
+ return data
340
+
341
+ def __len__(self):
342
+ return len(self.data)
343
+
344
+ @property
345
+ def modality_length(self):
346
+ self.group_length = []
347
+ for data_dict in self.data:
348
+ self.group_length.append(100)
349
+ return self.group_length
350
+
351
+ @property
352
+ def length(self):
353
+ group_length = np.array(self.group_length)
354
+ group_length = np.abs(group_length).tolist()
355
+ return group_length
356
+
357
+ def prepare_data(self, index):
358
+ data_dict: dict = self.data[index]
359
+
360
+ if data_dict is None:
361
+ return None
362
+
363
+ out_data_dict = {}
364
+
365
+ if self.skip_pure_text and data_dict.get('image', None) is None:
366
+ return None
367
+
368
+ if data_dict.get('image', None) is not None:
369
+ image_file = os.path.join(self.image_folder, data_dict['image'])
370
+ try:
371
+ image = Image.open(image_file).convert('RGB')
372
+ except Exception as e:
373
+ print(f'Error: {e}', flush=True)
374
+ print_log(f'Error: {e}', logger='current')
375
+ return None
376
+ if self.preprocessor is not None:
377
+ # images = dynamic_preprocess(image, self.min_dynamic_patch,
378
+ # self.max_dynamic_patch,
379
+ # self.image_size, self.use_thumbnail)
380
+ images = [image]
381
+ if self.arch_type == 'qwen':
382
+ _data_dict = self.preprocessor(images, do_resize=True)
383
+ _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
384
+ _data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
385
+ num_image_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
386
+ elif self.arch_type == 'llava':
387
+ _data_dict = self.preprocessor(images, do_resize=True, size=(self.image_size, self.image_size))
388
+ _data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
389
+ _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
390
+ num_image_tokens = _data_dict['pixel_values'].shape[0] * self.patch_token
391
+ else:
392
+ raise NotImplementedError
393
+ out_data_dict.update(_data_dict)
394
+ else:
395
+ images = dynamic_preprocess(image, self.min_dynamic_patch,
396
+ self.max_dynamic_patch,
397
+ self.image_size, self.use_thumbnail)
398
+ pixel_values = [self.transformer(image) for image in images]
399
+ pixel_values = torch.stack(pixel_values)
400
+ out_data_dict['pixel_values'] = pixel_values
401
+
402
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
403
+ image_token_str = f'{self.IMG_START_TOKEN}' \
404
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
405
+ f'{self.IMG_END_TOKEN}'
406
+ token_dict = self.get_inputid_labels(
407
+ data_dict['conversations'], image_token_str)
408
+ out_data_dict.update(token_dict)
409
+ else:
410
+ token_dict = self.get_inputid_labels(
411
+ data_dict['conversations'], None)
412
+ out_data_dict.update(token_dict)
413
+ out_data_dict['pixel_values'] = torch.zeros(
414
+ 1, 3, self.image_size, self.image_size)
415
+ return out_data_dict
416
+
417
+ def _rand_another(self) -> int:
418
+ return np.random.randint(0, len(self.data))
419
+
420
+ def get_inputid_labels(self, conversations, image_token_str) -> dict:
421
+ input = ''
422
+ out_conversation = []
423
+ while conversations and conversations[0]['from'] == 'gpt':
424
+ # Skip the first one if it is from gpt
425
+ conversations = conversations[1:]
426
+ for msg in conversations:
427
+ if msg['from'] == 'human':
428
+ if image_token_str is None and '<image>' in msg['value']:
429
+ msg['value'] = msg['value'].replace('<image>', '')
430
+ if '<image>' in msg['value']:
431
+ msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
432
+ input += msg['value'].strip()
433
+ elif msg['from'] == 'gpt':
434
+ out_conversation.append({
435
+ 'input': input,
436
+ 'output': msg['value'].strip()
437
+ })
438
+ input = ''
439
+ else:
440
+ raise NotImplementedError
441
+
442
+ input_ids, labels = [], []
443
+ for i, single_turn_conversation in enumerate(out_conversation):
444
+ input = single_turn_conversation.get('input', '')
445
+ if input is None:
446
+ input = ''
447
+ input_text = self.template.INSTRUCTION.format(
448
+ input=input, round=i + 1)
449
+
450
+ if i == 0:
451
+ if self._system != '' and self._system is not None:
452
+ system = self.template.SYSTEM.format(system=self._system)
453
+ input_text = system + input_text
454
+ input_encode = self.tokenizer.encode(
455
+ input_text, add_special_tokens=True)
456
+ else:
457
+ input_encode = self.tokenizer.encode(
458
+ input_text, add_special_tokens=False)
459
+ input_ids += input_encode
460
+ labels += [IGNORE_INDEX] * len(input_encode)
461
+
462
+ output_text = single_turn_conversation.get('output', '')
463
+ if self.template.get('SUFFIX', None):
464
+ output_text += self.template.SUFFIX
465
+ output_encode = self.tokenizer.encode(
466
+ output_text, add_special_tokens=False)
467
+ input_ids += output_encode
468
+ labels += copy.deepcopy(output_encode)
469
+
470
+ if len(input_ids) > self.max_length:
471
+ input_ids = input_ids[:self.max_length]
472
+ labels = labels[:self.max_length]
473
+ print_log(
474
+ f'Warning: input_ids length({len(input_ids)}) '
475
+ f'is longer than max_length, cut to {self.max_length}',
476
+ logger='current')
477
+ return {'input_ids': input_ids, 'labels': labels}
478
+
479
+
480
+ if __name__ == '__main__':
481
+ from transformers import CLIPImageProcessor, AutoTokenizer
482
+ from third_parts.segment_anything.utils.transforms import ResizeLongestSide
483
+ pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
484
+ llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
485
+
486
+ tokenizer = dict(
487
+ type=AutoTokenizer.from_pretrained,
488
+ pretrained_model_name_or_path=llm_name_or_path)
489
+ image_processor = dict(
490
+ type=CLIPImageProcessor.from_pretrained,
491
+ pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
492
+ extra_image_processor = dict(
493
+ type=ResizeLongestSide,
494
+ target_length=1024,
495
+ )
496
+ from xtuner.utils.templates import PROMPT_TEMPLATE
497
+ prompt_template = PROMPT_TEMPLATE.vicuna
498
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
499
+ from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
500
+
501
+ dataset = LLaVADataset(
502
+ tokenizer=tokenizer,
503
+ data_path='data/llava_data/LLaVA-Instruct-150K/llava_instruct_150k.json',
504
+ prompt_template=prompt_template,
505
+ special_tokens=['[SEG]'],
506
+ image_folder='data/coco/train2017/',
507
+ )
508
+ for i in range(1000):
509
+ dataset[i]
projects/llava_sam2/deepspeed_zero2_sam2.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "gradient_accumulation_steps": "auto",
3
+ "train_micro_batch_size_per_gpu": "auto",
4
+ "gradient_clipping": "auto",
5
+ "zero_allow_untested_optimizer": true,
6
+ "zero_force_ds_cpu_optimizer": false,
7
+ "zero_optimization": {
8
+ "stage": 2,
9
+ "overlap_comm": true,
10
+ "allgather_bucket_size": 5368709120,
11
+ "reduce_bucket_size": 5368709120,
12
+ "reduce_scatter": true,
13
+ "sub_group_size": 1e9,
14
+ "contiguous_gradients": true,
15
+ "allgather_partitions": true
16
+ },
17
+ "fp16": {
18
+ "enabled": false,
19
+ "initial_scale_power": 16
20
+ },
21
+ "bf16": {
22
+ "enabled": true
23
+ }
24
+ }
projects/llava_sam2/gradio/app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ from projects.llava_sam2.gradio.app_utils import\
4
+ process_markdown, show_mask_pred, description, preprocess_video,\
5
+ show_mask_pred_video, image2video_and_save
6
+
7
+ import torch
8
+ from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel, GenerationConfig)
11
+ import argparse
12
+ import os
13
+
14
+ TORCH_DTYPE_MAP = dict(
15
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
16
+
17
+ def parse_args(args):
18
+ parser = argparse.ArgumentParser(description="Sa2VA Demo")
19
+ parser.add_argument('hf_path', help='Sa2VA hf path.')
20
+ return parser.parse_args(args)
21
+
22
+ def inference(image, video, follow_up, input_str):
23
+ input_image = image
24
+ if image is not None and (video is not None and os.path.exists(video)):
25
+ return image, video, "Error: Please only input a image or a video !!!"
26
+ if image is None and (video is None or not os.path.exists(video)) and not follow_up:
27
+ return image, video, "Error: Please input a image or a video !!!"
28
+
29
+ if not follow_up:
30
+ # reset
31
+ print('Log: History responses have been removed!')
32
+ global_infos.n_turn = 0
33
+ global_infos.inputs = ''
34
+ text = input_str
35
+
36
+ image = input_image
37
+ global_infos.image_for_show = image
38
+ global_infos.image = image
39
+ video = video
40
+ global_infos.video = video
41
+
42
+ if image is not None:
43
+ global_infos.input_type = "image"
44
+ else:
45
+ global_infos.input_type = "video"
46
+
47
+ else:
48
+ text = input_str
49
+ image = global_infos.image
50
+ video = global_infos.video
51
+
52
+ input_type = global_infos.input_type
53
+ if input_type == "video":
54
+ video = preprocess_video(video, global_infos.inputs+input_str)
55
+
56
+ past_text = global_infos.inputs
57
+
58
+ if past_text == "" and "<image>" not in text:
59
+ text = "<image>" + text
60
+ if input_type == "image":
61
+ input_dict = {
62
+ 'image': image,
63
+ 'text': text,
64
+ 'past_text': past_text,
65
+ 'mask_prompts': None,
66
+ 'tokenizer': tokenizer,
67
+ }
68
+ else:
69
+ input_dict = {
70
+ 'video': video,
71
+ 'text': text,
72
+ 'past_text': past_text,
73
+ 'mask_prompts': None,
74
+ 'tokenizer': tokenizer,
75
+ }
76
+
77
+ return_dict = sa2va_model.predict_forward(**input_dict)
78
+ global_infos.inputs = return_dict["past_text"]
79
+ print(return_dict['past_text'])
80
+ if 'prediction_masks' in return_dict.keys() and return_dict['prediction_masks'] and len(
81
+ return_dict['prediction_masks']) != 0:
82
+ if input_type == "image":
83
+ image_mask_show, selected_colors = show_mask_pred(global_infos.image_for_show, return_dict['prediction_masks'],)
84
+ video_mask_show = global_infos.video
85
+ else:
86
+ image_mask_show = None
87
+ video_mask_show, selected_colors = show_mask_pred_video(video, return_dict['prediction_masks'],)
88
+ video_mask_show = image2video_and_save(video_mask_show, save_path="./ret_video.mp4")
89
+ else:
90
+ image_mask_show = global_infos.image_for_show
91
+ video_mask_show = global_infos.video
92
+ selected_colors = []
93
+
94
+ predict = return_dict['prediction'].strip()
95
+ global_infos.n_turn += 1
96
+
97
+ predict = process_markdown(predict, selected_colors)
98
+ return image_mask_show, video_mask_show, predict
99
+
100
+ def init_models(args):
101
+ model_path = args.hf_path
102
+ model = AutoModel.from_pretrained(
103
+ model_path,
104
+ torch_dtype=torch.bfloat16,
105
+ low_cpu_mem_usage=True,
106
+ use_flash_attn=True,
107
+ trust_remote_code=True,
108
+ ).eval().cuda()
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained(
111
+ model_path,
112
+ trust_remote_code=True,
113
+ )
114
+ return model, tokenizer
115
+
116
+ class global_infos:
117
+ inputs = ''
118
+ n_turn = 0
119
+ image_width = 0
120
+ image_height = 0
121
+
122
+ image_for_show = None
123
+ image = None
124
+ video = None
125
+
126
+ input_type = "image" # "image" or "video"
127
+
128
+ if __name__ == "__main__":
129
+ # get parse args and set models
130
+ args = parse_args(sys.argv[1:])
131
+
132
+ sa2va_model, tokenizer = \
133
+ init_models(args)
134
+
135
+ demo = gr.Interface(
136
+ inference,
137
+ inputs=[
138
+ gr.Image(type="pil", label="Upload Image", height=360),
139
+ gr.Video(sources=["upload", "webcam"], label="Upload mp4 video", height=360),
140
+ gr.Checkbox(label="Follow up Question"),
141
+ gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),],
142
+ outputs=[
143
+ gr.Image(type="pil", label="Output Image"),
144
+ gr.Video(label="Output Video", show_download_button=True, format='mp4'),
145
+ gr.Markdown()],
146
+ theme=gr.themes.Soft(), allow_flagging="auto", description=description,
147
+ title='Sa2VA'
148
+ )
149
+
150
+ demo.queue()
151
+ demo.launch(share=True)
projects/llava_sam2/gradio/app_utils.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import cv2
4
+
5
+ markdown_default = """
6
+ <link href="https://fonts.googleapis.com/css2?family=Montserrat:wght@400;700&display=swap" rel="stylesheet">
7
+ <style>
8
+ .highlighted-text {
9
+ font-family: 'Montserrat', sans-serif;
10
+ font-weight: 600;
11
+ font-size: 14px;
12
+ color: rgb(255, 255, 239);
13
+ background-color: rgb(225, 231, 254);
14
+ border-radius: 7px;
15
+ padding: 5px 7px;
16
+ display: inline-block;
17
+ }
18
+ .regular-text {
19
+ font-family: 'Montserrat', sans-serif;
20
+ font-weight: 400;
21
+ font-size: 14px;
22
+ }
23
+ .highlighted-response {
24
+ font-family: 'Montserrat', sans-serif;
25
+ font-weight: 600;
26
+ font-size: 14px;
27
+ border-radius: 6px;
28
+ padding: 3px 4px;
29
+ display: inline-block;
30
+ }
31
+ </style>
32
+ <span class="highlighted-text" style='color:rgb(107, 100, 239)'>Sa2VA</span>
33
+ """
34
+
35
+ description = """
36
+ **Usage** : <br>
37
+ &ensp;(1) For **Grounded Caption Generation** Interleaved Segmentation, input prompt like: *"Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer."* <br>
38
+ &ensp;(2) For **Segmentation Output**, input prompt like: *"Can you please segment xxx in the given image"* <br>
39
+ &ensp;(3) For **Image Captioning** VQA, input prompt like: *"Could you please give me a detailed description of the image?"* <br>
40
+ &ensp;(4) For **Image Conversation**, input arbitrary text instruction. <br>
41
+ """
42
+
43
+ ONE_THIRD = 1.0/3.0
44
+ ONE_SIXTH = 1.0/6.0
45
+ TWO_THIRD = 2.0/3.0
46
+
47
+ def desaturate(rgb, factor=0.65):
48
+ """
49
+ Desaturate an RGB color by a given factor.
50
+
51
+ :param rgb: A tuple of (r, g, b) where each value is in [0, 255].
52
+ :param factor: The factor by which to reduce the saturation.
53
+ 0 means completely desaturated, 1 means original color.
54
+ :return: A tuple of desaturated (r, g, b) values in [0, 255].
55
+ """
56
+ r, g, b = [x / 255.0 for x in rgb]
57
+ h, l, s = rgb_to_hls(r, g, b)
58
+ l = factor
59
+ new_r, new_g, new_b = hls_to_rgb(h, l, s)
60
+ return (int(new_r * 255), int(new_g * 255), int(new_b * 255))
61
+
62
+ def rgb_to_hls(r, g, b):
63
+ maxc = max(r, g, b)
64
+ minc = min(r, g, b)
65
+ sumc = (maxc+minc)
66
+ rangec = (maxc-minc)
67
+ l = sumc/2.0
68
+ if minc == maxc:
69
+ return 0.0, l, 0.0
70
+ if l <= 0.5:
71
+ s = rangec / sumc
72
+ else:
73
+ s = rangec / (2.0-sumc)
74
+ rc = (maxc-r) / rangec
75
+ gc = (maxc-g) / rangec
76
+ bc = (maxc-b) / rangec
77
+ if r == maxc:
78
+ h = bc-gc
79
+ elif g == maxc:
80
+ h = 2.0+rc-bc
81
+ else:
82
+ h = 4.0+gc-rc
83
+ h = (h/6.0) % 1.0
84
+ return h, l, s
85
+
86
+ def hls_to_rgb(h, l, s):
87
+ if s == 0.0:
88
+ return l, l, l
89
+ if l <= 0.5:
90
+ m2 = l * (1.0+s)
91
+ else:
92
+ m2 = l+s-(l*s)
93
+ m1 = 2.0*l - m2
94
+ return (_v(m1, m2, h+ONE_THIRD), _v(m1, m2, h), _v(m1, m2, h-ONE_THIRD))
95
+
96
+ def _v(m1, m2, hue):
97
+ hue = hue % 1.0
98
+ if hue < ONE_SIXTH:
99
+ return m1 + (m2-m1)*hue*6.0
100
+ if hue < 0.5:
101
+ return m2
102
+ if hue < TWO_THIRD:
103
+ return m1 + (m2-m1)*(TWO_THIRD-hue)*6.0
104
+ return m1
105
+
106
+ def process_markdown(output_str, colors):
107
+ output_str = output_str.replace("\n", "").replace(" ", " ").replace("<s>", "")\
108
+ .replace("<|im_end|>", '').replace("<|end|>", "")
109
+ output_str = output_str.split("ASSISTANT: ")[-1]
110
+
111
+ # markdown_out = output_str.replace('[SEG]', '')
112
+ markdown_out = output_str
113
+ markdown_out = markdown_out.replace(
114
+ "<p>", "<span class='highlighted-response' style='background-color:rgb[COLOR]'>"
115
+ )
116
+ markdown_out = markdown_out.replace("</p>", "</span>")
117
+
118
+ for color in colors:
119
+ markdown_out = markdown_out.replace("[COLOR]", str(desaturate(tuple(color))), 1)
120
+
121
+ markdown_out = f"""
122
+ {markdown_out}
123
+ """
124
+ markdown_out = markdown_default + "<p><span class='regular-text'>" + markdown_out
125
+ return markdown_out
126
+
127
+ def show_mask_pred(image, masks):
128
+ masks = [mask[:1] for mask in masks]
129
+ masks = np.concatenate(masks, axis=0) # (n, h, w)
130
+
131
+ selected_colors = []
132
+
133
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
134
+ (255, 255, 0), (255, 0, 255), (0, 255, 255),
135
+ (128, 128, 255), [255, 192, 203], # Pink
136
+ [165, 42, 42], # Brown
137
+ [255, 165, 0], # Orange
138
+ [128, 0, 128], # Purple
139
+ [0, 0, 128], # Navy
140
+ [128, 0, 0], # Maroon
141
+ [128, 128, 0], # Olive
142
+ [70, 130, 180], # Steel Blue
143
+ [173, 216, 230], # Light Blue
144
+ [255, 192, 0], # Gold
145
+ [255, 165, 165], # Light Salmon
146
+ [255, 20, 147], # Deep Pink
147
+ ]
148
+
149
+ _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8)
150
+
151
+ for i, mask in enumerate(masks):
152
+ color = colors[i % len(colors)]
153
+ selected_colors.append(color)
154
+ _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0]
155
+ _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1]
156
+ _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2]
157
+
158
+
159
+ image = np.array(image)
160
+ image = image * 0.5 + _mask_image * 0.5
161
+ image = image.astype(np.uint8)
162
+ return image, selected_colors
163
+
164
+ def show_mask_pred_video(video, masks):
165
+ ret_video = []
166
+ selected_colors = []
167
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
168
+ (255, 255, 0), (255, 0, 255), (0, 255, 255),
169
+ (128, 128, 255), [255, 192, 203], # Pink
170
+ [165, 42, 42], # Brown
171
+ [255, 165, 0], # Orange
172
+ [128, 0, 128], # Purple
173
+ [0, 0, 128], # Navy
174
+ [128, 0, 0], # Maroon
175
+ [128, 128, 0], # Olive
176
+ [70, 130, 180], # Steel Blue
177
+ [173, 216, 230], # Light Blue
178
+ [255, 192, 0], # Gold
179
+ [255, 165, 165], # Light Salmon
180
+ [255, 20, 147], # Deep Pink
181
+ ]
182
+ for i_frame in range(len(video)):
183
+ frame_masks = [mask[i_frame:i_frame+1] for mask in masks]
184
+ frame_masks = np.concatenate(frame_masks, axis=0)
185
+ _mask_image = np.zeros((frame_masks.shape[1], frame_masks.shape[2], 3), dtype=np.uint8)
186
+
187
+ for i, mask in enumerate(frame_masks):
188
+ if i_frame == 0:
189
+ color = colors[i % len(colors)]
190
+ selected_colors.append(color)
191
+ else:
192
+ color = selected_colors[i]
193
+ _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0]
194
+ _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1]
195
+ _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2]
196
+
197
+ image = np.array(video[i_frame])
198
+ image = image * 0.5 + _mask_image * 0.5
199
+ image = image.astype(np.uint8)
200
+ ret_video.append(image)
201
+ return ret_video, selected_colors
202
+
203
+ def parse_visual_prompts(points):
204
+ ret = {'points': [], 'boxes': []}
205
+ for item in points:
206
+ if item[2] == 1.0:
207
+ ret['points'].append([item[0], item[1]])
208
+ elif item[2] == 2.0 or item[2] == 3.0:
209
+ ret['boxes'].append([item[0], item[1], item[3], item[4]])
210
+ else:
211
+ raise NotImplementedError
212
+ return ret
213
+
214
+ def get_video_frames(video_path):
215
+ cap = cv2.VideoCapture(video_path)
216
+
217
+ if not cap.isOpened():
218
+ print("Error: Cannot open video file.")
219
+ return
220
+
221
+ frames = []
222
+
223
+ frame_id = 0
224
+ while True:
225
+ ret, frame = cap.read()
226
+
227
+ if not ret:
228
+ break
229
+
230
+ frames.append(frame)
231
+
232
+ frame_id += 1
233
+
234
+ cap.release()
235
+ return frames
236
+
237
+ def get_frames_from_video(video_path, n_frames=5, sample_type="uniform"):
238
+ frames = get_video_frames(video_path)
239
+ if sample_type == "uniform":
240
+ stride = len(frames) / (n_frames + 1e-4)
241
+ ret = []
242
+ for i in range(n_frames):
243
+ idx = int(i * stride)
244
+ frame = frames[idx]
245
+ frame = frame[:, :, ::-1]
246
+ frame_image = Image.fromarray(frame).convert('RGB')
247
+ ret.append(frame_image)
248
+ else:
249
+ ret = []
250
+ for frame in frames[:500]:
251
+ frame = frame[:, :, ::-1]
252
+ frame_image = Image.fromarray(frame).convert('RGB')
253
+ ret.append(frame_image)
254
+ return ret
255
+
256
+ def preprocess_video(video_path, text):
257
+ if "Segment" in text or "segment" in text:
258
+ sample_type = 'begin'
259
+ else:
260
+ sample_type = 'uniform'
261
+ return get_frames_from_video(video_path, sample_type=sample_type)
262
+
263
+ def image2video_and_save(frames, save_path):
264
+ success = frames_to_video(frames, save_path)
265
+ return save_path
266
+
267
+
268
+ def frames_to_video(
269
+ frames,
270
+ output_path: str,
271
+ fps: int = 24,
272
+ ) -> bool:
273
+ try:
274
+ frames = [frame[:, :, ::-1] for frame in frames]
275
+ # Use provided frame size or get from first frame
276
+ height, width = frames[0].shape[:2]
277
+
278
+ # Initialize video writer
279
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
280
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
281
+
282
+ # Process each frame
283
+ for frame in frames:
284
+ out.write(frame)
285
+
286
+ # Release video writer
287
+ out.release()
288
+ print(f"Video saved successfully to {output_path}")
289
+ return True
290
+
291
+ except Exception as e:
292
+ print(f"Error converting frames to video: {str(e)}")
293
+ return False
projects/llava_sam2/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .llava_sam2 import VideoLLaVASAMModel, VideoLLaVASAMModel_zero3
2
+ from .sam2 import SAM2
3
+ from .sam2_train import SAM2TrainRunner
projects/llava_sam2/models/extension/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sam2_base import SAM2Base
projects/llava_sam2/models/extension/sam2_base.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from third_parts.sam2.modeling.sam2_base import SAM2Base as _SAM2Base
5
+ from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE
6
+
7
+
8
+ class SAM2Base(_SAM2Base):
9
+
10
+ def track_step(
11
+ self,
12
+ frame_idx,
13
+ is_init_cond_frame,
14
+ current_vision_feats,
15
+ current_vision_pos_embeds,
16
+ feat_sizes,
17
+ point_inputs,
18
+ mask_inputs,
19
+ output_dict,
20
+ num_frames,
21
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
22
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
23
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
24
+ # in demo we might call `track_step` multiple times for each user click,
25
+ # and only encode the memory when the user finalizes their clicks. And in ablation
26
+ # settings like SAM training on static images, we don't need the memory encoder.
27
+ run_mem_encoder=True,
28
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
29
+ prev_sam_mask_logits=None,
30
+ ## Extension: LLM prompt
31
+ language_embd=None,
32
+ ):
33
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
34
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
35
+ if len(current_vision_feats) > 1:
36
+ high_res_features = [
37
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
38
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
39
+ ]
40
+ else:
41
+ high_res_features = None
42
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
43
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
44
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
45
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
46
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
47
+ sam_outputs = self._use_mask_as_output(
48
+ pix_feat, high_res_features, mask_inputs
49
+ )
50
+ else:
51
+ # fused the visual feature with previous memory features in the memory bank
52
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(
53
+ frame_idx=frame_idx,
54
+ is_init_cond_frame=is_init_cond_frame,
55
+ current_vision_feats=current_vision_feats[-1:],
56
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
57
+ feat_sizes=feat_sizes[-1:],
58
+ output_dict=output_dict,
59
+ num_frames=num_frames,
60
+ track_in_reverse=track_in_reverse,
61
+ )
62
+ # apply SAM-style segmentation head
63
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
64
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
65
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
66
+ if prev_sam_mask_logits is not None:
67
+ assert point_inputs is not None and mask_inputs is None
68
+ mask_inputs = prev_sam_mask_logits
69
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
70
+ sam_outputs = self._forward_sam_heads(
71
+ backbone_features=pix_feat_with_mem,
72
+ point_inputs=point_inputs,
73
+ mask_inputs=mask_inputs,
74
+ high_res_features=high_res_features,
75
+ multimask_output=multimask_output,
76
+ # Inject language Embed if possible
77
+ language_embd=language_embd,
78
+ )
79
+ (
80
+ _,
81
+ _,
82
+ _,
83
+ low_res_masks,
84
+ high_res_masks,
85
+ obj_ptr,
86
+ _,
87
+ ) = sam_outputs
88
+
89
+ current_out["pred_masks"] = low_res_masks
90
+ current_out["pred_masks_high_res"] = high_res_masks
91
+ current_out["obj_ptr"] = obj_ptr
92
+
93
+ # Finally run the memory encoder on the predicted mask to encode
94
+ # it into a new memory feature (that can be used in future frames)
95
+ if run_mem_encoder and self.num_maskmem > 0:
96
+ high_res_masks_for_mem_enc = high_res_masks
97
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
98
+ current_vision_feats=current_vision_feats,
99
+ feat_sizes=feat_sizes,
100
+ pred_masks_high_res=high_res_masks_for_mem_enc,
101
+ is_mask_from_pts=(point_inputs is not None),
102
+ )
103
+ current_out["maskmem_features"] = maskmem_features
104
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
105
+ else:
106
+ current_out["maskmem_features"] = None
107
+ current_out["maskmem_pos_enc"] = None
108
+
109
+ return current_out
110
+
111
+
112
+ def _forward_sam_heads(
113
+ self,
114
+ backbone_features,
115
+ point_inputs=None,
116
+ mask_inputs=None,
117
+ high_res_features=None,
118
+ multimask_output=False,
119
+ ## Extension: LLM prompt
120
+ language_embd=None,
121
+ ):
122
+ """
123
+ Forward SAM prompt encoders and mask heads.
124
+
125
+ Inputs:
126
+ - backbone_features: image features of [B, C, H, W] shape
127
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
128
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
129
+ absolute pixel-unit coordinate in (x, y) format of the P input points
130
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
131
+ positive clicks, 0 means negative clicks, and -1 means padding
132
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
133
+ same spatial size as the image.
134
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
135
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
136
+ which will be used as high-resolution feature maps for SAM decoder.
137
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
138
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
139
+ its corresponding IoU estimate.
140
+
141
+ Outputs:
142
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
143
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
144
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
145
+ the resolution (1/4 stride) of the input backbone_features.
146
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
147
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
148
+ upsampled from the low-resolution masks, with shape size as the image
149
+ (stride is 1 pixel).
150
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
151
+ if `multimask_output=False`), the estimated IoU of each output mask.
152
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
153
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
154
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
155
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
156
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
157
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
158
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
159
+ based on the output token from the SAM mask decoder.
160
+ """
161
+ B = backbone_features.size(0)
162
+ device = backbone_features.device
163
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
164
+ assert backbone_features.size(2) == self.sam_image_embedding_size
165
+ assert backbone_features.size(3) == self.sam_image_embedding_size
166
+
167
+ # a) Handle point prompts
168
+ if point_inputs is not None:
169
+ sam_point_coords = point_inputs["point_coords"]
170
+ sam_point_labels = point_inputs["point_labels"]
171
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
172
+ else:
173
+ # If no points are provide, pad with an empty point (with label -1)
174
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
175
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
176
+
177
+ # b) Handle mask prompts
178
+ if mask_inputs is not None:
179
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
180
+ # and feed it as a dense mask prompt into the SAM mask encoder
181
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
182
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
183
+ sam_mask_prompt = F.interpolate(
184
+ mask_inputs.float(),
185
+ size=self.sam_prompt_encoder.mask_input_size,
186
+ align_corners=False,
187
+ mode="bilinear",
188
+ antialias=True, # use antialias for downsampling
189
+ )
190
+ else:
191
+ sam_mask_prompt = mask_inputs
192
+ else:
193
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
194
+ # a learned `no_mask_embed` to indicate no mask input in this case).
195
+ sam_mask_prompt = None
196
+
197
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
198
+ points=(sam_point_coords, sam_point_labels),
199
+ boxes=None,
200
+ masks=sam_mask_prompt,
201
+ )
202
+
203
+ ## Extension: LLM prompt
204
+ if language_embd is not None:
205
+ # B N C
206
+ assert sparse_embeddings.size(0) == language_embd.size(0)
207
+ assert sparse_embeddings.size(2) == language_embd.size(2)
208
+ sparse_embeddings = torch.cat([sparse_embeddings, language_embd], dim=1)
209
+
210
+ (
211
+ low_res_multimasks,
212
+ ious,
213
+ sam_output_tokens,
214
+ object_score_logits,
215
+ ) = self.sam_mask_decoder(
216
+ image_embeddings=backbone_features,
217
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
218
+ sparse_prompt_embeddings=sparse_embeddings,
219
+ dense_prompt_embeddings=dense_embeddings,
220
+ multimask_output=multimask_output,
221
+ repeat_image=False, # the image is already batched
222
+ high_res_features=high_res_features,
223
+ )
224
+ if self.pred_obj_scores:
225
+ is_obj_appearing = object_score_logits > 0
226
+
227
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
228
+ # consistent with the actual mask prediction
229
+ # print('Do torch.where !!!')
230
+ # low_res_multimasks = torch.where(
231
+ # is_obj_appearing[:, None, None],
232
+ # low_res_multimasks,
233
+ # NO_OBJ_SCORE,
234
+ # )
235
+
236
+ # convert masks from possibly bfloat16 (or float16) to float32
237
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
238
+ low_res_multimasks = low_res_multimasks.float()
239
+ high_res_multimasks = F.interpolate(
240
+ low_res_multimasks,
241
+ size=(self.image_size, self.image_size),
242
+ mode="bilinear",
243
+ align_corners=False,
244
+ )
245
+
246
+ sam_output_token = sam_output_tokens[:, 0]
247
+ if multimask_output:
248
+ # take the best mask prediction (with the highest IoU estimation)
249
+ best_iou_inds = torch.argmax(ious, dim=-1)
250
+ batch_inds = torch.arange(B, device=device)
251
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
252
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
253
+ if sam_output_tokens.size(1) > 1:
254
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
255
+ else:
256
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
257
+
258
+ # Extract object pointer from the SAM output token (with occlusion handling)
259
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
260
+ if self.pred_obj_scores:
261
+ # Allow *soft* no obj ptr, unlike for masks
262
+ if self.soft_no_obj_ptr:
263
+ # Only hard possible with gt
264
+ assert not self.teacher_force_obj_scores_for_mem
265
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
266
+ else:
267
+ lambda_is_obj_appearing = is_obj_appearing.float()
268
+
269
+ if self.fixed_no_obj_ptr:
270
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
271
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
272
+
273
+ return (
274
+ low_res_multimasks,
275
+ high_res_multimasks,
276
+ ious,
277
+ low_res_masks,
278
+ high_res_masks,
279
+ obj_ptr,
280
+ object_score_logits,
281
+ )