praeclarumjj3 commited on
Commit
d3cee44
·
1 Parent(s): 0bd6903

:zap: Build space

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: VCoder
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.8.0
8
  app_file: app.py
 
1
  ---
2
  title: VCoder
3
+ emoji: ✌️
4
+ colorFrom: yellow
5
+ colorTo: orange
6
  sdk: gradio
7
  sdk_version: 4.8.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+ import hashlib
10
+
11
+ from vcoder_llava.vcoder_conversation import (default_conversation, conv_templates,
12
+ SeparatorStyle)
13
+ from vcoder_llava.constants import LOGDIR
14
+ from vcoder_llava.utils import (build_logger, server_error_msg,
15
+ violates_moderation, moderation_msg)
16
+ from .chat import Chat
17
+
18
+
19
+ logger = build_logger("gradio_app", "gradio_web_server.log")
20
+
21
+ headers = {"User-Agent": "VCoder Client"}
22
+
23
+ no_change_btn = gr.Button.update()
24
+ enable_btn = gr.Button.update(interactive=True)
25
+ disable_btn = gr.Button.update(interactive=False)
26
+
27
+ priority = {
28
+ "vicuna-13b": "aaaaaaa",
29
+ "koala-13b": "aaaaaab",
30
+ }
31
+
32
+
33
+ def get_conv_log_filename():
34
+ t = datetime.datetime.now()
35
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
36
+ return name
37
+
38
+
39
+ get_window_url_params = """
40
+ function() {
41
+ const params = new URLSearchParams(window.location.search);
42
+ url_params = Object.fromEntries(params);
43
+ console.log(url_params);
44
+ return url_params;
45
+ }
46
+ """
47
+
48
+
49
+ def load_demo_refresh_model_list(request: gr.Request):
50
+ logger.info(f"load_demo. ip: {request.client.host}")
51
+ state = default_conversation.copy()
52
+ dropdown_update = gr.Dropdown.update(
53
+ choices=models,
54
+ value=models[0] if len(models) > 0 else ""
55
+ )
56
+ return state, dropdown_update
57
+
58
+
59
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
60
+ with open(get_conv_log_filename(), "a") as fout:
61
+ data = {
62
+ "tstamp": round(time.time(), 4),
63
+ "type": vote_type,
64
+ "model": model_selector,
65
+ "state": state.dict(),
66
+ }
67
+ fout.write(json.dumps(data) + "\n")
68
+
69
+
70
+ def upvote_last_response(state, model_selector, request: gr.Request):
71
+ vote_last_response(state, "upvote", model_selector, request)
72
+ return ("",) + (disable_btn,) * 3
73
+
74
+
75
+ def downvote_last_response(state, model_selector, request: gr.Request):
76
+ vote_last_response(state, "downvote", model_selector, request)
77
+ return ("",) + (disable_btn,) * 3
78
+
79
+
80
+ def flag_last_response(state, model_selector, request: gr.Request):
81
+ vote_last_response(state, "flag", model_selector, request)
82
+ return ("",) + (disable_btn,) * 3
83
+
84
+ def regenerate(state, image_process_mode, seg_process_mode):
85
+ state.messages[-1][-1] = None
86
+ prev_human_msg = state.messages[-2]
87
+ if type(prev_human_msg[1]) in (tuple, list):
88
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, prev_human_msg[1][3], seg_process_mode, None, None)
89
+ state.skip_next = False
90
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
91
+
92
+
93
+ def clear_history(request: gr.Request):
94
+ state = default_conversation.copy()
95
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
96
+
97
+
98
+ def add_text(state, text, image, image_process_mode, seg, seg_process_mode, depth, depth_process_mode, request: gr.Request):
99
+ logger.info(f"add_text. len: {len(text)}")
100
+ if len(text) <= 0 and image is None:
101
+ state.skip_next = True
102
+ return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
103
+ if args.moderate:
104
+ flagged = violates_moderation(text)
105
+ if flagged:
106
+ state.skip_next = True
107
+ return (state, state.to_gradio_chatbot(), moderation_msg, None, None) + (
108
+ no_change_btn,) * 5
109
+
110
+ text = text[:1576] # Hard cut-off
111
+ if image is not None:
112
+ text = text[:1200] # Hard cut-off for images
113
+ if '<image>' not in text:
114
+ text = '<image>\n' + text
115
+ if seg is not None:
116
+ if '<seg>' not in text:
117
+ text = '<seg>\n' + text
118
+
119
+ text = (text, image, image_process_mode, seg, seg_process_mode, None, None)
120
+ if len(state.get_images(return_pil=True)) > 0:
121
+ state = default_conversation.copy()
122
+ state.append_message(state.roles[0], text)
123
+ state.append_message(state.roles[1], None)
124
+ state.skip_next = False
125
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
126
+
127
+
128
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
129
+ start_tstamp = time.time()
130
+ model_name = model_selector
131
+
132
+ if state.skip_next:
133
+ # This generate call is skipped due to invalid inputs
134
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
135
+ return
136
+
137
+ if len(state.messages) == state.offset + 2:
138
+ # First round of conversation
139
+ if "llava" in model_name.lower():
140
+ template_name = "llava_v1"
141
+ new_state = conv_templates[template_name].copy()
142
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
143
+ new_state.append_message(new_state.roles[1], None)
144
+ state = new_state
145
+
146
+ # Construct prompt
147
+ prompt = state.get_prompt()
148
+
149
+ all_images = state.get_images(return_pil=True)
150
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
151
+ for image, hash in zip(all_images, all_image_hash):
152
+ t = datetime.datetime.now()
153
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
154
+ if not os.path.isfile(filename):
155
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
156
+ image.save(filename)
157
+
158
+ all_segs = state.get_segs(return_pil=True)
159
+ all_seg_hash = [hashlib.md5(seg.tobytes()).hexdigest() for seg in all_segs]
160
+ for seg, hash in zip(all_segs, all_seg_hash):
161
+ t = datetime.datetime.now()
162
+ filename = os.path.join(LOGDIR, "serve_segs", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
163
+ if not os.path.isfile(filename):
164
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
165
+ seg.save(filename)
166
+
167
+ # Make requests
168
+ pload = {
169
+ "model": model_name,
170
+ "prompt": prompt,
171
+ "temperature": float(temperature),
172
+ "top_p": float(top_p),
173
+ "max_new_tokens": min(int(max_new_tokens), 1536),
174
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
175
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
176
+ "segs": f'List of {len(state.get_segs())} segs: {all_seg_hash}',
177
+ }
178
+ logger.info(f"==== request ====\n{pload}")
179
+
180
+ pload['images'] = state.get_images()
181
+ pload['segs'] = state.get_segs()
182
+
183
+ state.messages[-1][-1] = "▌"
184
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
185
+
186
+
187
+ try:
188
+ # Stream output
189
+ response = chat.generate_stream_gate(pload)
190
+ for chunk in response:
191
+ if chunk:
192
+ data = json.loads(chunk.decode())
193
+ if data["error_code"] == 0:
194
+ output = data["text"][len(prompt):].strip()
195
+ state.messages[-1][-1] = output + "▌"
196
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
197
+ else:
198
+ output = data["text"] + f" (error_code: {data['error_code']})"
199
+ state.messages[-1][-1] = output
200
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
201
+ return
202
+ time.sleep(0.03)
203
+ except:
204
+ state.messages[-1][-1] = server_error_msg
205
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
206
+ return
207
+
208
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
209
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
210
+
211
+ finish_tstamp = time.time()
212
+ logger.info(f"{output}")
213
+
214
+ with open(get_conv_log_filename(), "a") as fout:
215
+ data = {
216
+ "tstamp": round(finish_tstamp, 4),
217
+ "type": "chat",
218
+ "model": model_name,
219
+ "start": round(start_tstamp, 4),
220
+ "finish": round(start_tstamp, 4),
221
+ "state": state.dict(),
222
+ "images": all_image_hash,
223
+ "segs": all_seg_hash,
224
+ "ip": request.client.host,
225
+ }
226
+ fout.write(json.dumps(data) + "\n")
227
+
228
+ title_markdown = ("""
229
+ # 🌋 LLaVA: Large Language and Vision Assistant
230
+ [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
231
+ """)
232
+
233
+ tos_markdown = ("""
234
+ ### Terms of use
235
+ By using this service, users are required to agree to the following terms:
236
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
237
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
238
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
239
+ """)
240
+
241
+
242
+ learn_more_markdown = ("""
243
+ ### License
244
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
245
+ """)
246
+
247
+ block_css = """
248
+
249
+ #buttons button {
250
+ min-width: min(120px,100%);
251
+ }
252
+
253
+ """
254
+
255
+ def build_demo(embed_mode):
256
+
257
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
258
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
259
+ state = gr.State()
260
+
261
+ if not embed_mode:
262
+ gr.Markdown(title_markdown)
263
+
264
+ with gr.Row():
265
+ with gr.Column(scale=3):
266
+ with gr.Row(elem_id="model_selector_row"):
267
+ model_selector = gr.Dropdown(
268
+ choices=models,
269
+ value=models[0] if len(models) > 0 else "",
270
+ interactive=True,
271
+ show_label=False,
272
+ container=False)
273
+
274
+ # with gr.Row():
275
+ imagebox = gr.Image(type="pil", label="Image Input")
276
+ image_process_mode = gr.Radio(
277
+ ["Crop", "Resize", "Pad", "Default"],
278
+ value="Default",
279
+ label="Preprocess for non-square image", visible=False)
280
+
281
+ segbox = gr.Image(type="pil", label="Seg Map")
282
+ seg_process_mode = gr.Radio(
283
+ ["Crop", "Resize", "Pad", "Default"],
284
+ value="Default",
285
+ label="Preprocess for non-square Seg Map", visible=False)
286
+
287
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
288
+ gr.Examples(examples=[
289
+ [f"{cur_dir}/examples/3.jpg", f"{cur_dir}/examples/3_pan.png", "What objects can be seen in the image?"],
290
+ [f"{cur_dir}/examples/3.jpg", f"{cur_dir}/examples/3_ins.png", "What objects can be seen in the image?"],
291
+ ], inputs=[imagebox, segbox, textbox])
292
+
293
+ with gr.Accordion("Parameters", open=False) as parameter_row:
294
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
295
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
296
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
297
+
298
+ with gr.Column(scale=8):
299
+ chatbot = gr.Chatbot(elem_id="chatbot", label="VCoder Chatbot", height=550)
300
+ with gr.Row():
301
+ with gr.Column(scale=8):
302
+ textbox.render()
303
+ with gr.Column(scale=1, min_width=50):
304
+ submit_btn = gr.Button(value="Send", variant="primary")
305
+ with gr.Row(elem_id="buttons") as button_row:
306
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
307
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
308
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
309
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
310
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
311
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
312
+
313
+ if not embed_mode:
314
+ gr.Markdown(tos_markdown)
315
+ gr.Markdown(learn_more_markdown)
316
+
317
+ # Register listeners
318
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
319
+ upvote_btn.click(upvote_last_response,
320
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
321
+ downvote_btn.click(downvote_last_response,
322
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
323
+ flag_btn.click(flag_last_response,
324
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
325
+ regenerate_btn.click(regenerate, [state, image_process_mode, seg_process_mode],
326
+ [state, chatbot, textbox, imagebox, segbox] + btn_list).then(
327
+ http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
328
+ [state, chatbot] + btn_list)
329
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, segbox] + btn_list)
330
+
331
+ textbox.submit(add_text, [state, textbox, imagebox, image_process_mode, segbox, seg_process_mode], [state, chatbot, textbox, imagebox, segbox] + btn_list
332
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
333
+ [state, chatbot] + btn_list)
334
+ submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode, segbox, seg_process_mode], [state, chatbot, textbox, imagebox, segbox] + btn_list
335
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
336
+ [state, chatbot] + btn_list)
337
+
338
+ demo.load(load_demo_refresh_model_list, None, [state, model_selector])
339
+
340
+ return demo
341
+
342
+
343
+ if __name__ == "__main__":
344
+ parser = argparse.ArgumentParser()
345
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
346
+ parser.add_argument("--model-base", type=str, default=None)
347
+ parser.add_argument("--model-name", type=str)
348
+ parser.add_argument("--load-8bit", action="store_true")
349
+ parser.add_argument("--load-4bit", action="store_true")
350
+ parser.add_argument("--device", type=str, default="cuda")
351
+ parser.add_argument("--share", action="store_true")
352
+ parser.add_argument("--moderate", action="store_true")
353
+ parser.add_argument("--embed", action="store_true")
354
+ parser.add_argument("--concurrency-count", type=int, default=10)
355
+ parser.add_argument("--host", type=str, default="0.0.0.0")
356
+ parser.add_argument("--port", type=int)
357
+ args = parser.parse_args()
358
+ logger.info(f"args: {args}")
359
+
360
+ if args.model_name is None:
361
+ model_paths = args.model_path.split("/")
362
+ if model_paths[-1].startswith('checkpoint-'):
363
+ model_name = model_paths[-2] + "_" + model_paths[-1]
364
+ else:
365
+ model_name = model_paths[-1]
366
+ else:
367
+ model_name = args.model_name
368
+
369
+ models = [model_name]
370
+ chat = Chat(
371
+ args.model_path,
372
+ args.model_base,
373
+ args.model_name,
374
+ args.load_8bit,
375
+ args.load_4bit,
376
+ args.device,
377
+ logger
378
+ )
379
+
380
+ logger.info(args)
381
+ demo = build_demo(args.embed)
382
+ demo.queue(
383
+ concurrency_count=args.concurrency_count,
384
+ api_open=False
385
+ ).launch(
386
+ server_name=args.host,
387
+ server_port=args.port,
388
+ share=args.share
389
+ )
chat.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import json
6
+ import torch
7
+
8
+ from vcoder_llava.utils import server_error_msg
9
+ from vcoder_llava.model.builder import load_pretrained_model
10
+ from vcoder_llava.mm_utils import process_images, load_image_from_base64, tokenizer_seg_token, tokenizer_depth_seg_token, tokenizer_image_token, KeywordsStoppingCriteria
11
+ from vcoder_llava.constants import (
12
+ IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
13
+ SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
14
+ DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN
15
+ )
16
+ from transformers import TextIteratorStreamer
17
+
18
+ class Chat:
19
+ def __init__(self, model_path, model_base, model_name,
20
+ load_8bit, load_4bit, device, logger):
21
+ if model_path.endswith("/"):
22
+ model_path = model_path[:-1]
23
+ if model_name is None:
24
+ model_paths = model_path.split("/")
25
+ if model_paths[-1].startswith('checkpoint-'):
26
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
27
+ else:
28
+ self.model_name = model_paths[-1]
29
+ else:
30
+ self.model_name = model_name
31
+
32
+ self.device = device
33
+ logger.info(f"Loading the model {self.model_name} ...")
34
+ self.tokenizer, self.model, self.image_processor, self.seg_image_processor, self.depth_image_processor, self.context_len = load_pretrained_model(
35
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
36
+ self.is_multimodal = 'llava' in self.model_name.lower()
37
+ self.is_seg = "seg_llava" in self.model_name.lower()
38
+ self.is_depth = False
39
+
40
+ @torch.inference_mode()
41
+ def generate_stream(self, params):
42
+ tokenizer, model, image_processor, seg_image_processor, depth_image_processor = self.tokenizer, self.model, self.image_processor, self.seg_image_processor, self.depth_image_processor
43
+
44
+ prompt = params["prompt"]
45
+ ori_prompt = prompt
46
+ images = params.get("images", None)
47
+ segs = params.get("segs", None)
48
+ depths = params.get("depths", None)
49
+ num_image_tokens = 0
50
+ num_seg_tokens = 0
51
+ num_depth_tokens = 0
52
+ if images is not None and len(images) > 0 and self.is_multimodal:
53
+ if len(images) > 0:
54
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
55
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
56
+
57
+ images = [load_image_from_base64(image) for image in images]
58
+ images = process_images(images, image_processor, model.config)
59
+
60
+ if type(images) is list:
61
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
62
+ else:
63
+ images = images.to(self.model.device, dtype=torch.float16)
64
+
65
+ replace_token = DEFAULT_IMAGE_TOKEN
66
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
67
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
68
+
69
+ if segs is not None and len(segs) > 0 and self.is_seg:
70
+ if len(segs) != prompt.count(DEFAULT_SEG_TOKEN):
71
+ raise ValueError("Number of segs does not match number of <seg> tokens in prompt")
72
+
73
+ segs = [load_image_from_base64(seg) for seg in segs]
74
+ segs = process_images(segs, seg_image_processor, model.config)
75
+
76
+ if type(segs) is list:
77
+ segs = [seg.to(self.model.device, dtype=torch.float16) for seg in segs]
78
+ else:
79
+ segs = segs.to(self.model.device, dtype=torch.float16)
80
+
81
+ replace_seg_token = DEFAULT_SEG_TOKEN
82
+ prompt = prompt.replace(DEFAULT_SEG_TOKEN, replace_seg_token)
83
+ num_seg_tokens = prompt.count(replace_seg_token) * model.get_vision_tower().num_patches
84
+
85
+ if depths is not None and len(depths) > 0 and self.is_depth:
86
+ if len(depths) != prompt.count(DEFAULT_DEPTH_TOKEN):
87
+ raise ValueError("Number of depths does not match number of <depth> tokens in prompt")
88
+
89
+ depths = [load_image_from_base64(depth) for depth in depths]
90
+ depths = process_images(depths, depth_image_processor, model.config)
91
+
92
+ if type(depths) is list:
93
+ depths = [depth.to(self.model.device, dtype=torch.float16) for depth in depths]
94
+ else:
95
+ depths = depths.to(self.model.device, dtype=torch.float16)
96
+
97
+ replace_depth_token = DEFAULT_DEPTH_TOKEN
98
+ prompt = prompt.replace(DEFAULT_DEPTH_TOKEN, replace_depth_token)
99
+ num_depth_tokens = prompt.count(replace_depth_token) * model.get_vision_tower().num_patches
100
+ else:
101
+ depths = None
102
+ else:
103
+ segs = None
104
+ depths = None
105
+ else:
106
+ images = None
107
+ segs = None
108
+ depths = None
109
+ image_args = {"images": images, "segs": segs, "depths": depths}
110
+ else:
111
+ images = None
112
+ segs = None
113
+ depths = None
114
+ image_args = {}
115
+
116
+ temperature = float(params.get("temperature", 1.0))
117
+ top_p = float(params.get("top_p", 1.0))
118
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
119
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
120
+ stop_str = params.get("stop", None)
121
+ do_sample = True if temperature > 0.001 else False
122
+
123
+ if self.is_seg:
124
+ if self.is_depth:
125
+ input_ids = tokenizer_depth_seg_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX, DEPTH_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
126
+ else:
127
+ input_ids = tokenizer_seg_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
128
+ else:
129
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
130
+ keywords = [stop_str]
131
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
132
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
133
+
134
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens - num_seg_tokens - num_depth_tokens)
135
+
136
+ if max_new_tokens < 1:
137
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
138
+ return
139
+
140
+ generated_text = model.generate(
141
+ inputs=input_ids,
142
+ do_sample=do_sample,
143
+ temperature=temperature,
144
+ top_p=top_p,
145
+ max_new_tokens=max_new_tokens,
146
+ streamer=streamer,
147
+ stopping_criteria=[stopping_criteria],
148
+ use_cache=True,
149
+ **image_args
150
+ )
151
+ # thread.start()
152
+
153
+ generated_text = ori_prompt
154
+ for new_text in streamer:
155
+ generated_text += new_text
156
+ if generated_text.endswith(stop_str):
157
+ generated_text = generated_text[:-len(stop_str)]
158
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode()
159
+
160
+ def generate_stream_gate(self, params):
161
+ try:
162
+ for x in self.generate_stream(params):
163
+ yield x
164
+ except ValueError as e:
165
+ print("Caught ValueError:", e)
166
+ ret = {
167
+ "text": server_error_msg,
168
+ "error_code": 1,
169
+ }
170
+ yield json.dumps(ret).encode()
171
+ except torch.cuda.CudaError as e:
172
+ print("Caught torch.cuda.CudaError:", e)
173
+ ret = {
174
+ "text": server_error_msg,
175
+ "error_code": 1,
176
+ }
177
+ yield json.dumps(ret).encode()
178
+ except Exception as e:
179
+ print("Caught Unknown Error", e)
180
+ ret = {
181
+ "text": server_error_msg,
182
+ "error_code": 1,
183
+ }
184
+ yield json.dumps(ret).encode()
185
+
186
+
187
+ if __name__ == "__main__":
188
+ parser = argparse.ArgumentParser()
189
+ parser.add_argument("--host", type=str, default="localhost")
190
+ parser.add_argument("--port", type=int, default=21002)
191
+ parser.add_argument("--worker-address", type=str,
192
+ default="http://localhost:21002")
193
+ parser.add_argument("--controller-address", type=str,
194
+ default="http://localhost:21001")
195
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
196
+ parser.add_argument("--model-base", type=str, default=None)
197
+ parser.add_argument("--model-name", type=str)
198
+ parser.add_argument("--device", type=str, default="cuda")
199
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
200
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
201
+ parser.add_argument("--stream-interval", type=int, default=1)
202
+ parser.add_argument("--no-register", action="store_true")
203
+ parser.add_argument("--load-8bit", action="store_true")
204
+ parser.add_argument("--load-4bit", action="store_true")
205
+ args = parser.parse_args()
examples/3.jpg ADDED

Git LFS Details

  • SHA256: 721367369f53ecefddeeb16383eceab43835e143fd1d9aeed05d2f3ad9356410
  • Pointer size: 131 Bytes
  • Size of remote file: 268 kB
examples/3_ins.png ADDED

Git LFS Details

  • SHA256: 817b7679286d4079fd0a165d8c7689b5c7a89c0217a6da27be728576cc7a04d8
  • Pointer size: 129 Bytes
  • Size of remote file: 8.92 kB
examples/3_pan.png ADDED

Git LFS Details

  • SHA256: f2e392734be1a44aee7609459ab50ea01d6a7866d5170b46c15f7c00512d1701
  • Pointer size: 130 Bytes
  • Size of remote file: 15.3 kB
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu117
2
+ torch==2.0.0+cu117
3
+ packaging
4
+ Pillow
5
+ huggingface_hub
6
+ matplotlib
7
+ flash-attn
8
+ gradio
9
+ fastapi
10
+ numpy,
11
+ requests
12
+ sentencepiece
13
+ tokenizers>=0.12.1,
14
+ uvicorn
15
+ chardet,
16
+ shortuuid
17
+ httpx==0.24.0,
18
+ spacy
19
+ inflect
20
+ peft==0.4.0
21
+ num2words,
22
+ transformers==4.31.0,
23
+ accelerate==0.21.0,
24
+ bitsandbytes==0.41.0,
25
+ scikit-learn==1.2.2,
26
+ sentencepiece==0.1.99,
27
+ einops==0.6.1
28
+ einops-exts==0.0.4
29
+ timm==0.6.13,
30
+ gradio_client==0.2.9
vcoder_llava/.DS_Store ADDED
Binary file (6.15 kB). View file
 
vcoder_llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM, VCoderLlavaLlamaForCausalLM, VCoderDSLlavaLlamaForCausalLM
vcoder_llava/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LOGDIR = "."
2
+
3
+ # Model Constants
4
+ IGNORE_INDEX = -100
5
+ IMAGE_TOKEN_INDEX = -200
6
+ DEFAULT_IMAGE_TOKEN = "<image>"
7
+
8
+ SEG_TOKEN_INDEX = -300
9
+ DEFAULT_SEG_TOKEN = "<seg>"
10
+
11
+ DEPTH_TOKEN_INDEX = -400
12
+ DEFAULT_DEPTH_TOKEN = "<depth>"
vcoder_llava/data_utils.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import spacy
3
+ from word2number import w2n
4
+ import inflect
5
+ from num2words import num2words
6
+ p = inflect.engine()
7
+ import numpy as np
8
+ import random
9
+
10
+ nltk.download('punkt')
11
+ nltk.download('averaged_perceptron_tagger')
12
+ nlp = spacy.load('en_core_web_sm')
13
+
14
+ # object names with two words
15
+ SPECIAL_WORDS = ['baseball bat',
16
+ 'baseball glove',
17
+ 'cell phone',
18
+ 'dining table',
19
+ 'fire hydrant',
20
+ 'french fries',
21
+ 'hair drier',
22
+ 'hot dog',
23
+ 'parking meter',
24
+ 'potted plant',
25
+ 'soccer ball',
26
+ 'soccer player',
27
+ 'sports ball',
28
+ 'stop sign',
29
+ 'teddy bear',
30
+ 'tennis racket',
31
+ 'toy figure',
32
+ 'traffic light',
33
+ 'wine glass']
34
+
35
+ def _get_nouns(lines):
36
+ # function to test if something is a noun
37
+ present_words = []
38
+ for s in SPECIAL_WORDS:
39
+ if s in lines:
40
+ present_words.append(s)
41
+
42
+ for w in present_words:
43
+ lines = lines.replace(w, "")
44
+
45
+ is_noun = lambda pos: pos[:2] == 'NN' or pos[:2] == 'NNP'
46
+ # do the nlp stuff
47
+ tokenized = nltk.word_tokenize(lines)
48
+ nouns = [word for (word, pos) in nltk.pos_tag(tokenized) if is_noun(pos)]
49
+ noun_dict = {}
50
+ if "objects" in nouns:
51
+ nouns.remove("objects")
52
+ if "image" in nouns:
53
+ nouns.remove("image")
54
+
55
+ for n in nouns:
56
+ if n not in noun_dict.keys():
57
+ noun_dict[n] = 1
58
+ else:
59
+ noun_dict[n] += 1
60
+ nouns = {}
61
+ for k, v in noun_dict.items():
62
+ if not (k == "bus" or k == "skis"):
63
+ if v == 1:
64
+ if p.singular_noun(k):
65
+ k = p.singular_noun(k)
66
+ else:
67
+ if not p.singular_noun(k):
68
+ k = p.plural(k)
69
+ try:
70
+ w2n.word_to_num(k)
71
+ except:
72
+ if len(k) >= 3:
73
+ if k == "ski":
74
+ k = "skis"
75
+ elif k == "gras":
76
+ k = "grass"
77
+ nouns[k] = v
78
+ for w in present_words:
79
+ nouns[w] = 1
80
+ return nouns
81
+
82
+ def _get_num_nouns(lines):
83
+ lines = lines.replace(":", "").replace(".", "")
84
+ doc = nlp(lines)
85
+ num_nouns = [chunk.text for chunk in doc.noun_chunks if any(token.pos_ == 'NUM' for token in chunk)]
86
+
87
+ num_noun_dict = {}
88
+ for n in num_nouns:
89
+ nums = n.split(", ")
90
+ for n in nums:
91
+ try:
92
+ w = " ".join(n.split(' ')[1:])
93
+ if w == "ski":
94
+ w = "skis"
95
+ num_noun_dict[w] = w2n.word_to_num(n.split(' ')[0])
96
+ except:
97
+ pass
98
+
99
+ return num_noun_dict
100
+
101
+
102
+ def _obtain_nouns(gt):
103
+ gt = gt.replace("hair dryer", "hair drier").lower()
104
+ nouns_gt = _get_nouns(gt)
105
+
106
+ num_nouns_gt = _get_num_nouns(gt)
107
+
108
+ com_keys = []
109
+ for k in nouns_gt.keys():
110
+ if p.plural(k) in num_nouns_gt.keys():
111
+ com_keys.append(k)
112
+ for k in com_keys:
113
+ del nouns_gt[k]
114
+
115
+ num_nouns_gt = {**num_nouns_gt, **nouns_gt}
116
+
117
+ return num_nouns_gt
118
+
119
+ def generate_qa_pairs(text):
120
+ num_nouns = _obtain_nouns(text)
121
+ qa_pairs = []
122
+
123
+ for obj, count in num_nouns.items():
124
+ # Count question
125
+ if count == 1:
126
+ plural_obj = p.plural(obj)
127
+ else:
128
+ plural_obj = obj
129
+ count_question = f"How many {plural_obj} are there in the image?"
130
+ count_answer = f"There {'is' if count == 1 else 'are'} {num2words(count)} {obj} in the image."
131
+ qa_pairs.append((count_question, count_answer))
132
+
133
+ prob_positive = np.random.uniform(0,1.)
134
+
135
+ if prob_positive > 0.7 or count == 1:
136
+ numeric_presence_question = f"{'Is' if count == 1 else 'Are'} there {num2words(count)} {obj} in the image?"
137
+ numeric_presence_answer = "Yes."
138
+ elif count > 1:
139
+ numbers = [i for i in range(2, count + 6) if i != count]
140
+ # Select a random number from the range
141
+ cnt = random.choice(numbers)
142
+ numeric_presence_question = f"{'Is' if cnt == 1 else 'Are'} there {num2words(cnt)} {obj} in the image?"
143
+ numeric_presence_answer = "No."
144
+
145
+ qa_pairs.append((numeric_presence_question, numeric_presence_answer))
146
+ random.shuffle(qa_pairs)
147
+
148
+ return random.sample(qa_pairs, min(len(qa_pairs), random.choice([1, 2, 3, 4, 5, 6])))
149
+
150
+ if __name__ == "__main__":
151
+
152
+ text = "The objects present in the image are: wall, ceiling, shelf, cabinet, counter, dining table, two people, eighteen bottles, two wine glasses, refrigerator, tv, bowl"
153
+
154
+ qa = generate_qa_pairs(text)
155
+ from icecream import ic
156
+ ic(qa)
157
+
vcoder_llava/mm_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from vcoder_llava.constants import IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX, DEPTH_TOKEN_INDEX
8
+
9
+
10
+ def load_image_from_base64(image):
11
+ return Image.open(BytesIO(base64.b64decode(image)))
12
+
13
+
14
+ def expand2square(pil_img, background_color):
15
+ width, height = pil_img.size
16
+ if width == height:
17
+ return pil_img
18
+ elif width > height:
19
+ result = Image.new(pil_img.mode, (width, width), background_color)
20
+ result.paste(pil_img, (0, (width - height) // 2))
21
+ return result
22
+ else:
23
+ result = Image.new(pil_img.mode, (height, height), background_color)
24
+ result.paste(pil_img, ((height - width) // 2, 0))
25
+ return result
26
+
27
+
28
+ def process_images(images, image_processor, model_cfg):
29
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30
+ new_images = []
31
+ if image_aspect_ratio == 'pad':
32
+ for image in images:
33
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
34
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
35
+ new_images.append(image)
36
+ else:
37
+ return image_processor(images, return_tensors='pt')['pixel_values']
38
+ if all(x.shape == new_images[0].shape for x in new_images):
39
+ new_images = torch.stack(new_images, dim=0)
40
+ return new_images
41
+
42
+
43
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
44
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
45
+
46
+ def insert_separator(X, sep):
47
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
48
+
49
+ input_ids = []
50
+ offset = 0
51
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
52
+ offset = 1
53
+ input_ids.append(prompt_chunks[0][0])
54
+
55
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
56
+ input_ids.extend(x[offset:])
57
+
58
+ if return_tensors is not None:
59
+ if return_tensors == 'pt':
60
+ return torch.tensor(input_ids, dtype=torch.long)
61
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
62
+ return input_ids
63
+
64
+
65
+ def tokenizer_seg_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, seg_token_index=SEG_TOKEN_INDEX, return_tensors=None):
66
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<seg>\n<image>')]
67
+
68
+ def insert_separator(X, sep):
69
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
70
+
71
+ input_ids = []
72
+ offset = 0
73
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
74
+ offset = 1
75
+ input_ids.append(prompt_chunks[0][0])
76
+
77
+ for x in insert_separator(prompt_chunks, [seg_token_index, image_token_index] * (offset + 1)):
78
+ if seg_token_index in x:
79
+ input_ids.extend(x[offset:-1])
80
+ else:
81
+ input_ids.extend(x[offset:])
82
+
83
+ if return_tensors is not None:
84
+ if return_tensors == 'pt':
85
+ return torch.tensor(input_ids, dtype=torch.long)
86
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
87
+ return input_ids
88
+
89
+ def _tokenizer_depth_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, seg_token_index=SEG_TOKEN_INDEX, depth_token_index=DEPTH_TOKEN_INDEX, return_tensors=None):
90
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<depth>\n<seg>\n<image>')]
91
+
92
+ def insert_separator(X, sep):
93
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
94
+
95
+ input_ids = []
96
+ offset = 0
97
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
98
+ offset = 1
99
+ input_ids.append(prompt_chunks[0][0])
100
+
101
+ for x in insert_separator(prompt_chunks, [image_token_index, depth_token_index, seg_token_index] * (offset + 1)):
102
+ if depth_token_index in x and seg_token_index in x:
103
+ input_ids.extend(x[:3])
104
+ else:
105
+ input_ids.extend(x[offset:])
106
+
107
+ if return_tensors is not None:
108
+ if return_tensors == 'pt':
109
+ return torch.tensor(input_ids, dtype=torch.long)
110
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
111
+ return input_ids
112
+
113
+ def tokenizer_depth_seg_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, seg_token_index=SEG_TOKEN_INDEX, depth_token_index=DEPTH_TOKEN_INDEX, return_tensors=None):
114
+ if "<depth>" in prompt:
115
+ return _tokenizer_depth_token(prompt, tokenizer, image_token_index, seg_token_index, depth_token_index, return_tensors)
116
+ else:
117
+ return tokenizer_seg_token(prompt, tokenizer, image_token_index, seg_token_index, return_tensors)
118
+
119
+
120
+ def get_model_name_from_path(model_path):
121
+ model_path = model_path.strip("/")
122
+ model_paths = model_path.split("/")
123
+ if model_paths[-1].startswith('checkpoint-'):
124
+ return model_paths[-2] + "_" + model_paths[-1]
125
+ else:
126
+ return model_paths[-1]
127
+
128
+ class KeywordsStoppingCriteria(StoppingCriteria):
129
+ def __init__(self, keywords, tokenizer, input_ids):
130
+ self.keywords = keywords
131
+ self.keyword_ids = []
132
+ for keyword in keywords:
133
+ cur_keyword_ids = tokenizer(keyword).input_ids
134
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
135
+ cur_keyword_ids = cur_keyword_ids[1:]
136
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
137
+ self.tokenizer = tokenizer
138
+ self.start_len = input_ids.shape[1]
139
+
140
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
141
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
142
+ offset = min(output_ids.shape[1] - self.start_len, 3)
143
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
144
+ for keyword_id in self.keyword_ids:
145
+ if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
146
+ return True
147
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
148
+ for keyword in self.keywords:
149
+ if keyword in outputs:
150
+ return True
151
+ return False
vcoder_llava/model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
vcoder_llava/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2
+ from .language_model.vcoder_llava_llama import VCoderLlavaLlamaForCausalLM, VCoderLlavaConfig
3
+ from .language_model.vcoder_ds_llava_llama import VCoderDSLlavaLlamaForCausalLM, VCoderDSLlavaConfig
vcoder_llava/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from vcoder_llava import LlavaLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
vcoder_llava/model/builder.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from vcoder_llava.model import *
23
+
24
+
25
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
26
+ kwargs = {"device_map": device_map}
27
+
28
+ if load_8bit:
29
+ kwargs['load_in_8bit'] = True
30
+ elif load_4bit:
31
+ kwargs['load_in_4bit'] = True
32
+ kwargs['quantization_config'] = BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_compute_dtype=torch.float16,
35
+ bnb_4bit_use_double_quant=True,
36
+ bnb_4bit_quant_type='nf4'
37
+ )
38
+ else:
39
+ kwargs['torch_dtype'] = torch.float16
40
+ if 'llava' in model_name.lower():
41
+ # Load LLaVA model
42
+ if 'lora' in model_name.lower() and model_base is None:
43
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
44
+ if 'lora' in model_name.lower() and model_base is not None:
45
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
46
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
47
+ print('Loading LLaVA from base model...')
48
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
49
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
50
+ if model.lm_head.weight.shape[0] != token_num:
51
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
52
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
53
+
54
+ print('Loading additional LLaVA weights...')
55
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
56
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
57
+ else:
58
+ # this is probably from HF Hub
59
+ from huggingface_hub import hf_hub_download
60
+ def load_from_hf(repo_id, filename, subfolder=None):
61
+ cache_file = hf_hub_download(
62
+ repo_id=repo_id,
63
+ filename=filename,
64
+ subfolder=subfolder)
65
+ return torch.load(cache_file, map_location='cpu')
66
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
67
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
68
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
69
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
70
+ model.load_state_dict(non_lora_trainables, strict=False)
71
+
72
+ from peft import PeftModel
73
+ print('Loading LoRA weights...')
74
+ model = PeftModel.from_pretrained(model, model_path)
75
+ print('Merging LoRA weights...')
76
+ model = model.merge_and_unload()
77
+ print('Model is loaded...')
78
+ elif model_base is not None:
79
+ # this may be mm projector only
80
+ print('Loading LLaVA from base model...')
81
+ if 'vcoder_ds_llava' in model_name.lower():
82
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
83
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
84
+ model = VCoderDSLlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
85
+ elif 'vcoder_llava' in model_name.lower():
86
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
87
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
88
+ model = VCoderLlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
89
+ else:
90
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
92
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
93
+
94
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
95
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
96
+ model.load_state_dict(mm_projector_weights, strict=False)
97
+ else:
98
+ if 'vcoder_ds_llava' in model_name.lower():
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
100
+ model = VCoderDSLlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
101
+ elif 'vcoder_llava' in model_name.lower():
102
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
103
+ model = VCoderLlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
104
+ else:
105
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
106
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
107
+ else:
108
+ # Load language model
109
+ if model_base is not None:
110
+ # PEFT model
111
+ from peft import PeftModel
112
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
113
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
114
+ print(f"Loading LoRA weights from {model_path}")
115
+ model = PeftModel.from_pretrained(model, model_path)
116
+ print(f"Merging weights")
117
+ model = model.merge_and_unload()
118
+ print('Convert to FP16...')
119
+ model.to(torch.float16)
120
+ else:
121
+ use_fast = False
122
+ if 'mpt' in model_name.lower():
123
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
124
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
125
+ else:
126
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
127
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
128
+
129
+ image_processor = None
130
+
131
+ if hasattr(model.config, "max_sequence_length"):
132
+ context_len = model.config.max_sequence_length
133
+ else:
134
+ context_len = 2048
135
+
136
+ if 'llava' in model_name.lower():
137
+ vision_tower = model.get_vision_tower()
138
+ if not vision_tower.is_loaded:
139
+ vision_tower.load_model()
140
+ vision_tower.to(device=device, dtype=torch.float16)
141
+ image_processor = vision_tower.image_processor
142
+
143
+ seg_image_processor = None
144
+ if 'vcoder' in model_name.lower():
145
+ seg_image_processor = image_processor
146
+
147
+ depth_image_processor = None
148
+ if "ds" in model_name.lower():
149
+ depth_image_processor = image_processor
150
+
151
+ model.requires_grad_(False)
152
+ return tokenizer, model, image_processor, seg_image_processor, depth_image_processor, context_len
vcoder_llava/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m vcoder_llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from vcoder_llava.model import *
10
+ from vcoder_llava.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
vcoder_llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ LlamaConfig, LlamaModel, LlamaForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(LlamaConfig):
31
+ model_type = "llava"
32
+
33
+
34
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(LlavaLlamaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = LlavaLlamaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ images_cd: Optional[torch.FloatTensor] = None,
68
+ cd_beta: Optional[torch.FloatTensor] = None,
69
+ cd_alpha: Optional[torch.FloatTensor] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
73
+ output_hidden_states = (
74
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
75
+ )
76
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
77
+
78
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
79
+
80
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
81
+ outputs = self.model(
82
+ input_ids=input_ids,
83
+ attention_mask=attention_mask,
84
+ past_key_values=past_key_values,
85
+ inputs_embeds=inputs_embeds,
86
+ use_cache=use_cache,
87
+ output_attentions=output_attentions,
88
+ output_hidden_states=output_hidden_states,
89
+ return_dict=return_dict
90
+ )
91
+
92
+ hidden_states = outputs[0]
93
+ logits = self.lm_head(hidden_states)
94
+
95
+ loss = None
96
+ if labels is not None:
97
+ # Shift so that tokens < n predict n
98
+ shift_logits = logits[..., :-1, :].contiguous()
99
+ shift_labels = labels[..., 1:].contiguous()
100
+ # Flatten the tokens
101
+ loss_fct = CrossEntropyLoss()
102
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
103
+ shift_labels = shift_labels.view(-1)
104
+ # Enable model/pipeline parallelism
105
+ shift_labels = shift_labels.to(shift_logits.device)
106
+ loss = loss_fct(shift_logits, shift_labels)
107
+
108
+ if not return_dict:
109
+ output = (logits,) + outputs[1:]
110
+ return (loss,) + output if loss is not None else output
111
+
112
+ return CausalLMOutputWithPast(
113
+ loss=loss,
114
+ logits=logits,
115
+ past_key_values=outputs.past_key_values,
116
+ hidden_states=outputs.hidden_states,
117
+ attentions=outputs.attentions,
118
+ )
119
+
120
+ def prepare_inputs_for_generation(
121
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
122
+ ):
123
+ if past_key_values:
124
+ input_ids = input_ids[:, -1:]
125
+
126
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
127
+ if inputs_embeds is not None and past_key_values is None:
128
+ model_inputs = {"inputs_embeds": inputs_embeds}
129
+ else:
130
+ model_inputs = {"input_ids": input_ids}
131
+
132
+ model_inputs.update(
133
+ {
134
+ "past_key_values": past_key_values,
135
+ "use_cache": kwargs.get("use_cache"),
136
+ "attention_mask": attention_mask,
137
+ "images": kwargs.get("images", None),
138
+ }
139
+ )
140
+ return model_inputs
141
+
142
+ def prepare_inputs_for_generation_cd(
143
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
144
+ ):
145
+ if past_key_values:
146
+ input_ids = input_ids[:, -1:]
147
+
148
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
149
+ if inputs_embeds is not None and past_key_values is None:
150
+ model_inputs = {"inputs_embeds": inputs_embeds}
151
+ else:
152
+ model_inputs = {"input_ids": input_ids}
153
+
154
+ model_inputs.update(
155
+ {
156
+ "past_key_values": past_key_values,
157
+ "use_cache": kwargs.get("use_cache"),
158
+ "attention_mask": attention_mask,
159
+ "images": kwargs.get("images_cd", None),
160
+ }
161
+ )
162
+ return model_inputs
163
+
164
+ AutoConfig.register("llava", LlavaConfig)
165
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
vcoder_llava/model/language_model/vcoder_ds_llava_llama.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ LlamaConfig, LlamaModel, LlamaForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..vcoder_ds_llava_arch import VCoderDSLlavaMetaModel, VCoderDSLlavaMetaForCausalLM
28
+
29
+
30
+ class VCoderDSLlavaConfig(LlamaConfig):
31
+ model_type = "vcoder_ds_llava"
32
+
33
+
34
+ class VCoderDSLlavaLlamaModel(VCoderDSLlavaMetaModel, LlamaModel):
35
+ config_class = VCoderDSLlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(VCoderDSLlavaLlamaModel, self).__init__(config)
39
+
40
+
41
+ class VCoderDSLlavaLlamaForCausalLM(LlamaForCausalLM, VCoderDSLlavaMetaForCausalLM):
42
+ config_class = VCoderDSLlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = VCoderDSLlavaLlamaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ images: Optional[torch.FloatTensor] = None,
68
+ segs: Optional[torch.FloatTensor] = None,
69
+ depths: Optional[torch.FloatTensor] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
73
+ output_hidden_states = (
74
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
75
+ )
76
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
77
+
78
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, segs, depths)
79
+
80
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
81
+ outputs = self.model(
82
+ input_ids=input_ids,
83
+ attention_mask=attention_mask,
84
+ past_key_values=past_key_values,
85
+ inputs_embeds=inputs_embeds,
86
+ use_cache=use_cache,
87
+ output_attentions=output_attentions,
88
+ output_hidden_states=output_hidden_states,
89
+ return_dict=return_dict
90
+ )
91
+
92
+ hidden_states = outputs[0]
93
+ logits = self.lm_head(hidden_states)
94
+
95
+ loss = None
96
+ if labels is not None:
97
+ # Shift so that tokens < n predict n
98
+ shift_logits = logits[..., :-1, :].contiguous()
99
+ shift_labels = labels[..., 1:].contiguous()
100
+ # Flatten the tokens
101
+ loss_fct = CrossEntropyLoss()
102
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
103
+ shift_labels = shift_labels.view(-1)
104
+ # Enable model/pipeline parallelism
105
+ shift_labels = shift_labels.to(shift_logits.device)
106
+ loss = loss_fct(shift_logits, shift_labels)
107
+
108
+ if not return_dict:
109
+ output = (logits,) + outputs[1:]
110
+ return (loss,) + output if loss is not None else output
111
+
112
+ return CausalLMOutputWithPast(
113
+ loss=loss,
114
+ logits=logits,
115
+ past_key_values=outputs.past_key_values,
116
+ hidden_states=outputs.hidden_states,
117
+ attentions=outputs.attentions,
118
+ )
119
+
120
+ def prepare_inputs_for_generation(
121
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
122
+ ):
123
+ if past_key_values:
124
+ input_ids = input_ids[:, -1:]
125
+
126
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
127
+ if inputs_embeds is not None and past_key_values is None:
128
+ model_inputs = {"inputs_embeds": inputs_embeds}
129
+ else:
130
+ model_inputs = {"input_ids": input_ids}
131
+
132
+ model_inputs.update(
133
+ {
134
+ "past_key_values": past_key_values,
135
+ "use_cache": kwargs.get("use_cache"),
136
+ "attention_mask": attention_mask,
137
+ "images": kwargs.get("images", None),
138
+ "segs": kwargs.get("segs", None),
139
+ "depths": kwargs.get("depths", None),
140
+ }
141
+ )
142
+ return model_inputs
143
+
144
+ AutoConfig.register("vcoder_ds_llava", VCoderDSLlavaConfig)
145
+ AutoModelForCausalLM.register(VCoderDSLlavaConfig, VCoderDSLlavaLlamaForCausalLM)
vcoder_llava/model/language_model/vcoder_llava_llama.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ LlamaConfig, LlamaModel, LlamaForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..vcoder_llava_arch import VCoderLlavaMetaModel, VCoderLlavaMetaForCausalLM
28
+
29
+
30
+ class VCoderLlavaConfig(LlamaConfig):
31
+ model_type = "vcoder_llava"
32
+
33
+
34
+ class VCoderLlavaLlamaModel(VCoderLlavaMetaModel, LlamaModel):
35
+ config_class = VCoderLlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(VCoderLlavaLlamaModel, self).__init__(config)
39
+
40
+
41
+ class VCoderLlavaLlamaForCausalLM(LlamaForCausalLM, VCoderLlavaMetaForCausalLM):
42
+ config_class = VCoderLlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = VCoderLlavaLlamaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ segs: Optional[torch.FloatTensor] = None,
68
+ return_dict: Optional[bool] = None,
69
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
70
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
+ output_hidden_states = (
72
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
73
+ )
74
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
+
76
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, segs)
77
+
78
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
79
+ outputs = self.model(
80
+ input_ids=input_ids,
81
+ attention_mask=attention_mask,
82
+ past_key_values=past_key_values,
83
+ inputs_embeds=inputs_embeds,
84
+ use_cache=use_cache,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ return_dict=return_dict
88
+ )
89
+
90
+ hidden_states = outputs[0]
91
+ logits = self.lm_head(hidden_states)
92
+
93
+ loss = None
94
+ if labels is not None:
95
+ # Shift so that tokens < n predict n
96
+ shift_logits = logits[..., :-1, :].contiguous()
97
+ shift_labels = labels[..., 1:].contiguous()
98
+ # Flatten the tokens
99
+ loss_fct = CrossEntropyLoss()
100
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
101
+ shift_labels = shift_labels.view(-1)
102
+ # Enable model/pipeline parallelism
103
+ shift_labels = shift_labels.to(shift_logits.device)
104
+ loss = loss_fct(shift_logits, shift_labels)
105
+
106
+ if not return_dict:
107
+ output = (logits,) + outputs[1:]
108
+ return (loss,) + output if loss is not None else output
109
+
110
+ return CausalLMOutputWithPast(
111
+ loss=loss,
112
+ logits=logits,
113
+ past_key_values=outputs.past_key_values,
114
+ hidden_states=outputs.hidden_states,
115
+ attentions=outputs.attentions,
116
+ )
117
+
118
+ def prepare_inputs_for_generation(
119
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
120
+ ):
121
+ if past_key_values:
122
+ input_ids = input_ids[:, -1:]
123
+
124
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
125
+ if inputs_embeds is not None and past_key_values is None:
126
+ model_inputs = {"inputs_embeds": inputs_embeds}
127
+ else:
128
+ model_inputs = {"input_ids": input_ids}
129
+
130
+ model_inputs.update(
131
+ {
132
+ "past_key_values": past_key_values,
133
+ "use_cache": kwargs.get("use_cache"),
134
+ "attention_mask": attention_mask,
135
+ "images": kwargs.get("images", None),
136
+ "segs": kwargs.get("segs", None),
137
+ }
138
+ )
139
+ return model_inputs
140
+
141
+ AutoConfig.register("vcoder_llava", VCoderLlavaConfig)
142
+ AutoModelForCausalLM.register(VCoderLlavaConfig, VCoderLlavaLlamaForCausalLM)
vcoder_llava/model/llava_arch.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+
24
+ from vcoder_llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
25
+
26
+
27
+ class LlavaMetaModel:
28
+
29
+ def __init__(self, config):
30
+ super(LlavaMetaModel, self).__init__(config)
31
+
32
+ if hasattr(config, "mm_vision_tower"):
33
+ self.vision_tower = build_vision_tower(config, delay_load=True)
34
+ self.mm_projector = build_vision_projector(config)
35
+
36
+ def get_vision_tower(self):
37
+ vision_tower = getattr(self, 'vision_tower', None)
38
+ if type(vision_tower) is list:
39
+ vision_tower = vision_tower[0]
40
+ return vision_tower
41
+
42
+ def initialize_vision_modules(self, model_args, fsdp=None):
43
+ vision_tower = model_args.vision_tower
44
+ mm_vision_select_layer = model_args.mm_vision_select_layer
45
+ mm_vision_select_feature = model_args.mm_vision_select_feature
46
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
47
+
48
+ self.config.mm_vision_tower = vision_tower
49
+
50
+ if self.get_vision_tower() is None:
51
+ vision_tower = build_vision_tower(model_args)
52
+
53
+ if fsdp is not None and len(fsdp) > 0:
54
+ self.vision_tower = [vision_tower]
55
+ else:
56
+ self.vision_tower = vision_tower
57
+ else:
58
+ if fsdp is not None and len(fsdp) > 0:
59
+ vision_tower = self.vision_tower[0]
60
+ else:
61
+ vision_tower = self.vision_tower
62
+ vision_tower.load_model()
63
+
64
+ self.config.use_mm_proj = True
65
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
66
+ self.config.mm_hidden_size = vision_tower.hidden_size
67
+ self.config.mm_vision_select_layer = mm_vision_select_layer
68
+ self.config.mm_vision_select_feature = mm_vision_select_feature
69
+
70
+ if getattr(self, 'mm_projector', None) is None:
71
+ self.mm_projector = build_vision_projector(self.config)
72
+ else:
73
+ # In case it is frozen by LoRA
74
+ for p in self.mm_projector.parameters():
75
+ p.requires_grad = True
76
+
77
+ if pretrain_mm_mlp_adapter is not None:
78
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
79
+ def get_w(weights, keyword):
80
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
81
+
82
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
83
+
84
+
85
+ class LlavaMetaForCausalLM(ABC):
86
+
87
+ @abstractmethod
88
+ def get_model(self):
89
+ pass
90
+
91
+ def get_vision_tower(self):
92
+ return self.get_model().get_vision_tower()
93
+
94
+ def encode_images(self, images):
95
+ image_features = self.get_model().get_vision_tower()(images)
96
+ image_features = self.get_model().mm_projector(image_features)
97
+ return image_features
98
+
99
+ def prepare_inputs_labels_for_multimodal(
100
+ self, input_ids, attention_mask, past_key_values, labels, images
101
+ ):
102
+ vision_tower = self.get_vision_tower()
103
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
104
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
105
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
106
+ return input_ids, attention_mask, past_key_values, None, labels
107
+
108
+ if type(images) is list or images.ndim == 5:
109
+ concat_images = torch.cat([image for image in images], dim=0)
110
+ image_features = self.encode_images(concat_images)
111
+ split_sizes = [image.shape[0] for image in images]
112
+ image_features = torch.split(image_features, split_sizes, dim=0)
113
+ image_features = [x.flatten(0, 1) for x in image_features]
114
+ else:
115
+ image_features = self.encode_images(images)
116
+
117
+ new_input_embeds = []
118
+ new_labels = [] if labels is not None else None
119
+ cur_image_idx = 0
120
+ for batch_idx, cur_input_ids in enumerate(input_ids):
121
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
122
+ # multimodal LLM, but the current sample is not multimodal
123
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
124
+ half_len = cur_input_ids.shape[0] // 2
125
+ cur_image_features = image_features[cur_image_idx]
126
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
127
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
128
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
129
+ new_input_embeds.append(cur_input_embeds)
130
+ if labels is not None:
131
+ new_labels.append(labels[batch_idx])
132
+ cur_image_idx += 1
133
+ continue
134
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
135
+ cur_new_input_embeds = []
136
+ if labels is not None:
137
+ cur_labels = labels[batch_idx]
138
+ cur_new_labels = []
139
+ assert cur_labels.shape == cur_input_ids.shape
140
+ while image_token_indices.numel() > 0:
141
+ cur_image_features = image_features[cur_image_idx]
142
+ image_token_start = image_token_indices[0]
143
+
144
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
145
+ cur_new_input_embeds.append(cur_image_features)
146
+ if labels is not None:
147
+ cur_new_labels.append(cur_labels[:image_token_start])
148
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
149
+ cur_labels = cur_labels[image_token_start+1:]
150
+ cur_image_idx += 1
151
+ cur_input_ids = cur_input_ids[image_token_start+1:]
152
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
153
+ if cur_input_ids.numel() > 0:
154
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
155
+ if labels is not None:
156
+ cur_new_labels.append(cur_labels)
157
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
158
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
159
+ new_input_embeds.append(cur_new_input_embeds)
160
+ if labels is not None:
161
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
162
+ new_labels.append(cur_new_labels)
163
+
164
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
165
+ max_len = max(x.shape[0] for x in new_input_embeds)
166
+
167
+ new_input_embeds_align = []
168
+ for cur_new_embed in new_input_embeds:
169
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
170
+ new_input_embeds_align.append(cur_new_embed)
171
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
172
+
173
+ if labels is not None:
174
+ new_labels_align = []
175
+ _new_labels = new_labels
176
+ for cur_new_label in new_labels:
177
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
178
+ new_labels_align.append(cur_new_label)
179
+ new_labels = torch.stack(new_labels_align, dim=0)
180
+
181
+ if attention_mask is not None:
182
+ new_attention_mask = []
183
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
184
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
185
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
186
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
187
+ new_attention_mask.append(cur_new_attention_mask)
188
+ attention_mask = torch.stack(new_attention_mask, dim=0)
189
+ assert attention_mask.shape == new_labels.shape
190
+ else:
191
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
192
+ if labels is not None:
193
+ new_labels = torch.stack(new_labels, dim=0)
194
+
195
+ if attention_mask is not None:
196
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
197
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
198
+ assert attention_mask.shape == new_input_embeds.shape[:2]
199
+
200
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
vcoder_llava/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m vcoder_llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from vcoder_llava.model.utils import auto_upgrade
11
+
12
+
13
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31
+ bparam = base.state_dict()[name]
32
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
vcoder_llava/model/multimodal_adapter/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import re
3
+
4
+ class IdentityMap(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x, *args, **kwargs):
9
+ return x
10
+
11
+ @property
12
+ def config(self):
13
+ return {"seg_mm_projector_type": 'identity'}
14
+
15
+
16
+ class SimpleResBlock(nn.Module):
17
+ def __init__(self, channels):
18
+ super().__init__()
19
+ self.pre_norm = nn.LayerNorm(channels)
20
+
21
+ self.proj = nn.Sequential(
22
+ nn.Linear(channels, channels),
23
+ nn.GELU(),
24
+ nn.Linear(channels, channels)
25
+ )
26
+ def forward(self, x):
27
+ x = self.pre_norm(x)
28
+ return x + self.proj(x)
29
+
30
+
31
+ def build_seg_projector(config, delay_load=False, **kwargs):
32
+ projector_type = getattr(config, 'seg_mm_projector_type', 'linear')
33
+
34
+ if projector_type == 'linear':
35
+ return nn.Linear(config.seg_mm_hidden_size, config.hidden_size)
36
+
37
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
38
+ if mlp_gelu_match:
39
+ mlp_depth = int(mlp_gelu_match.group(1))
40
+ modules = [nn.Linear(config.seg_mm_hidden_size, config.hidden_size)]
41
+ for _ in range(1, mlp_depth):
42
+ modules.append(nn.GELU())
43
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
44
+ return nn.Sequential(*modules)
45
+
46
+ if projector_type == 'identity':
47
+ return IdentityMap()
48
+
49
+ raise ValueError(f'Unknown seg projector type: {projector_type}')
vcoder_llava/model/multimodal_depth_adapter/builder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import re
3
+
4
+ class IdentityMap(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x, *args, **kwargs):
9
+ return x
10
+
11
+ @property
12
+ def config(self):
13
+ return {"depth_mm_projector_type": 'identity'}
14
+
15
+
16
+
17
+ class SimpleResBlock(nn.Module):
18
+ def __init__(self, channels):
19
+ super().__init__()
20
+ self.pre_norm = nn.LayerNorm(channels)
21
+
22
+ self.proj = nn.Sequential(
23
+ nn.Linear(channels, channels),
24
+ nn.GELU(),
25
+ nn.Linear(channels, channels)
26
+ )
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_depth_projector(config, delay_load=False, **kwargs):
33
+ projector_type = getattr(config, 'depth_mm_projector_type', 'linear')
34
+
35
+ if projector_type == 'linear':
36
+ return nn.Linear(config.depth_mm_hidden_size, config.hidden_size)
37
+
38
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
39
+ if mlp_gelu_match:
40
+ mlp_depth = int(mlp_gelu_match.group(1))
41
+ modules = [nn.Linear(config.depth_mm_hidden_size, config.hidden_size)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
45
+ return nn.Sequential(*modules)
46
+
47
+ if projector_type == 'identity':
48
+ return IdentityMap()
49
+
50
+ raise ValueError(f'Unknown depth projector type: {projector_type}')
vcoder_llava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+
4
+
5
+ def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7
+ is_absolute_path_exists = os.path.exists(vision_tower)
8
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
9
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10
+
11
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
vcoder_llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ else:
20
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
21
+
22
+ def load_model(self):
23
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
25
+ self.vision_tower.requires_grad_(False)
26
+
27
+ self.is_loaded = True
28
+
29
+ def feature_select(self, image_forward_outs):
30
+ image_features = image_forward_outs.hidden_states[self.select_layer]
31
+ if self.select_feature == 'patch':
32
+ image_features = image_features[:, 1:]
33
+ elif self.select_feature == 'cls_patch':
34
+ image_features = image_features
35
+ else:
36
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
37
+ return image_features
38
+
39
+ @torch.no_grad()
40
+ def forward(self, images):
41
+ if type(images) is list:
42
+ image_features = []
43
+ for image in images:
44
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
45
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
46
+ image_features.append(image_feature)
47
+ else:
48
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
49
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
50
+
51
+ return image_features
52
+
53
+ @property
54
+ def dummy_feature(self):
55
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
56
+
57
+ @property
58
+ def dtype(self):
59
+ return self.vision_tower.dtype
60
+
61
+ @property
62
+ def device(self):
63
+ return self.vision_tower.device
64
+
65
+ @property
66
+ def config(self):
67
+ if self.is_loaded:
68
+ return self.vision_tower.config
69
+ else:
70
+ return self.cfg_only
71
+
72
+ @property
73
+ def hidden_size(self):
74
+ return self.config.hidden_size
75
+
76
+ @property
77
+ def num_patches(self):
78
+ return (self.config.image_size // self.config.patch_size) ** 2
vcoder_llava/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+
6
+ class IdentityMap(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_projector_type": 'identity'}
16
+
17
+
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(channels)
22
+
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(channels, channels),
25
+ nn.GELU(),
26
+ nn.Linear(channels, channels)
27
+ )
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
35
+
36
+ if projector_type == 'linear':
37
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
38
+
39
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40
+ if mlp_gelu_match:
41
+ mlp_depth = int(mlp_gelu_match.group(1))
42
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43
+ for _ in range(1, mlp_depth):
44
+ modules.append(nn.GELU())
45
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46
+ return nn.Sequential(*modules)
47
+
48
+ if projector_type == 'identity':
49
+ return IdentityMap()
50
+
51
+ raise ValueError(f'Unknown projector type: {projector_type}')
vcoder_llava/model/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+
3
+
4
+ def auto_upgrade(config):
5
+ cfg = AutoConfig.from_pretrained(config)
6
+ if 'llava' in config and 'llava' not in cfg.model_type:
7
+ assert cfg.model_type == 'llama'
8
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11
+ if confirm.lower() in ["y", "yes"]:
12
+ print("Upgrading checkpoint...")
13
+ assert len(cfg.architectures) == 1
14
+ setattr(cfg.__class__, "model_type", "llava")
15
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16
+ cfg.save_pretrained(config)
17
+ print("Checkpoint upgraded.")
18
+ else:
19
+ print("Checkpoint upgrade aborted.")
20
+ exit(1)
vcoder_llava/model/vcd/vcd_add_noise.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def add_diffusion_noise(image_tensor, noise_step):
4
+ num_steps = 1000 # Number of diffusion steps
5
+
6
+ # decide beta in each step
7
+ betas = torch.linspace(-6,6,num_steps)
8
+ betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
9
+
10
+ # decide alphas in each step
11
+ alphas = 1 - betas
12
+ alphas_prod = torch.cumprod(alphas, dim=0)
13
+ alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]],0) # p for previous
14
+ alphas_bar_sqrt = torch.sqrt(alphas_prod)
15
+ one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
16
+ one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
17
+
18
+ def q_x(x_0,t):
19
+ noise = torch.randn_like(x_0)
20
+ alphas_t = alphas_bar_sqrt[t]
21
+ alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
22
+ return (alphas_t*x_0 + alphas_1_m_t*noise)
23
+
24
+ noise_delta = int(noise_step) # from 0-999
25
+ noisy_image = image_tensor.clone()
26
+ image_tensor_cd = q_x(noisy_image,noise_step)
27
+
28
+ return image_tensor_cd
vcoder_llava/model/vcd/vcd_sample.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import inspect
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch import nn
10
+
11
+ from transformers.generation.logits_process import (
12
+ LogitsProcessorList,
13
+ )
14
+ from transformers.generation.stopping_criteria import (
15
+ StoppingCriteria,
16
+ StoppingCriteriaList,
17
+ validate_stopping_criteria,
18
+ )
19
+ import transformers
20
+ from transformers.generation.utils import SampleOutput
21
+
22
+
23
+
24
+ def sample(
25
+ self,
26
+ input_ids: torch.LongTensor,
27
+ logits_processor: Optional[LogitsProcessorList] = None,
28
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
29
+ logits_warper: Optional[LogitsProcessorList] = None,
30
+ max_length: Optional[int] = None,
31
+ pad_token_id: Optional[int] = None,
32
+ eos_token_id: Optional[Union[int, List[int]]] = None,
33
+ output_attentions: Optional[bool] = None,
34
+ output_hidden_states: Optional[bool] = None,
35
+ output_scores: Optional[bool] = None,
36
+ return_dict_in_generate: Optional[bool] = None,
37
+ synced_gpus: bool = False,
38
+ streamer: Optional["BaseStreamer"] = None,
39
+ **model_kwargs,
40
+ ) -> Union[SampleOutput, torch.LongTensor]:
41
+ # init values
42
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
43
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
44
+ if max_length is not None:
45
+ warnings.warn(
46
+ "`max_length` is deprecated in this function, use"
47
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
48
+ UserWarning,
49
+ )
50
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
51
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
52
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
53
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
54
+
55
+
56
+ if isinstance(eos_token_id, int):
57
+ eos_token_id = [eos_token_id]
58
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
59
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
60
+ output_attentions = (
61
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
62
+ )
63
+ output_hidden_states = (
64
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
65
+ )
66
+
67
+ return_dict_in_generate = (
68
+ return_dict_in_generate
69
+ if return_dict_in_generate is not None
70
+ else self.generation_config.return_dict_in_generate
71
+ )
72
+
73
+ # init attention / hidden states / scores tuples
74
+ scores = () if (return_dict_in_generate and output_scores) else None
75
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
76
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
77
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
78
+
79
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
80
+ if return_dict_in_generate and self.config.is_encoder_decoder:
81
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
82
+ encoder_hidden_states = (
83
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
84
+ )
85
+
86
+ # keep track of which sequences are already finished
87
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
88
+
89
+ this_peer_finished = False # used by synced_gpus only
90
+
91
+ # auto-regressive generation
92
+ while True:
93
+ if synced_gpus:
94
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
95
+ # The following logic allows an early break if all peers finished generating their sequence
96
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
97
+ # send 0.0 if we finished, 1.0 otherwise
98
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
99
+ # did all peers finish? the reduced sum will be 0.0 then
100
+ if this_peer_finished_flag.item() == 0.0:
101
+ break
102
+
103
+ # prepare model inputs
104
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
105
+
106
+ # forward pass to get next token
107
+ outputs = self(
108
+ **model_inputs,
109
+ return_dict=True,
110
+ output_attentions=output_attentions,
111
+ output_hidden_states=output_hidden_states,
112
+ )
113
+
114
+ if synced_gpus and this_peer_finished:
115
+ continue # don't waste resources running the code we don't need
116
+
117
+ next_token_logits = outputs.logits[:, -1, :]
118
+
119
+
120
+ ## For contrastive decoding initial
121
+ use_cd = model_kwargs.get("images_cd") != None
122
+ output_attentions_wo_img = (
123
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
124
+ )
125
+ output_hidden_states_wo_img = (
126
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
127
+ )
128
+ model_kwargs_cd = model_kwargs.copy()
129
+
130
+ if use_cd:
131
+ ## cd_comments: forward pass of the model with distorted image input
132
+ model_inputs_cd = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd)
133
+ outputs_cd = self(
134
+ **model_inputs_cd,
135
+ return_dict=True,
136
+ output_attentions=output_attentions_wo_img,
137
+ output_hidden_states=output_hidden_states_wo_img,
138
+ )
139
+ next_token_logits_cd = outputs_cd.logits[:, -1, :]
140
+
141
+ ## cd_comments: pre-process logits from contrastive inputs
142
+ cd_alpha = model_kwargs.get("cd_alpha") if model_kwargs.get("cd_alpha") is not None else 0.5
143
+ cd_beta = model_kwargs.get("cd_beta") if model_kwargs.get("cd_beta") is not None else 0.1
144
+
145
+ # version 1 set cutoff for Adaptive Plausibility Constraints
146
+ # probs = nn.functional.softmax(next_token_logits, dim=-1)
147
+ # cutoff = cd_beta * probs.max(dim=-1, keepdim=True).values
148
+
149
+ # version 2 set cutoff for Adaptive Plausibility Constraints
150
+ cutoff = torch.log(torch.tensor(cd_beta)) + next_token_logits.max(dim=-1, keepdim=True).values
151
+
152
+ diffs = (1+cd_alpha)*next_token_logits - cd_alpha*next_token_logits_cd
153
+ cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))
154
+
155
+ ## cd_comments: apply temperature warping and top-k filtering in contrastive decoding
156
+ cd_logits = logits_processor(input_ids, cd_logits)
157
+ cd_logits = logits_warper(input_ids, cd_logits)
158
+
159
+ next_token_scores = cd_logits
160
+ cd_probs = nn.functional.softmax(cd_logits, dim=-1)
161
+ next_tokens = torch.multinomial(cd_probs, num_samples=1).squeeze(1)
162
+ else:
163
+ next_token_scores = logits_processor(input_ids, next_token_logits)
164
+ next_token_scores = logits_warper(input_ids, next_token_scores)
165
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
166
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
167
+
168
+
169
+
170
+ # Store scores, attentions and hidden_states when required
171
+ if return_dict_in_generate:
172
+ if output_scores:
173
+ scores += (next_token_scores,)
174
+ if output_attentions:
175
+ decoder_attentions += (
176
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
177
+ )
178
+ if self.config.is_encoder_decoder:
179
+ cross_attentions += (outputs.cross_attentions,)
180
+
181
+ if output_hidden_states:
182
+ decoder_hidden_states += (
183
+ (outputs.decoder_hidden_states,)
184
+ if self.config.is_encoder_decoder
185
+ else (outputs.hidden_states,)
186
+ )
187
+
188
+
189
+ # finished sentences should have their next token be a padding token
190
+ if eos_token_id is not None:
191
+ if pad_token_id is None:
192
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
193
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
194
+
195
+ # update generated ids, model inputs, and length for next step
196
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
197
+ if streamer is not None:
198
+ streamer.put(next_tokens.cpu())
199
+ model_kwargs = self._update_model_kwargs_for_generation(
200
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
201
+ )
202
+ ## cd_comments: update model_kwargs_cd for contrastive decoding
203
+ if use_cd:
204
+ model_kwargs_cd = self._update_model_kwargs_for_generation(
205
+ outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
206
+ )
207
+
208
+ # if eos_token was found in one sentence, set sentence to finished
209
+ if eos_token_id_tensor is not None:
210
+ unfinished_sequences = unfinished_sequences.mul(
211
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
212
+ )
213
+
214
+ # stop when each sentence is finished
215
+ if unfinished_sequences.max() == 0:
216
+ this_peer_finished = True
217
+
218
+ # stop if we exceed the maximum length
219
+ if stopping_criteria(input_ids, scores):
220
+ this_peer_finished = True
221
+
222
+ if this_peer_finished and not synced_gpus:
223
+ break
224
+
225
+ if streamer is not None:
226
+ streamer.end()
227
+
228
+ if return_dict_in_generate:
229
+ if self.config.is_encoder_decoder:
230
+ return SampleEncoderDecoderOutput(
231
+ sequences=input_ids,
232
+ scores=scores,
233
+ encoder_attentions=encoder_attentions,
234
+ encoder_hidden_states=encoder_hidden_states,
235
+ decoder_attentions=decoder_attentions,
236
+ cross_attentions=cross_attentions,
237
+ decoder_hidden_states=decoder_hidden_states,
238
+ )
239
+ else:
240
+ return SampleDecoderOnlyOutput(
241
+ sequences=input_ids,
242
+ scores=scores,
243
+ attentions=decoder_attentions,
244
+ hidden_states=decoder_hidden_states,
245
+ )
246
+ else:
247
+ return input_ids
248
+
249
+ def evolve_vcd_sampling():
250
+ transformers.generation.utils.GenerationMixin.sample = sample
vcoder_llava/model/vcoder_ds_llava_arch.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+ from .multimodal_adapter.builder import build_seg_projector
24
+ from .multimodal_depth_adapter.builder import build_depth_projector
25
+
26
+ from vcoder_llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX, DEPTH_TOKEN_INDEX
27
+
28
+ class VCoderDSLlavaMetaModel:
29
+
30
+ def __init__(self, config):
31
+ super(VCoderDSLlavaMetaModel, self).__init__(config)
32
+ self.config = config
33
+
34
+ if hasattr(config, "mm_vision_tower"):
35
+ self.vision_tower = build_vision_tower(config, delay_load=True)
36
+ self.mm_projector = build_vision_projector(config)
37
+
38
+ if hasattr(config, "seg_mm_projector_type"):
39
+ self.seg_mm_projector = build_seg_projector(config)
40
+
41
+ if hasattr(config, "use_mm2_proj"):
42
+ if config.use_mm2_proj:
43
+ self.mm2_projector = build_vision_projector(config)
44
+
45
+ if hasattr(config, "depth_mm_projector_type"):
46
+ self.depth_mm_projector = build_depth_projector(config)
47
+
48
+ if hasattr(config, "mm_vcoder_lm_emb"):
49
+ self.vcoder_lm_emb = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
50
+
51
+ def get_vision_tower(self):
52
+ vision_tower = getattr(self, 'vision_tower', None)
53
+ if type(vision_tower) is list:
54
+ vision_tower = vision_tower[0]
55
+ return vision_tower
56
+
57
+ def initialize_seg_modules(self, model_args, fsdp=None):
58
+ mm_seg_select_layer = model_args.mm_seg_select_layer
59
+ mm_seg_select_feature = model_args.mm_seg_select_feature
60
+
61
+ self.config.seg_mm_hidden_size = self.vision_tower.hidden_size
62
+
63
+ self.config.seg_use_mm_proj = True
64
+ self.config.seg_mm_projector_type = getattr(model_args, 'seg_mm_projector_type', 'linear')
65
+ self.config.mm_seg_select_layer = mm_seg_select_layer
66
+ self.config.mm_seg_select_feature = mm_seg_select_feature
67
+
68
+ self.seg_mm_projector = build_seg_projector(self.config)
69
+ self.vcoder_lm_emb = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.config.pad_token_id)
70
+
71
+ # use MLP from pretraining stage
72
+ pretrain_mm2_mlp_adapter = model_args.pretrain_mm2_mlp_adapter
73
+ if getattr(model_args, "use_mm2_proj"):
74
+ self.config.use_mm2_proj = model_args.use_mm2_proj
75
+ self.mm2_projector = build_vision_projector(self.config)
76
+
77
+ if pretrain_mm2_mlp_adapter is not None:
78
+ mm2_projector_weights = torch.load(pretrain_mm2_mlp_adapter, map_location='cpu')
79
+ def get_w(weights, keyword):
80
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
81
+
82
+ self.mm2_projector.load_state_dict(get_w(mm2_projector_weights, 'mm_projector'))
83
+
84
+ def initialize_depth_modules(self, model_args, fsdp=None):
85
+ mm_depth_select_layer = model_args.mm_depth_select_layer
86
+ mm_depth_select_feature = model_args.mm_depth_select_feature
87
+
88
+ self.config.depth_mm_hidden_size = self.vision_tower.hidden_size
89
+
90
+ self.config.depth_use_mm_proj = True
91
+ self.config.depth_mm_projector_type = getattr(model_args, 'depth_mm_projector_type', 'linear')
92
+ self.config.mm_depth_select_layer = mm_depth_select_layer
93
+ self.config.mm_depth_select_feature = mm_depth_select_feature
94
+
95
+ self.depth_mm_projector = build_depth_projector(self.config)
96
+
97
+ class VCoderDSLlavaMetaForCausalLM(ABC):
98
+
99
+ @abstractmethod
100
+ def get_model(self):
101
+ pass
102
+
103
+ def get_vision_tower(self):
104
+ return self.get_model().get_vision_tower()
105
+
106
+ def encode_seg_images(self, seg_images):
107
+ seg_features = self.get_model().get_vision_tower()(seg_images)
108
+ seg_features = self.get_model().seg_mm_projector(seg_features)
109
+ return seg_features
110
+
111
+ def encode_depth_images(self, depth_images):
112
+ depth_features = self.get_model().get_vision_tower()(depth_images)
113
+ depth_features = self.get_model().seg_mm_projector(depth_features)
114
+ return depth_features
115
+
116
+ def encode_images(self, images):
117
+ image_features = self.get_model().get_vision_tower()(images)
118
+ image_features = self.get_model().mm_projector(image_features)
119
+ return image_features
120
+
121
+ def encode_images_w_seg(self, images):
122
+ image_features = self.get_model().get_vision_tower()(images)
123
+ image_features = self.get_model().mm2_projector(image_features)
124
+ return image_features
125
+
126
+ def prepare_inputs_labels_for_multimodal(
127
+ self, input_ids, attention_mask, past_key_values, labels, images, seg_images, depth_images
128
+ ):
129
+ vision_tower = self.get_vision_tower()
130
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
131
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
132
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
133
+ return input_ids, attention_mask, past_key_values, None, labels
134
+
135
+ if type(images) is list or images.ndim == 5:
136
+ concat_images = torch.cat([image for image in images], dim=0)
137
+ if seg_images is not None and hasattr(self, 'mm2_projector'):
138
+ image_features = self.encode_images_w_seg(concat_images)
139
+ else:
140
+ image_features = self.encode_images(concat_images)
141
+ split_sizes = [image.shape[0] for image in images]
142
+ image_features = torch.split(image_features, split_sizes, dim=0)
143
+ image_features = [x.flatten(0, 1) for x in image_features]
144
+ else:
145
+ if seg_images is not None and hasattr(self, 'mm2_projector'):
146
+ image_features = self.encode_images_w_seg(images)
147
+ else:
148
+ image_features = self.encode_images(images)
149
+
150
+ if seg_images is not None:
151
+ if type(seg_images) is list or seg_images.ndim == 5:
152
+ concat_seg_images = torch.cat([image for image in seg_images], dim=0)
153
+ seg_features = self.encode_seg_images(concat_seg_images)
154
+ split_sizes = [image.shape[0] for image in seg_images]
155
+ seg_features = torch.split(seg_features, split_sizes, dim=0)
156
+ seg_features = [x.flatten(0, 1) for x in seg_features]
157
+ else:
158
+ seg_features = self.encode_seg_images(seg_images)
159
+
160
+ if depth_images is not None:
161
+ try:
162
+ for p in self.get_model().depth_mm_projector.parameters():
163
+ p.requires_grad = True
164
+ if type(depth_images) is list or depth_images.ndim == 5:
165
+ concat_depth_images = torch.cat([image for image in depth_images], dim=0)
166
+ depth_features = self.encode_depth_images(concat_depth_images)
167
+ split_sizes = [image.shape[0] for image in depth_images]
168
+ depth_features = torch.split(depth_features, split_sizes, dim=0)
169
+ depth_features = [x.flatten(0, 1) for x in depth_features]
170
+ else:
171
+ depth_features = self.encode_depth_images(depth_images)
172
+ except:
173
+ depth_images = None
174
+ mask = input_ids != DEPTH_TOKEN_INDEX # drop depth indices
175
+ input_ids = input_ids[mask]
176
+ for p in self.get_model().depth_mm_projector.parameters():
177
+ p.requires_grad = False
178
+ else:
179
+ for p in self.get_model().depth_mm_projector.parameters():
180
+ p.requires_grad = False
181
+
182
+ self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
183
+
184
+ new_input_embeds = []
185
+ new_labels = [] if labels is not None else None
186
+ cur_image_idx = 0
187
+ cur_seg_idx = 0
188
+ cur_depth_idx = 0
189
+ for batch_idx, cur_input_ids in enumerate(input_ids):
190
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 and (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
191
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
192
+ cur_image_features = image_features[cur_image_idx]
193
+ half_len = cur_input_ids.shape[0] // 2
194
+ if seg_images is not None:
195
+ cur_seg_features = seg_features[cur_seg_idx]
196
+ if depth_images is not None:
197
+ cur_depth_features = depth_features[cur_depth_idx]
198
+ cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
199
+ cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
200
+ else:
201
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
202
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
203
+ if seg_images is not None:
204
+ if depth_images is not None:
205
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_depth_features[0:0], cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
206
+ else:
207
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
208
+ else:
209
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
210
+ new_input_embeds.append(cur_input_embeds)
211
+ if labels is not None:
212
+ new_labels.append(labels[batch_idx])
213
+ cur_image_idx += 1
214
+ cur_seg_idx += 1
215
+ cur_depth_idx += 1
216
+ continue
217
+
218
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
219
+
220
+ cur_new_input_embeds = []
221
+ if labels is not None:
222
+ cur_labels = labels[batch_idx]
223
+ cur_new_labels = []
224
+ assert cur_labels.shape == cur_input_ids.shape
225
+ while image_token_indices.numel() > 0:
226
+ cur_image_features = image_features[cur_image_idx]
227
+ image_token_start = image_token_indices[0]
228
+ if seg_images is None:
229
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
230
+ else:
231
+ cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:image_token_start]))
232
+ cur_new_input_embeds.append(cur_image_features)
233
+ if labels is not None:
234
+ cur_new_labels.append(cur_labels[:image_token_start])
235
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
236
+ cur_labels = cur_labels[image_token_start+1:]
237
+ cur_image_idx += 1
238
+ cur_input_ids = cur_input_ids[image_token_start+1:]
239
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
240
+
241
+ if seg_images is not None:
242
+ seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
243
+ while seg_token_indices.numel() > 0:
244
+ cur_seg_features = seg_features[cur_seg_idx]
245
+ seg_token_start = seg_token_indices[0]
246
+ if depth_images is None:
247
+ cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:seg_token_start]))
248
+ cur_new_input_embeds.append(cur_seg_features)
249
+ if labels is not None:
250
+ if depth_images is None:
251
+ cur_new_labels.append(cur_labels[:seg_token_start])
252
+ cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
253
+ cur_labels = cur_labels[seg_token_start+1:]
254
+ cur_seg_idx += 1
255
+ cur_input_ids = cur_input_ids[seg_token_start+1:]
256
+ seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
257
+
258
+ if depth_images is not None:
259
+ depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
260
+ while depth_token_indices.numel() > 0:
261
+ cur_depth_features = depth_features[cur_depth_idx]
262
+ depth_token_start = depth_token_indices[0]
263
+ cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:depth_token_start]))
264
+ cur_new_input_embeds.append(cur_depth_features)
265
+ if labels is not None:
266
+ cur_new_labels.append(cur_labels[:depth_token_start])
267
+ cur_new_labels.append(torch.full((cur_depth_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
268
+ cur_labels = cur_labels[depth_token_start+1:]
269
+ cur_depth_idx += 1
270
+ cur_input_ids = cur_input_ids[depth_token_start+1:]
271
+ depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
272
+
273
+ if cur_input_ids.numel() > 0:
274
+ if seg_images is None:
275
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
276
+ else:
277
+ cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids))
278
+ if labels is not None:
279
+ cur_new_labels.append(cur_labels)
280
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
281
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
282
+ new_input_embeds.append(cur_new_input_embeds)
283
+ if labels is not None:
284
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
285
+ new_labels.append(cur_new_labels)
286
+
287
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
288
+ max_len = max(x.shape[0] for x in new_input_embeds)
289
+
290
+ new_input_embeds_align = []
291
+ for cur_new_embed in new_input_embeds:
292
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
293
+ new_input_embeds_align.append(cur_new_embed)
294
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
295
+
296
+ if labels is not None:
297
+ new_labels_align = []
298
+ _new_labels = new_labels
299
+ for cur_new_label in new_labels:
300
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
301
+ new_labels_align.append(cur_new_label)
302
+ new_labels = torch.stack(new_labels_align, dim=0)
303
+
304
+ if attention_mask is not None:
305
+ new_attention_mask = []
306
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
307
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
308
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
309
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
310
+ new_attention_mask.append(cur_new_attention_mask)
311
+ attention_mask = torch.stack(new_attention_mask, dim=0)
312
+ assert attention_mask.shape == new_labels.shape
313
+ else:
314
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
315
+ if labels is not None:
316
+ new_labels = torch.stack(new_labels, dim=0)
317
+
318
+ if attention_mask is not None:
319
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
320
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
321
+ assert attention_mask.shape == new_input_embeds.shape[:2]
322
+
323
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
vcoder_llava/model/vcoder_llava_arch.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+ from .multimodal_adapter.builder import build_seg_projector
24
+
25
+ from vcoder_llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX
26
+
27
+ class VCoderLlavaMetaModel:
28
+
29
+ def __init__(self, config):
30
+ super(VCoderLlavaMetaModel, self).__init__(config)
31
+ self.config = config
32
+
33
+ if hasattr(config, "mm_vision_tower"):
34
+ self.vision_tower = build_vision_tower(config, delay_load=True)
35
+ self.mm_projector = build_vision_projector(config)
36
+
37
+ if hasattr(config, "seg_mm_projector_type"):
38
+ self.seg_mm_projector = build_seg_projector(config)
39
+
40
+ if hasattr(config, "use_mm2_proj"):
41
+ if config.use_mm2_proj:
42
+ self.mm2_projector = build_vision_projector(config)
43
+
44
+ if hasattr(config, "mm_vcoder_lm_emb"):
45
+ self.vcoder_lm_emb = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
46
+
47
+ def get_vision_tower(self):
48
+ vision_tower = getattr(self, 'vision_tower', None)
49
+ if type(vision_tower) is list:
50
+ vision_tower = vision_tower[0]
51
+ return vision_tower
52
+
53
+ def initialize_seg_modules(self, model_args, fsdp=None):
54
+ mm_seg_select_layer = model_args.mm_seg_select_layer
55
+ mm_seg_select_feature = model_args.mm_seg_select_feature
56
+
57
+ self.config.seg_mm_hidden_size = self.vision_tower.hidden_size
58
+
59
+ pretrain_mm2_mlp_adapter = model_args.pretrain_mm2_mlp_adapter
60
+
61
+ self.config.seg_use_mm_proj = True
62
+ self.config.seg_mm_projector_type = getattr(model_args, 'seg_mm_projector_type', 'linear')
63
+ self.config.mm_seg_select_layer = mm_seg_select_layer
64
+ self.config.mm_seg_select_feature = mm_seg_select_feature
65
+
66
+ self.seg_mm_projector = build_seg_projector(self.config)
67
+ self.vcoder_lm_emb = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.config.pad_token_id)
68
+
69
+ if getattr(model_args, "use_mm2_proj"):
70
+ self.config.use_mm2_proj = model_args.use_mm2_proj
71
+ self.mm2_projector = build_vision_projector(self.config)
72
+
73
+ if pretrain_mm2_mlp_adapter is not None:
74
+ mm2_projector_weights = torch.load(pretrain_mm2_mlp_adapter, map_location='cpu')
75
+ def get_w(weights, keyword):
76
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
77
+
78
+ self.mm2_projector.load_state_dict(get_w(mm2_projector_weights, 'mm_projector'))
79
+
80
+ class VCoderLlavaMetaForCausalLM(ABC):
81
+
82
+ @abstractmethod
83
+ def get_model(self):
84
+ pass
85
+
86
+ def get_vision_tower(self):
87
+ return self.get_model().get_vision_tower()
88
+
89
+ def encode_seg_images(self, seg_images):
90
+ seg_features = self.get_model().get_vision_tower()(seg_images)
91
+ seg_features = self.get_model().seg_mm_projector(seg_features)
92
+ return seg_features
93
+
94
+ def encode_images(self, images):
95
+ image_features = self.get_model().get_vision_tower()(images)
96
+ image_features = self.get_model().mm_projector(image_features)
97
+ return image_features
98
+
99
+ def encode_images_w_seg(self, images):
100
+ image_features = self.get_model().get_vision_tower()(images)
101
+ image_features = self.get_model().mm2_projector(image_features)
102
+ return image_features
103
+
104
+ def prepare_inputs_labels_for_multimodal(
105
+ self, input_ids, attention_mask, past_key_values, labels, images, seg_images,
106
+ ):
107
+ vision_tower = self.get_vision_tower()
108
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
109
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
110
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
111
+ return input_ids, attention_mask, past_key_values, None, labels
112
+
113
+ if type(images) is list or images.ndim == 5:
114
+ concat_images = torch.cat([image for image in images], dim=0)
115
+ if seg_images is not None and hasattr(self, 'mm2_projector'):
116
+ image_features = self.encode_images_w_seg(concat_images)
117
+ else:
118
+ image_features = self.encode_images(concat_images)
119
+ split_sizes = [image.shape[0] for image in images]
120
+ image_features = torch.split(image_features, split_sizes, dim=0)
121
+ image_features = [x.flatten(0, 1) for x in image_features]
122
+ else:
123
+ if seg_images is not None and hasattr(self, 'mm2_projector'):
124
+ image_features = self.encode_images_w_seg(images)
125
+ else:
126
+ image_features = self.encode_images(images)
127
+
128
+ if seg_images is not None:
129
+ if type(seg_images) is list or seg_images.ndim == 5:
130
+ concat_seg_images = torch.cat([image for image in seg_images], dim=0)
131
+ seg_features = self.encode_seg_images(concat_seg_images)
132
+ split_sizes = [image.shape[0] for image in seg_images]
133
+ seg_features = torch.split(seg_features, split_sizes, dim=0)
134
+ seg_features = [x.flatten(0, 1) for x in seg_features]
135
+ else:
136
+ seg_features = self.encode_seg_images(seg_images)
137
+
138
+ self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
139
+
140
+ new_input_embeds = []
141
+ new_labels = [] if labels is not None else None
142
+ cur_image_idx = 0
143
+ cur_seg_idx = 0
144
+ for batch_idx, cur_input_ids in enumerate(input_ids):
145
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 or (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
146
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
147
+ cur_image_features = image_features[cur_image_idx]
148
+ if seg_images is not None:
149
+ cur_seg_features = seg_features[cur_seg_idx]
150
+ half_len = cur_input_ids.shape[0] // 2
151
+ if seg_images is not None:
152
+ cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
153
+ cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
154
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
155
+ else:
156
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
157
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
158
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
159
+ new_input_embeds.append(cur_input_embeds)
160
+ if labels is not None:
161
+ new_labels.append(labels[batch_idx])
162
+ cur_image_idx += 1
163
+ cur_seg_idx += 1
164
+ continue
165
+
166
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
167
+
168
+ cur_new_input_embeds = []
169
+ if labels is not None:
170
+ cur_labels = labels[batch_idx]
171
+ cur_new_labels = []
172
+ assert cur_labels.shape == cur_input_ids.shape
173
+ while image_token_indices.numel() > 0:
174
+ cur_image_features = image_features[cur_image_idx]
175
+ image_token_start = image_token_indices[0]
176
+ if seg_images is None:
177
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
178
+ else:
179
+ cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:image_token_start]))
180
+ cur_new_input_embeds.append(cur_image_features)
181
+ if labels is not None:
182
+ cur_new_labels.append(cur_labels[:image_token_start])
183
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
184
+ cur_labels = cur_labels[image_token_start+1:]
185
+ cur_image_idx += 1
186
+ cur_input_ids = cur_input_ids[image_token_start+1:]
187
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
188
+
189
+ if seg_images is not None:
190
+ seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
191
+ while seg_token_indices.numel() > 0:
192
+ cur_seg_features = seg_features[cur_seg_idx]
193
+ seg_token_start = seg_token_indices[0]
194
+ cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:seg_token_start]))
195
+ cur_new_input_embeds.append(cur_seg_features)
196
+ if labels is not None:
197
+ cur_new_labels.append(cur_labels[:seg_token_start])
198
+ cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
199
+ cur_labels = cur_labels[seg_token_start+1:]
200
+ cur_seg_idx += 1
201
+ cur_input_ids = cur_input_ids[seg_token_start+1:]
202
+ seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
203
+
204
+ if cur_input_ids.numel() > 0:
205
+ if seg_images is None:
206
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
207
+ else:
208
+ cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids))
209
+ if labels is not None:
210
+ cur_new_labels.append(cur_labels)
211
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
212
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
213
+ new_input_embeds.append(cur_new_input_embeds)
214
+ if labels is not None:
215
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
216
+ new_labels.append(cur_new_labels)
217
+
218
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
219
+ max_len = max(x.shape[0] for x in new_input_embeds)
220
+
221
+ new_input_embeds_align = []
222
+ for cur_new_embed in new_input_embeds:
223
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
224
+ new_input_embeds_align.append(cur_new_embed)
225
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
226
+
227
+ if labels is not None:
228
+ new_labels_align = []
229
+ _new_labels = new_labels
230
+ for cur_new_label in new_labels:
231
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
232
+ new_labels_align.append(cur_new_label)
233
+ new_labels = torch.stack(new_labels_align, dim=0)
234
+
235
+ if attention_mask is not None:
236
+ new_attention_mask = []
237
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
238
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
239
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
240
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
241
+ new_attention_mask.append(cur_new_attention_mask)
242
+ attention_mask = torch.stack(new_attention_mask, dim=0)
243
+ assert attention_mask.shape == new_labels.shape
244
+ else:
245
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
246
+ if labels is not None:
247
+ new_labels = torch.stack(new_labels, dim=0)
248
+
249
+ if attention_mask is not None:
250
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
251
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
252
+ assert attention_mask.shape == new_input_embeds.shape[:2]
253
+
254
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
vcoder_llava/questions.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ SEMANTIC_QUESTIONS = [
3
+ "What objects can be seen in the image? Perceive as done for semantic segmentation.",
4
+ "What items are depicted in the picture? Consider in terms of semantic segmentation.",
5
+ "Which elements are present in the visual? Analyze as you would for semantic segmentation.",
6
+ "Can you identify the objects in the image? Think from a semantic segmentation perspective.",
7
+ "What are the components visible in the graphic? Examine as if segmenting semantically.",
8
+ "Which entities can be spotted in the photo? View through the lens of semantic segmentation.",
9
+ "What are the discernible objects in the snapshot? Envision in relation to semantic segmentation.",
10
+ "What elements stand out in the illustration? Reflect upon it as for semantic segmentation.",
11
+ "Can you spot any items within the visual representation? Contemplate in a semantic segmentation context.",
12
+ "What features are evident in this visual content? Analyze with semantic segmentation in mind.",
13
+ "Which objects are noticeable in the image? Think of it in terms of semantic layers.",
14
+ "How would you categorize the objects in this picture? As if you're doing semantic segmentation.",
15
+ "What constituents can you recognize in the image? Ponder considering semantic segmentation.",
16
+ "Which components can be distinguished in the photo? Evaluate as per semantic segmentation guidelines.",
17
+ "What items in the image can you point out? Interpret with a semantic segmentation approach.",
18
+ "Can you enumerate the objects present in this visual? Think semantically.",
19
+ "What do you observe in the graphic? Consider its semantic segments.",
20
+ "How many distinct objects can you identify in the visual? Keeping semantic segmentation in perspective.",
21
+ "Which items are apparent in this depiction? Assess as one would for semantic segmentation.",
22
+ "What are the visible entities within this image? Delve into it semantically.",
23
+ "Can you discern specific objects in the portrayal? Approach it from a semantic segmentation standpoint.",
24
+ ]
25
+
26
+ INSTANCE_QUESTIONS = [
27
+ "What objects can be seen in the image? Perceive as done for instance segmentation",
28
+ "What items are visible in the picture? Analyze as you would for instance segmentation.",
29
+ "Which elements are present in the visual? Consider from an instance segmentation perspective.",
30
+ "What are the distinguishable objects in the image? Think in terms of instance segmentation.",
31
+ "Can you identify the entities in the graphic? Approach it with instance segmentation in mind.",
32
+ "What components are apparent in the photo? Examine as if performing instance segmentation.",
33
+ "Which items can be detected in the snapshot? View it through the lens of instance segmentation.",
34
+ "What features stand out in the illustration? Reflect upon it as for instance segmentation.",
35
+ "How would you describe the objects in this image? Keeping instance segmentation as a reference.",
36
+ "What constituents are evident in the visual content? Think from an instance segmentation standpoint.",
37
+ "Which objects can you spot in the depiction? Evaluate as per instance segmentation guidelines.",
38
+ "What do you observe in the graphic? Contemplate with instance segmentation considerations.",
39
+ "Can you discern specific entities in the visual? Approach it in the context of instance segmentation.",
40
+ "Which components in the image catch your eye? Think of it in relation to instance layers.",
41
+ "How many distinct items can you pinpoint in the photo? With an instance segmentation approach.",
42
+ "What elements are noticeable in this portrayal? Analyze while considering instance segmentation.",
43
+ "Can you list the objects present in the visual representation? Reflecting on instance segmentation.",
44
+ "What items in the snapshot can you recognize? Interpret with an instance segmentation perspective.",
45
+ "Which entities are discernible in this depiction? Delve into it from an instance segmentation angle.",
46
+ "What are the components you can spot within the image? Think instance-wise.",
47
+ "Can you detail the objects in the visual? Assess as one would for instance segmentation.",
48
+ ]
49
+
50
+ PANOPTIC_QUESTIONS = [
51
+ "What objects can be seen in the image? Perceive as done for panoptic segmentation",
52
+ "What items are evident in the picture? Analyze with a panoptic segmentation perspective.",
53
+ "Which elements emerge in the visual? Think in terms of panoptic segmentation.",
54
+ "What are the discernible objects in the graphic? Approach it from a panoptic segmentation viewpoint.",
55
+ "Can you identify the entities within the image? Consider it as you would for panoptic segmentation.",
56
+ "What components stand out in the photo? Examine with panoptic segmentation in mind.",
57
+ "Which items are detectable in the snapshot? Reflect upon it with panoptic segmentation considerations.",
58
+ "What features can be observed in the illustration? View through the lens of panoptic segmentation.",
59
+ "How would you describe the objects in this depiction? Keeping panoptic segmentation as a reference.",
60
+ "What constituents are visible in the visual content? Think from a panoptic segmentation standpoint.",
61
+ "Which objects can you pinpoint in the image? Evaluate as per panoptic segmentation guidelines.",
62
+ "What do you perceive in the graphic? Delve into it with panoptic segmentation insights.",
63
+ "Can you spot specific components in the visual? Contextualize with panoptic segmentation.",
64
+ "What items in the portrayal catch your attention? Think in relation to panoptic layers.",
65
+ "How many distinct entities can you recognize in the photo? With a panoptic segmentation approach.",
66
+ "What elements are present in this visual? Analyze while keeping panoptic segmentation in mind.",
67
+ "Can you list the objects depicted in the visual representation? Reflecting on panoptic segmentation.",
68
+ "Which features in the image can you discern? Interpret considering panoptic segmentation.",
69
+ "What are the components evident in this depiction? Approach it using a panoptic segmentation angle.",
70
+ "What items can you detect in the visual content? Think panoptically.",
71
+ "Can you detail the entities present in the image? Assess as one would when considering panoptic segmentation.",
72
+ ]
73
+
74
+ DEPTH_QUESTIONS = [
75
+ "what is depth order of objects in the image?",
76
+ "Can you describe the depth order of the objects in this image, from closest to farthest?",
77
+ "Which objects in the image appear nearest to the viewer and which seem furthest away?",
78
+ "Could you list the objects in the image in order of their perceived distance from the foreground to the background?",
79
+ "In what order do the objects in this image appear based on their depth, starting from the closest?",
80
+ "How would you rank the objects in this picture from the most proximal to the most distal?",
81
+ "Can you arrange the objects seen here from those appearing closest to those appearing farthest?",
82
+ "What is the sequence of objects in this image based on their distance from the front to the back?",
83
+ "Please identify the order of objects in terms of depth perspective in this image.",
84
+ "Which objects in the picture seem to be in the front, and which ones appear to be in the back?",
85
+ "How are the objects in this image layered in depth, from the one nearest to the camera to the one farthest?",
86
+ "Could you sort the objects in this photo from foreground to background?",
87
+ "In this image, what is the spatial arrangement of objects from closest to furthest?",
88
+ "Can you pinpoint the depth hierarchy of these objects, starting from the closest?",
89
+ "What's the depth sequence of the objects displayed in this picture?",
90
+ "From nearest to furthest, how would you order the objects in this image?",
91
+ "How would you describe the spatial positioning of these objects in terms of their depth?",
92
+ "Can you determine the depth placement of each object in this photo, starting with the nearest?",
93
+ "What is the arrangement of objects in this scene by depth?",
94
+ "Could you outline the depth profile of the objects in this image?",
95
+ "In what depth order do the objects in this image align, from the frontmost to the rearmost?",
96
+ "How are the objects in this image ordered in terms of their relative distance from the observer?",
97
+ ]
98
+
99
+ QUESTIONS = {
100
+ 'semantic': SEMANTIC_QUESTIONS,
101
+ 'instance': INSTANCE_QUESTIONS,
102
+ 'panoptic': PANOPTIC_QUESTIONS,
103
+ 'depth': DEPTH_QUESTIONS,
104
+ }
105
+
106
+ ### Depth Prompts
107
+ # Can you describe the depth order of the objects in this image, from closest to farthest? Return answer in the paragraph format: `The depth order for the objects present in the image is: ...' and then list the objects with their order number (if greater than 1) separated by a hyphen like `person-2'. For example, an acceptable response is "The depth order for objects present in the image is: bicycle, bicycle-2, bicycle-3, pavement, road, bus, tree, sky, building."
108
+
109
+ ### Seg Prompts
110
+ # What objects can be seen in the image? Return the answer in the paragraph format: The objects present in the image are: ...' and then list the objects with their count in word format (if greater than 1) in front of them, like two people'.
vcoder_llava/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from vcoder_llava.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def build_logger(logger_name, logger_filename):
18
+ global handler
19
+
20
+ formatter = logging.Formatter(
21
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ )
24
+
25
+ # Set the format of root handlers
26
+ if not logging.getLogger().handlers:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger().handlers[0].setFormatter(formatter)
29
+
30
+ # Redirect stdout and stderr to loggers
31
+ stdout_logger = logging.getLogger("stdout")
32
+ stdout_logger.setLevel(logging.INFO)
33
+ sl = StreamToLogger(stdout_logger, logging.INFO)
34
+ sys.stdout = sl
35
+
36
+ stderr_logger = logging.getLogger("stderr")
37
+ stderr_logger.setLevel(logging.ERROR)
38
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
39
+ sys.stderr = sl
40
+
41
+ # Get logger
42
+ logger = logging.getLogger(logger_name)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # Add a file handler for all loggers
46
+ if handler is None:
47
+ os.makedirs(LOGDIR, exist_ok=True)
48
+ filename = os.path.join(LOGDIR, logger_filename)
49
+ handler = logging.handlers.TimedRotatingFileHandler(
50
+ filename, when='D', utc=True)
51
+ handler.setFormatter(formatter)
52
+
53
+ for name, item in logging.root.manager.loggerDict.items():
54
+ if isinstance(item, logging.Logger):
55
+ item.addHandler(handler)
56
+
57
+ return logger
58
+
59
+
60
+ class StreamToLogger(object):
61
+ """
62
+ Fake file-like stream object that redirects writes to a logger instance.
63
+ """
64
+ def __init__(self, logger, log_level=logging.INFO):
65
+ self.terminal = sys.stdout
66
+ self.logger = logger
67
+ self.log_level = log_level
68
+ self.linebuf = ''
69
+
70
+ def __getattr__(self, attr):
71
+ return getattr(self.terminal, attr)
72
+
73
+ def write(self, buf):
74
+ temp_linebuf = self.linebuf + buf
75
+ self.linebuf = ''
76
+ for line in temp_linebuf.splitlines(True):
77
+ # From the io.TextIOWrapper docs:
78
+ # On output, if newline is None, any '\n' characters written
79
+ # are translated to the system default line separator.
80
+ # By default sys.stdout.write() expects '\n' newlines and then
81
+ # translates them so this is still cross platform.
82
+ if line[-1] == '\n':
83
+ self.logger.log(self.log_level, line.rstrip())
84
+ else:
85
+ self.linebuf += line
86
+
87
+ def flush(self):
88
+ if self.linebuf != '':
89
+ self.logger.log(self.log_level, self.linebuf.rstrip())
90
+ self.linebuf = ''
91
+
92
+
93
+ def disable_torch_init():
94
+ """
95
+ Disable the redundant torch default initialization to accelerate model creation.
96
+ """
97
+ import torch
98
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100
+
101
+
102
+ def violates_moderation(text):
103
+ """
104
+ Check whether the text violates OpenAI moderation API.
105
+ """
106
+ url = "https://api.openai.com/v1/moderations"
107
+ headers = {"Content-Type": "application/json",
108
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109
+ text = text.replace("\n", "")
110
+ data = "{" + '"input": ' + f'"{text}"' + "}"
111
+ data = data.encode("utf-8")
112
+ try:
113
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
114
+ flagged = ret.json()["results"][0]["flagged"]
115
+ except requests.exceptions.RequestException as e:
116
+ flagged = False
117
+ except KeyError as e:
118
+ flagged = False
119
+
120
+ return flagged
121
+
122
+
123
+ def pretty_print_semaphore(semaphore):
124
+ if semaphore is None:
125
+ return "None"
126
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
vcoder_llava/vcoder_conversation.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class VCoderConversation:
17
+ """A class that keeps all conversation history."""
18
+ system: str
19
+ roles: List[str]
20
+ messages: List[List[str]]
21
+ offset: int
22
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23
+ sep: str = "###"
24
+ sep2: str = None
25
+ version: str = "Unknown"
26
+
27
+ skip_next: bool = False
28
+
29
+ def get_prompt(self):
30
+ messages = self.messages
31
+ if self.sep_style == SeparatorStyle.SINGLE:
32
+ ret = self.system + self.sep
33
+ for role, message in messages:
34
+ if message:
35
+ if type(message) is tuple:
36
+ message, _, _, _, _, _, _ = message
37
+ ret += role + ": " + message + self.sep
38
+ else:
39
+ ret += role + ":"
40
+ elif self.sep_style == SeparatorStyle.TWO:
41
+ seps = [self.sep, self.sep2]
42
+ ret = self.system + seps[0]
43
+ for i, (role, message) in enumerate(messages):
44
+ if message:
45
+ if type(message) is tuple:
46
+ message, _, _, _, _, _, _ = message
47
+ ret += role + ": " + message + seps[i % 2]
48
+ else:
49
+ ret += role + ":"
50
+ elif self.sep_style == SeparatorStyle.MPT:
51
+ ret = self.system + self.sep
52
+ for role, message in messages:
53
+ if message:
54
+ if type(message) is tuple:
55
+ message, _, _, _, _, _, _ = message
56
+ ret += role + message + self.sep
57
+ else:
58
+ ret += role
59
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
60
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
61
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
62
+ ret = ""
63
+
64
+ for i, (role, message) in enumerate(messages):
65
+ if i == 0:
66
+ assert message, "first message should not be none"
67
+ assert role == self.roles[0], "first message should come from user"
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _, _, _, _, _ = message
71
+ if i == 0: message = wrap_sys(self.system) + message
72
+ if i % 2 == 0:
73
+ message = wrap_inst(message)
74
+ ret += self.sep + message
75
+ else:
76
+ ret += " " + message + " " + self.sep2
77
+ else:
78
+ ret += ""
79
+ ret = ret.lstrip(self.sep)
80
+ elif self.sep_style == SeparatorStyle.PLAIN:
81
+ seps = [self.sep, self.sep2]
82
+ ret = self.system
83
+ for i, (role, message) in enumerate(messages):
84
+ if message:
85
+ if type(message) is tuple:
86
+ message, _, _, _, _, _, _ = message
87
+ ret += message + seps[i % 2]
88
+ else:
89
+ ret += ""
90
+ else:
91
+ raise ValueError(f"Invalid style: {self.sep_style}")
92
+
93
+ return ret
94
+
95
+ def append_message(self, role, message):
96
+ self.messages.append([role, message])
97
+
98
+ def get_images(self, return_pil=False):
99
+ images = []
100
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
101
+ if i % 2 == 0:
102
+ if type(msg) is tuple:
103
+ import base64
104
+ from io import BytesIO
105
+ from PIL import Image
106
+ msg, image, image_process_mode, _, _, _, _ = msg
107
+ if image is not None:
108
+ if image_process_mode == "Pad":
109
+ def expand2square(pil_img, background_color=(122, 116, 104)):
110
+ width, height = pil_img.size
111
+ if width == height:
112
+ return pil_img
113
+ elif width > height:
114
+ result = Image.new(pil_img.mode, (width, width), background_color)
115
+ result.paste(pil_img, (0, (width - height) // 2))
116
+ return result
117
+ else:
118
+ result = Image.new(pil_img.mode, (height, height), background_color)
119
+ result.paste(pil_img, ((height - width) // 2, 0))
120
+ return result
121
+ image = expand2square(image)
122
+ elif image_process_mode in ["Default", "Crop"]:
123
+ pass
124
+ elif image_process_mode == "Resize":
125
+ image = image.resize((336, 336))
126
+ else:
127
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
128
+ max_hw, min_hw = max(image.size), min(image.size)
129
+ aspect_ratio = max_hw / min_hw
130
+ max_len, min_len = 800, 400
131
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
132
+ longest_edge = int(shortest_edge * aspect_ratio)
133
+ W, H = image.size
134
+ if longest_edge != max(image.size):
135
+ if H > W:
136
+ H, W = longest_edge, shortest_edge
137
+ else:
138
+ H, W = shortest_edge, longest_edge
139
+ image = image.resize((W, H))
140
+ if return_pil:
141
+ images.append(image)
142
+ else:
143
+ buffered = BytesIO()
144
+ image.save(buffered, format="PNG")
145
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
146
+ images.append(img_b64_str)
147
+ return images
148
+
149
+ def get_segs(self, return_pil=False):
150
+ segs = []
151
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
152
+ if i % 2 == 0:
153
+ if type(msg) is tuple:
154
+ import base64
155
+ from io import BytesIO
156
+ from PIL import Image
157
+ msg, _, _, seg, seg_process_mode, _, _ = msg
158
+ if seg is not None:
159
+ if seg_process_mode == "Pad":
160
+ def expand2square(pil_img, background_color=(122, 116, 104)):
161
+ width, height = pil_img.size
162
+ if width == height:
163
+ return pil_img
164
+ elif width > height:
165
+ result = Image.new(pil_img.mode, (width, width), background_color)
166
+ result.paste(pil_img, (0, (width - height) // 2))
167
+ return result
168
+ else:
169
+ result = Image.new(pil_img.mode, (height, height), background_color)
170
+ result.paste(pil_img, ((height - width) // 2, 0))
171
+ return result
172
+ seg = expand2square(seg)
173
+ elif seg_process_mode in ["Default", "Crop"]:
174
+ pass
175
+ elif seg_process_mode == "Resize":
176
+ seg = seg.resize((336, 336))
177
+ else:
178
+ raise ValueError(f"Invalid image_process_mode: {seg_process_mode}")
179
+ max_hw, min_hw = max(seg.size), min(seg.size)
180
+ aspect_ratio = max_hw / min_hw
181
+ max_len, min_len = 800, 400
182
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
183
+ longest_edge = int(shortest_edge * aspect_ratio)
184
+ W, H = seg.size
185
+ if longest_edge != max(seg.size):
186
+ if H > W:
187
+ H, W = longest_edge, shortest_edge
188
+ else:
189
+ H, W = shortest_edge, longest_edge
190
+ seg = seg.resize((W, H))
191
+ if return_pil:
192
+ segs.append(seg)
193
+ else:
194
+ buffered = BytesIO()
195
+ seg.save(buffered, format="PNG")
196
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
197
+ segs.append(img_b64_str)
198
+ return segs
199
+
200
+ def get_depths(self, return_pil=False):
201
+ depths = []
202
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
203
+ if i % 2 == 0:
204
+ if type(msg) is tuple:
205
+ import base64
206
+ from io import BytesIO
207
+ from PIL import Image
208
+ msg, _, _, _, _, depth, depth_process_mode = msg
209
+ if depth is not None:
210
+ if depth_process_mode == "Pad":
211
+ def expand2square(pil_img, background_color=(122, 116, 104)):
212
+ width, height = pil_img.size
213
+ if width == height:
214
+ return pil_img
215
+ elif width > height:
216
+ result = Image.new(pil_img.mode, (width, width), background_color)
217
+ result.paste(pil_img, (0, (width - height) // 2))
218
+ return result
219
+ else:
220
+ result = Image.new(pil_img.mode, (height, height), background_color)
221
+ result.paste(pil_img, ((height - width) // 2, 0))
222
+ return result
223
+ depth = expand2square(depth)
224
+ elif depth_process_mode in ["Default", "Crop"]:
225
+ pass
226
+ elif depth_process_mode == "Resize":
227
+ depth = depth.resize((336, 336))
228
+ else:
229
+ raise ValueError(f"Invalid image_process_mode: {depth_process_mode}")
230
+ max_hw, min_hw = max(depth.size), min(depth.size)
231
+ aspect_ratio = max_hw / min_hw
232
+ max_len, min_len = 800, 400
233
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
234
+ longest_edge = int(shortest_edge * aspect_ratio)
235
+ W, H = depth.size
236
+ if longest_edge != max(depth.size):
237
+ if H > W:
238
+ H, W = longest_edge, shortest_edge
239
+ else:
240
+ H, W = shortest_edge, longest_edge
241
+ depth = depth.resize((W, H))
242
+ if return_pil:
243
+ depths.append(depth)
244
+ else:
245
+ buffered = BytesIO()
246
+ depth.save(buffered, format="PNG")
247
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
248
+ depths.append(img_b64_str)
249
+ return depths
250
+
251
+ def to_gradio_chatbot(self):
252
+ ret = []
253
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
254
+ if i % 2 == 0:
255
+ if type(msg) is tuple:
256
+ import base64
257
+ from io import BytesIO
258
+ msg, image, image_process_mode, seg, seg_process_mode, depth, depth_process_mode = msg
259
+ if image is not None:
260
+ max_hw, min_hw = max(image.size), min(image.size)
261
+ aspect_ratio = max_hw / min_hw
262
+ max_len, min_len = 800, 400
263
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
264
+ longest_edge = int(shortest_edge * aspect_ratio)
265
+ W, H = image.size
266
+ if H > W:
267
+ H, W = longest_edge, shortest_edge
268
+ else:
269
+ H, W = shortest_edge, longest_edge
270
+ image = image.resize((W, H))
271
+ buffered = BytesIO()
272
+ image.save(buffered, format="JPEG")
273
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
274
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
275
+ msg = img_str + msg.replace('<image>', '').strip()
276
+
277
+ if seg is not None:
278
+ W, H = seg.size
279
+ if H > W:
280
+ H, W = longest_edge, shortest_edge
281
+ else:
282
+ H, W = shortest_edge, longest_edge
283
+ seg = seg.resize((W, H))
284
+ seg_buffered = BytesIO()
285
+ seg.save(seg_buffered, format="JPEG")
286
+ seg_b64_str = base64.b64encode(seg_buffered.getvalue()).decode()
287
+ seg_str = f'<img src="data:image/png;base64,{seg_b64_str}" alt="user upload seg" />'
288
+ msg = seg_str + msg.replace('<seg>', '').strip()
289
+
290
+ if depth is not None:
291
+ W, H = depth.size
292
+ if H > W:
293
+ H, W = longest_edge, shortest_edge
294
+ else:
295
+ H, W = shortest_edge, longest_edge
296
+ depth = depth.resize((W, H))
297
+ depth_buffered = BytesIO()
298
+ depth.save(depth_buffered, format="JPEG")
299
+ depth_b64_str = base64.b64encode(depth_buffered.getvalue()).decode()
300
+ depth_str = f'<img src="data:image/png;base64,{depth_b64_str}" alt="user upload depth" />'
301
+ msg = depth_str + msg.replace('<depth>', '').strip()
302
+ ret.append([msg, None])
303
+ else:
304
+ ret.append([msg, None])
305
+ else:
306
+ ret[-1][-1] = msg
307
+ return ret
308
+
309
+ def copy(self):
310
+ return VCoderConversation(
311
+ system=self.system,
312
+ roles=self.roles,
313
+ messages=[[x, y] for x, y in self.messages],
314
+ offset=self.offset,
315
+ sep_style=self.sep_style,
316
+ sep=self.sep,
317
+ sep2=self.sep2,
318
+ version=self.version)
319
+
320
+ def dict(self):
321
+ if len(self.get_images()) > 0:
322
+ return {
323
+ "system": self.system,
324
+ "roles": self.roles,
325
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
326
+ "offset": self.offset,
327
+ "sep": self.sep,
328
+ "sep2": self.sep2,
329
+ }
330
+ return {
331
+ "system": self.system,
332
+ "roles": self.roles,
333
+ "messages": self.messages,
334
+ "offset": self.offset,
335
+ "sep": self.sep,
336
+ "sep2": self.sep2,
337
+ }
338
+
339
+
340
+ conv_vicuna_v1 = VCoderConversation(
341
+ system="A chat between a curious user and an artificial intelligence assistant. "
342
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
343
+ roles=("USER", "ASSISTANT"),
344
+ version="v1",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.TWO,
348
+ sep=" ",
349
+ sep2="</s>",
350
+ )
351
+
352
+ conv_llava_v1 = VCoderConversation(
353
+ system="A chat between a curious human and an artificial intelligence assistant. "
354
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
355
+ roles=("USER", "ASSISTANT"),
356
+ version="v1",
357
+ messages=(),
358
+ offset=0,
359
+ sep_style=SeparatorStyle.TWO,
360
+ sep=" ",
361
+ sep2="</s>",
362
+ )
363
+
364
+
365
+ default_conversation = conv_vicuna_v1
366
+ conv_templates = {
367
+ "v1": conv_vicuna_v1,
368
+ "vicuna_v1": conv_vicuna_v1,
369
+ "llava_v1": conv_llava_v1,
370
+ }
371
+
372
+
373
+ if __name__ == "__main__":
374
+ print(default_conversation.get_prompt())