Spaces:
Runtime error
Runtime error
praeclarumjj3
commited on
Commit
·
d3cee44
1
Parent(s):
0bd6903
:zap: Build space
Browse files- .DS_Store +0 -0
- .gitattributes +3 -0
- LICENSE +201 -0
- README.md +3 -3
- app.py +389 -0
- chat.py +205 -0
- examples/3.jpg +3 -0
- examples/3_ins.png +3 -0
- examples/3_pan.png +3 -0
- requirements.txt +30 -0
- vcoder_llava/.DS_Store +0 -0
- vcoder_llava/__init__.py +1 -0
- vcoder_llava/constants.py +12 -0
- vcoder_llava/data_utils.py +157 -0
- vcoder_llava/mm_utils.py +151 -0
- vcoder_llava/model/.DS_Store +0 -0
- vcoder_llava/model/__init__.py +3 -0
- vcoder_llava/model/apply_delta.py +48 -0
- vcoder_llava/model/builder.py +152 -0
- vcoder_llava/model/consolidate.py +29 -0
- vcoder_llava/model/language_model/llava_llama.py +165 -0
- vcoder_llava/model/language_model/vcoder_ds_llava_llama.py +145 -0
- vcoder_llava/model/language_model/vcoder_llava_llama.py +142 -0
- vcoder_llava/model/llava_arch.py +200 -0
- vcoder_llava/model/make_delta.py +52 -0
- vcoder_llava/model/multimodal_adapter/builder.py +49 -0
- vcoder_llava/model/multimodal_depth_adapter/builder.py +50 -0
- vcoder_llava/model/multimodal_encoder/builder.py +11 -0
- vcoder_llava/model/multimodal_encoder/clip_encoder.py +78 -0
- vcoder_llava/model/multimodal_projector/builder.py +51 -0
- vcoder_llava/model/utils.py +20 -0
- vcoder_llava/model/vcd/vcd_add_noise.py +28 -0
- vcoder_llava/model/vcd/vcd_sample.py +250 -0
- vcoder_llava/model/vcoder_ds_llava_arch.py +323 -0
- vcoder_llava/model/vcoder_llava_arch.py +254 -0
- vcoder_llava/questions.py +110 -0
- vcoder_llava/utils.py +126 -0
- vcoder_llava/vcoder_conversation.py +374 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: VCoder
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.8.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
title: VCoder
|
3 |
+
emoji: ✌️
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: orange
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.8.0
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import requests
|
9 |
+
import hashlib
|
10 |
+
|
11 |
+
from vcoder_llava.vcoder_conversation import (default_conversation, conv_templates,
|
12 |
+
SeparatorStyle)
|
13 |
+
from vcoder_llava.constants import LOGDIR
|
14 |
+
from vcoder_llava.utils import (build_logger, server_error_msg,
|
15 |
+
violates_moderation, moderation_msg)
|
16 |
+
from .chat import Chat
|
17 |
+
|
18 |
+
|
19 |
+
logger = build_logger("gradio_app", "gradio_web_server.log")
|
20 |
+
|
21 |
+
headers = {"User-Agent": "VCoder Client"}
|
22 |
+
|
23 |
+
no_change_btn = gr.Button.update()
|
24 |
+
enable_btn = gr.Button.update(interactive=True)
|
25 |
+
disable_btn = gr.Button.update(interactive=False)
|
26 |
+
|
27 |
+
priority = {
|
28 |
+
"vicuna-13b": "aaaaaaa",
|
29 |
+
"koala-13b": "aaaaaab",
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def get_conv_log_filename():
|
34 |
+
t = datetime.datetime.now()
|
35 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
36 |
+
return name
|
37 |
+
|
38 |
+
|
39 |
+
get_window_url_params = """
|
40 |
+
function() {
|
41 |
+
const params = new URLSearchParams(window.location.search);
|
42 |
+
url_params = Object.fromEntries(params);
|
43 |
+
console.log(url_params);
|
44 |
+
return url_params;
|
45 |
+
}
|
46 |
+
"""
|
47 |
+
|
48 |
+
|
49 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
50 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
51 |
+
state = default_conversation.copy()
|
52 |
+
dropdown_update = gr.Dropdown.update(
|
53 |
+
choices=models,
|
54 |
+
value=models[0] if len(models) > 0 else ""
|
55 |
+
)
|
56 |
+
return state, dropdown_update
|
57 |
+
|
58 |
+
|
59 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
60 |
+
with open(get_conv_log_filename(), "a") as fout:
|
61 |
+
data = {
|
62 |
+
"tstamp": round(time.time(), 4),
|
63 |
+
"type": vote_type,
|
64 |
+
"model": model_selector,
|
65 |
+
"state": state.dict(),
|
66 |
+
}
|
67 |
+
fout.write(json.dumps(data) + "\n")
|
68 |
+
|
69 |
+
|
70 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
71 |
+
vote_last_response(state, "upvote", model_selector, request)
|
72 |
+
return ("",) + (disable_btn,) * 3
|
73 |
+
|
74 |
+
|
75 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
76 |
+
vote_last_response(state, "downvote", model_selector, request)
|
77 |
+
return ("",) + (disable_btn,) * 3
|
78 |
+
|
79 |
+
|
80 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
81 |
+
vote_last_response(state, "flag", model_selector, request)
|
82 |
+
return ("",) + (disable_btn,) * 3
|
83 |
+
|
84 |
+
def regenerate(state, image_process_mode, seg_process_mode):
|
85 |
+
state.messages[-1][-1] = None
|
86 |
+
prev_human_msg = state.messages[-2]
|
87 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
88 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, prev_human_msg[1][3], seg_process_mode, None, None)
|
89 |
+
state.skip_next = False
|
90 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
91 |
+
|
92 |
+
|
93 |
+
def clear_history(request: gr.Request):
|
94 |
+
state = default_conversation.copy()
|
95 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
96 |
+
|
97 |
+
|
98 |
+
def add_text(state, text, image, image_process_mode, seg, seg_process_mode, depth, depth_process_mode, request: gr.Request):
|
99 |
+
logger.info(f"add_text. len: {len(text)}")
|
100 |
+
if len(text) <= 0 and image is None:
|
101 |
+
state.skip_next = True
|
102 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
|
103 |
+
if args.moderate:
|
104 |
+
flagged = violates_moderation(text)
|
105 |
+
if flagged:
|
106 |
+
state.skip_next = True
|
107 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None, None) + (
|
108 |
+
no_change_btn,) * 5
|
109 |
+
|
110 |
+
text = text[:1576] # Hard cut-off
|
111 |
+
if image is not None:
|
112 |
+
text = text[:1200] # Hard cut-off for images
|
113 |
+
if '<image>' not in text:
|
114 |
+
text = '<image>\n' + text
|
115 |
+
if seg is not None:
|
116 |
+
if '<seg>' not in text:
|
117 |
+
text = '<seg>\n' + text
|
118 |
+
|
119 |
+
text = (text, image, image_process_mode, seg, seg_process_mode, None, None)
|
120 |
+
if len(state.get_images(return_pil=True)) > 0:
|
121 |
+
state = default_conversation.copy()
|
122 |
+
state.append_message(state.roles[0], text)
|
123 |
+
state.append_message(state.roles[1], None)
|
124 |
+
state.skip_next = False
|
125 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
126 |
+
|
127 |
+
|
128 |
+
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
129 |
+
start_tstamp = time.time()
|
130 |
+
model_name = model_selector
|
131 |
+
|
132 |
+
if state.skip_next:
|
133 |
+
# This generate call is skipped due to invalid inputs
|
134 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
135 |
+
return
|
136 |
+
|
137 |
+
if len(state.messages) == state.offset + 2:
|
138 |
+
# First round of conversation
|
139 |
+
if "llava" in model_name.lower():
|
140 |
+
template_name = "llava_v1"
|
141 |
+
new_state = conv_templates[template_name].copy()
|
142 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
143 |
+
new_state.append_message(new_state.roles[1], None)
|
144 |
+
state = new_state
|
145 |
+
|
146 |
+
# Construct prompt
|
147 |
+
prompt = state.get_prompt()
|
148 |
+
|
149 |
+
all_images = state.get_images(return_pil=True)
|
150 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
151 |
+
for image, hash in zip(all_images, all_image_hash):
|
152 |
+
t = datetime.datetime.now()
|
153 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
154 |
+
if not os.path.isfile(filename):
|
155 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
156 |
+
image.save(filename)
|
157 |
+
|
158 |
+
all_segs = state.get_segs(return_pil=True)
|
159 |
+
all_seg_hash = [hashlib.md5(seg.tobytes()).hexdigest() for seg in all_segs]
|
160 |
+
for seg, hash in zip(all_segs, all_seg_hash):
|
161 |
+
t = datetime.datetime.now()
|
162 |
+
filename = os.path.join(LOGDIR, "serve_segs", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
163 |
+
if not os.path.isfile(filename):
|
164 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
165 |
+
seg.save(filename)
|
166 |
+
|
167 |
+
# Make requests
|
168 |
+
pload = {
|
169 |
+
"model": model_name,
|
170 |
+
"prompt": prompt,
|
171 |
+
"temperature": float(temperature),
|
172 |
+
"top_p": float(top_p),
|
173 |
+
"max_new_tokens": min(int(max_new_tokens), 1536),
|
174 |
+
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
|
175 |
+
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
176 |
+
"segs": f'List of {len(state.get_segs())} segs: {all_seg_hash}',
|
177 |
+
}
|
178 |
+
logger.info(f"==== request ====\n{pload}")
|
179 |
+
|
180 |
+
pload['images'] = state.get_images()
|
181 |
+
pload['segs'] = state.get_segs()
|
182 |
+
|
183 |
+
state.messages[-1][-1] = "▌"
|
184 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
185 |
+
|
186 |
+
|
187 |
+
try:
|
188 |
+
# Stream output
|
189 |
+
response = chat.generate_stream_gate(pload)
|
190 |
+
for chunk in response:
|
191 |
+
if chunk:
|
192 |
+
data = json.loads(chunk.decode())
|
193 |
+
if data["error_code"] == 0:
|
194 |
+
output = data["text"][len(prompt):].strip()
|
195 |
+
state.messages[-1][-1] = output + "▌"
|
196 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
197 |
+
else:
|
198 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
199 |
+
state.messages[-1][-1] = output
|
200 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
201 |
+
return
|
202 |
+
time.sleep(0.03)
|
203 |
+
except:
|
204 |
+
state.messages[-1][-1] = server_error_msg
|
205 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
206 |
+
return
|
207 |
+
|
208 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
209 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
210 |
+
|
211 |
+
finish_tstamp = time.time()
|
212 |
+
logger.info(f"{output}")
|
213 |
+
|
214 |
+
with open(get_conv_log_filename(), "a") as fout:
|
215 |
+
data = {
|
216 |
+
"tstamp": round(finish_tstamp, 4),
|
217 |
+
"type": "chat",
|
218 |
+
"model": model_name,
|
219 |
+
"start": round(start_tstamp, 4),
|
220 |
+
"finish": round(start_tstamp, 4),
|
221 |
+
"state": state.dict(),
|
222 |
+
"images": all_image_hash,
|
223 |
+
"segs": all_seg_hash,
|
224 |
+
"ip": request.client.host,
|
225 |
+
}
|
226 |
+
fout.write(json.dumps(data) + "\n")
|
227 |
+
|
228 |
+
title_markdown = ("""
|
229 |
+
# 🌋 LLaVA: Large Language and Vision Assistant
|
230 |
+
[[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
|
231 |
+
""")
|
232 |
+
|
233 |
+
tos_markdown = ("""
|
234 |
+
### Terms of use
|
235 |
+
By using this service, users are required to agree to the following terms:
|
236 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
237 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
238 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
239 |
+
""")
|
240 |
+
|
241 |
+
|
242 |
+
learn_more_markdown = ("""
|
243 |
+
### License
|
244 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
245 |
+
""")
|
246 |
+
|
247 |
+
block_css = """
|
248 |
+
|
249 |
+
#buttons button {
|
250 |
+
min-width: min(120px,100%);
|
251 |
+
}
|
252 |
+
|
253 |
+
"""
|
254 |
+
|
255 |
+
def build_demo(embed_mode):
|
256 |
+
|
257 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
258 |
+
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
|
259 |
+
state = gr.State()
|
260 |
+
|
261 |
+
if not embed_mode:
|
262 |
+
gr.Markdown(title_markdown)
|
263 |
+
|
264 |
+
with gr.Row():
|
265 |
+
with gr.Column(scale=3):
|
266 |
+
with gr.Row(elem_id="model_selector_row"):
|
267 |
+
model_selector = gr.Dropdown(
|
268 |
+
choices=models,
|
269 |
+
value=models[0] if len(models) > 0 else "",
|
270 |
+
interactive=True,
|
271 |
+
show_label=False,
|
272 |
+
container=False)
|
273 |
+
|
274 |
+
# with gr.Row():
|
275 |
+
imagebox = gr.Image(type="pil", label="Image Input")
|
276 |
+
image_process_mode = gr.Radio(
|
277 |
+
["Crop", "Resize", "Pad", "Default"],
|
278 |
+
value="Default",
|
279 |
+
label="Preprocess for non-square image", visible=False)
|
280 |
+
|
281 |
+
segbox = gr.Image(type="pil", label="Seg Map")
|
282 |
+
seg_process_mode = gr.Radio(
|
283 |
+
["Crop", "Resize", "Pad", "Default"],
|
284 |
+
value="Default",
|
285 |
+
label="Preprocess for non-square Seg Map", visible=False)
|
286 |
+
|
287 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
288 |
+
gr.Examples(examples=[
|
289 |
+
[f"{cur_dir}/examples/3.jpg", f"{cur_dir}/examples/3_pan.png", "What objects can be seen in the image?"],
|
290 |
+
[f"{cur_dir}/examples/3.jpg", f"{cur_dir}/examples/3_ins.png", "What objects can be seen in the image?"],
|
291 |
+
], inputs=[imagebox, segbox, textbox])
|
292 |
+
|
293 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
294 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
|
295 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
296 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
297 |
+
|
298 |
+
with gr.Column(scale=8):
|
299 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="VCoder Chatbot", height=550)
|
300 |
+
with gr.Row():
|
301 |
+
with gr.Column(scale=8):
|
302 |
+
textbox.render()
|
303 |
+
with gr.Column(scale=1, min_width=50):
|
304 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
305 |
+
with gr.Row(elem_id="buttons") as button_row:
|
306 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
307 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
308 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
309 |
+
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
310 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
311 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
312 |
+
|
313 |
+
if not embed_mode:
|
314 |
+
gr.Markdown(tos_markdown)
|
315 |
+
gr.Markdown(learn_more_markdown)
|
316 |
+
|
317 |
+
# Register listeners
|
318 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
319 |
+
upvote_btn.click(upvote_last_response,
|
320 |
+
[state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
|
321 |
+
downvote_btn.click(downvote_last_response,
|
322 |
+
[state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
|
323 |
+
flag_btn.click(flag_last_response,
|
324 |
+
[state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
|
325 |
+
regenerate_btn.click(regenerate, [state, image_process_mode, seg_process_mode],
|
326 |
+
[state, chatbot, textbox, imagebox, segbox] + btn_list).then(
|
327 |
+
http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
|
328 |
+
[state, chatbot] + btn_list)
|
329 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, segbox] + btn_list)
|
330 |
+
|
331 |
+
textbox.submit(add_text, [state, textbox, imagebox, image_process_mode, segbox, seg_process_mode], [state, chatbot, textbox, imagebox, segbox] + btn_list
|
332 |
+
).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
|
333 |
+
[state, chatbot] + btn_list)
|
334 |
+
submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode, segbox, seg_process_mode], [state, chatbot, textbox, imagebox, segbox] + btn_list
|
335 |
+
).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
|
336 |
+
[state, chatbot] + btn_list)
|
337 |
+
|
338 |
+
demo.load(load_demo_refresh_model_list, None, [state, model_selector])
|
339 |
+
|
340 |
+
return demo
|
341 |
+
|
342 |
+
|
343 |
+
if __name__ == "__main__":
|
344 |
+
parser = argparse.ArgumentParser()
|
345 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
346 |
+
parser.add_argument("--model-base", type=str, default=None)
|
347 |
+
parser.add_argument("--model-name", type=str)
|
348 |
+
parser.add_argument("--load-8bit", action="store_true")
|
349 |
+
parser.add_argument("--load-4bit", action="store_true")
|
350 |
+
parser.add_argument("--device", type=str, default="cuda")
|
351 |
+
parser.add_argument("--share", action="store_true")
|
352 |
+
parser.add_argument("--moderate", action="store_true")
|
353 |
+
parser.add_argument("--embed", action="store_true")
|
354 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
355 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
356 |
+
parser.add_argument("--port", type=int)
|
357 |
+
args = parser.parse_args()
|
358 |
+
logger.info(f"args: {args}")
|
359 |
+
|
360 |
+
if args.model_name is None:
|
361 |
+
model_paths = args.model_path.split("/")
|
362 |
+
if model_paths[-1].startswith('checkpoint-'):
|
363 |
+
model_name = model_paths[-2] + "_" + model_paths[-1]
|
364 |
+
else:
|
365 |
+
model_name = model_paths[-1]
|
366 |
+
else:
|
367 |
+
model_name = args.model_name
|
368 |
+
|
369 |
+
models = [model_name]
|
370 |
+
chat = Chat(
|
371 |
+
args.model_path,
|
372 |
+
args.model_base,
|
373 |
+
args.model_name,
|
374 |
+
args.load_8bit,
|
375 |
+
args.load_4bit,
|
376 |
+
args.device,
|
377 |
+
logger
|
378 |
+
)
|
379 |
+
|
380 |
+
logger.info(args)
|
381 |
+
demo = build_demo(args.embed)
|
382 |
+
demo.queue(
|
383 |
+
concurrency_count=args.concurrency_count,
|
384 |
+
api_open=False
|
385 |
+
).launch(
|
386 |
+
server_name=args.host,
|
387 |
+
server_port=args.port,
|
388 |
+
share=args.share
|
389 |
+
)
|
chat.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from vcoder_llava.utils import server_error_msg
|
9 |
+
from vcoder_llava.model.builder import load_pretrained_model
|
10 |
+
from vcoder_llava.mm_utils import process_images, load_image_from_base64, tokenizer_seg_token, tokenizer_depth_seg_token, tokenizer_image_token, KeywordsStoppingCriteria
|
11 |
+
from vcoder_llava.constants import (
|
12 |
+
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
|
13 |
+
SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
|
14 |
+
DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN
|
15 |
+
)
|
16 |
+
from transformers import TextIteratorStreamer
|
17 |
+
|
18 |
+
class Chat:
|
19 |
+
def __init__(self, model_path, model_base, model_name,
|
20 |
+
load_8bit, load_4bit, device, logger):
|
21 |
+
if model_path.endswith("/"):
|
22 |
+
model_path = model_path[:-1]
|
23 |
+
if model_name is None:
|
24 |
+
model_paths = model_path.split("/")
|
25 |
+
if model_paths[-1].startswith('checkpoint-'):
|
26 |
+
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
27 |
+
else:
|
28 |
+
self.model_name = model_paths[-1]
|
29 |
+
else:
|
30 |
+
self.model_name = model_name
|
31 |
+
|
32 |
+
self.device = device
|
33 |
+
logger.info(f"Loading the model {self.model_name} ...")
|
34 |
+
self.tokenizer, self.model, self.image_processor, self.seg_image_processor, self.depth_image_processor, self.context_len = load_pretrained_model(
|
35 |
+
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
36 |
+
self.is_multimodal = 'llava' in self.model_name.lower()
|
37 |
+
self.is_seg = "seg_llava" in self.model_name.lower()
|
38 |
+
self.is_depth = False
|
39 |
+
|
40 |
+
@torch.inference_mode()
|
41 |
+
def generate_stream(self, params):
|
42 |
+
tokenizer, model, image_processor, seg_image_processor, depth_image_processor = self.tokenizer, self.model, self.image_processor, self.seg_image_processor, self.depth_image_processor
|
43 |
+
|
44 |
+
prompt = params["prompt"]
|
45 |
+
ori_prompt = prompt
|
46 |
+
images = params.get("images", None)
|
47 |
+
segs = params.get("segs", None)
|
48 |
+
depths = params.get("depths", None)
|
49 |
+
num_image_tokens = 0
|
50 |
+
num_seg_tokens = 0
|
51 |
+
num_depth_tokens = 0
|
52 |
+
if images is not None and len(images) > 0 and self.is_multimodal:
|
53 |
+
if len(images) > 0:
|
54 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
55 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
56 |
+
|
57 |
+
images = [load_image_from_base64(image) for image in images]
|
58 |
+
images = process_images(images, image_processor, model.config)
|
59 |
+
|
60 |
+
if type(images) is list:
|
61 |
+
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
62 |
+
else:
|
63 |
+
images = images.to(self.model.device, dtype=torch.float16)
|
64 |
+
|
65 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
66 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
67 |
+
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
|
68 |
+
|
69 |
+
if segs is not None and len(segs) > 0 and self.is_seg:
|
70 |
+
if len(segs) != prompt.count(DEFAULT_SEG_TOKEN):
|
71 |
+
raise ValueError("Number of segs does not match number of <seg> tokens in prompt")
|
72 |
+
|
73 |
+
segs = [load_image_from_base64(seg) for seg in segs]
|
74 |
+
segs = process_images(segs, seg_image_processor, model.config)
|
75 |
+
|
76 |
+
if type(segs) is list:
|
77 |
+
segs = [seg.to(self.model.device, dtype=torch.float16) for seg in segs]
|
78 |
+
else:
|
79 |
+
segs = segs.to(self.model.device, dtype=torch.float16)
|
80 |
+
|
81 |
+
replace_seg_token = DEFAULT_SEG_TOKEN
|
82 |
+
prompt = prompt.replace(DEFAULT_SEG_TOKEN, replace_seg_token)
|
83 |
+
num_seg_tokens = prompt.count(replace_seg_token) * model.get_vision_tower().num_patches
|
84 |
+
|
85 |
+
if depths is not None and len(depths) > 0 and self.is_depth:
|
86 |
+
if len(depths) != prompt.count(DEFAULT_DEPTH_TOKEN):
|
87 |
+
raise ValueError("Number of depths does not match number of <depth> tokens in prompt")
|
88 |
+
|
89 |
+
depths = [load_image_from_base64(depth) for depth in depths]
|
90 |
+
depths = process_images(depths, depth_image_processor, model.config)
|
91 |
+
|
92 |
+
if type(depths) is list:
|
93 |
+
depths = [depth.to(self.model.device, dtype=torch.float16) for depth in depths]
|
94 |
+
else:
|
95 |
+
depths = depths.to(self.model.device, dtype=torch.float16)
|
96 |
+
|
97 |
+
replace_depth_token = DEFAULT_DEPTH_TOKEN
|
98 |
+
prompt = prompt.replace(DEFAULT_DEPTH_TOKEN, replace_depth_token)
|
99 |
+
num_depth_tokens = prompt.count(replace_depth_token) * model.get_vision_tower().num_patches
|
100 |
+
else:
|
101 |
+
depths = None
|
102 |
+
else:
|
103 |
+
segs = None
|
104 |
+
depths = None
|
105 |
+
else:
|
106 |
+
images = None
|
107 |
+
segs = None
|
108 |
+
depths = None
|
109 |
+
image_args = {"images": images, "segs": segs, "depths": depths}
|
110 |
+
else:
|
111 |
+
images = None
|
112 |
+
segs = None
|
113 |
+
depths = None
|
114 |
+
image_args = {}
|
115 |
+
|
116 |
+
temperature = float(params.get("temperature", 1.0))
|
117 |
+
top_p = float(params.get("top_p", 1.0))
|
118 |
+
max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
|
119 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
120 |
+
stop_str = params.get("stop", None)
|
121 |
+
do_sample = True if temperature > 0.001 else False
|
122 |
+
|
123 |
+
if self.is_seg:
|
124 |
+
if self.is_depth:
|
125 |
+
input_ids = tokenizer_depth_seg_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX, DEPTH_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
126 |
+
else:
|
127 |
+
input_ids = tokenizer_seg_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
128 |
+
else:
|
129 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
130 |
+
keywords = [stop_str]
|
131 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
132 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
133 |
+
|
134 |
+
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens - num_seg_tokens - num_depth_tokens)
|
135 |
+
|
136 |
+
if max_new_tokens < 1:
|
137 |
+
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
138 |
+
return
|
139 |
+
|
140 |
+
generated_text = model.generate(
|
141 |
+
inputs=input_ids,
|
142 |
+
do_sample=do_sample,
|
143 |
+
temperature=temperature,
|
144 |
+
top_p=top_p,
|
145 |
+
max_new_tokens=max_new_tokens,
|
146 |
+
streamer=streamer,
|
147 |
+
stopping_criteria=[stopping_criteria],
|
148 |
+
use_cache=True,
|
149 |
+
**image_args
|
150 |
+
)
|
151 |
+
# thread.start()
|
152 |
+
|
153 |
+
generated_text = ori_prompt
|
154 |
+
for new_text in streamer:
|
155 |
+
generated_text += new_text
|
156 |
+
if generated_text.endswith(stop_str):
|
157 |
+
generated_text = generated_text[:-len(stop_str)]
|
158 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
159 |
+
|
160 |
+
def generate_stream_gate(self, params):
|
161 |
+
try:
|
162 |
+
for x in self.generate_stream(params):
|
163 |
+
yield x
|
164 |
+
except ValueError as e:
|
165 |
+
print("Caught ValueError:", e)
|
166 |
+
ret = {
|
167 |
+
"text": server_error_msg,
|
168 |
+
"error_code": 1,
|
169 |
+
}
|
170 |
+
yield json.dumps(ret).encode()
|
171 |
+
except torch.cuda.CudaError as e:
|
172 |
+
print("Caught torch.cuda.CudaError:", e)
|
173 |
+
ret = {
|
174 |
+
"text": server_error_msg,
|
175 |
+
"error_code": 1,
|
176 |
+
}
|
177 |
+
yield json.dumps(ret).encode()
|
178 |
+
except Exception as e:
|
179 |
+
print("Caught Unknown Error", e)
|
180 |
+
ret = {
|
181 |
+
"text": server_error_msg,
|
182 |
+
"error_code": 1,
|
183 |
+
}
|
184 |
+
yield json.dumps(ret).encode()
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
parser = argparse.ArgumentParser()
|
189 |
+
parser.add_argument("--host", type=str, default="localhost")
|
190 |
+
parser.add_argument("--port", type=int, default=21002)
|
191 |
+
parser.add_argument("--worker-address", type=str,
|
192 |
+
default="http://localhost:21002")
|
193 |
+
parser.add_argument("--controller-address", type=str,
|
194 |
+
default="http://localhost:21001")
|
195 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
196 |
+
parser.add_argument("--model-base", type=str, default=None)
|
197 |
+
parser.add_argument("--model-name", type=str)
|
198 |
+
parser.add_argument("--device", type=str, default="cuda")
|
199 |
+
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
|
200 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
201 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
202 |
+
parser.add_argument("--no-register", action="store_true")
|
203 |
+
parser.add_argument("--load-8bit", action="store_true")
|
204 |
+
parser.add_argument("--load-4bit", action="store_true")
|
205 |
+
args = parser.parse_args()
|
examples/3.jpg
ADDED
Git LFS Details
|
examples/3_ins.png
ADDED
Git LFS Details
|
examples/3_pan.png
ADDED
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu117
|
2 |
+
torch==2.0.0+cu117
|
3 |
+
packaging
|
4 |
+
Pillow
|
5 |
+
huggingface_hub
|
6 |
+
matplotlib
|
7 |
+
flash-attn
|
8 |
+
gradio
|
9 |
+
fastapi
|
10 |
+
numpy,
|
11 |
+
requests
|
12 |
+
sentencepiece
|
13 |
+
tokenizers>=0.12.1,
|
14 |
+
uvicorn
|
15 |
+
chardet,
|
16 |
+
shortuuid
|
17 |
+
httpx==0.24.0,
|
18 |
+
spacy
|
19 |
+
inflect
|
20 |
+
peft==0.4.0
|
21 |
+
num2words,
|
22 |
+
transformers==4.31.0,
|
23 |
+
accelerate==0.21.0,
|
24 |
+
bitsandbytes==0.41.0,
|
25 |
+
scikit-learn==1.2.2,
|
26 |
+
sentencepiece==0.1.99,
|
27 |
+
einops==0.6.1
|
28 |
+
einops-exts==0.0.4
|
29 |
+
timm==0.6.13,
|
30 |
+
gradio_client==0.2.9
|
vcoder_llava/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
vcoder_llava/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import LlavaLlamaForCausalLM, VCoderLlavaLlamaForCausalLM, VCoderDSLlavaLlamaForCausalLM
|
vcoder_llava/constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LOGDIR = "."
|
2 |
+
|
3 |
+
# Model Constants
|
4 |
+
IGNORE_INDEX = -100
|
5 |
+
IMAGE_TOKEN_INDEX = -200
|
6 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
7 |
+
|
8 |
+
SEG_TOKEN_INDEX = -300
|
9 |
+
DEFAULT_SEG_TOKEN = "<seg>"
|
10 |
+
|
11 |
+
DEPTH_TOKEN_INDEX = -400
|
12 |
+
DEFAULT_DEPTH_TOKEN = "<depth>"
|
vcoder_llava/data_utils.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import spacy
|
3 |
+
from word2number import w2n
|
4 |
+
import inflect
|
5 |
+
from num2words import num2words
|
6 |
+
p = inflect.engine()
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
|
10 |
+
nltk.download('punkt')
|
11 |
+
nltk.download('averaged_perceptron_tagger')
|
12 |
+
nlp = spacy.load('en_core_web_sm')
|
13 |
+
|
14 |
+
# object names with two words
|
15 |
+
SPECIAL_WORDS = ['baseball bat',
|
16 |
+
'baseball glove',
|
17 |
+
'cell phone',
|
18 |
+
'dining table',
|
19 |
+
'fire hydrant',
|
20 |
+
'french fries',
|
21 |
+
'hair drier',
|
22 |
+
'hot dog',
|
23 |
+
'parking meter',
|
24 |
+
'potted plant',
|
25 |
+
'soccer ball',
|
26 |
+
'soccer player',
|
27 |
+
'sports ball',
|
28 |
+
'stop sign',
|
29 |
+
'teddy bear',
|
30 |
+
'tennis racket',
|
31 |
+
'toy figure',
|
32 |
+
'traffic light',
|
33 |
+
'wine glass']
|
34 |
+
|
35 |
+
def _get_nouns(lines):
|
36 |
+
# function to test if something is a noun
|
37 |
+
present_words = []
|
38 |
+
for s in SPECIAL_WORDS:
|
39 |
+
if s in lines:
|
40 |
+
present_words.append(s)
|
41 |
+
|
42 |
+
for w in present_words:
|
43 |
+
lines = lines.replace(w, "")
|
44 |
+
|
45 |
+
is_noun = lambda pos: pos[:2] == 'NN' or pos[:2] == 'NNP'
|
46 |
+
# do the nlp stuff
|
47 |
+
tokenized = nltk.word_tokenize(lines)
|
48 |
+
nouns = [word for (word, pos) in nltk.pos_tag(tokenized) if is_noun(pos)]
|
49 |
+
noun_dict = {}
|
50 |
+
if "objects" in nouns:
|
51 |
+
nouns.remove("objects")
|
52 |
+
if "image" in nouns:
|
53 |
+
nouns.remove("image")
|
54 |
+
|
55 |
+
for n in nouns:
|
56 |
+
if n not in noun_dict.keys():
|
57 |
+
noun_dict[n] = 1
|
58 |
+
else:
|
59 |
+
noun_dict[n] += 1
|
60 |
+
nouns = {}
|
61 |
+
for k, v in noun_dict.items():
|
62 |
+
if not (k == "bus" or k == "skis"):
|
63 |
+
if v == 1:
|
64 |
+
if p.singular_noun(k):
|
65 |
+
k = p.singular_noun(k)
|
66 |
+
else:
|
67 |
+
if not p.singular_noun(k):
|
68 |
+
k = p.plural(k)
|
69 |
+
try:
|
70 |
+
w2n.word_to_num(k)
|
71 |
+
except:
|
72 |
+
if len(k) >= 3:
|
73 |
+
if k == "ski":
|
74 |
+
k = "skis"
|
75 |
+
elif k == "gras":
|
76 |
+
k = "grass"
|
77 |
+
nouns[k] = v
|
78 |
+
for w in present_words:
|
79 |
+
nouns[w] = 1
|
80 |
+
return nouns
|
81 |
+
|
82 |
+
def _get_num_nouns(lines):
|
83 |
+
lines = lines.replace(":", "").replace(".", "")
|
84 |
+
doc = nlp(lines)
|
85 |
+
num_nouns = [chunk.text for chunk in doc.noun_chunks if any(token.pos_ == 'NUM' for token in chunk)]
|
86 |
+
|
87 |
+
num_noun_dict = {}
|
88 |
+
for n in num_nouns:
|
89 |
+
nums = n.split(", ")
|
90 |
+
for n in nums:
|
91 |
+
try:
|
92 |
+
w = " ".join(n.split(' ')[1:])
|
93 |
+
if w == "ski":
|
94 |
+
w = "skis"
|
95 |
+
num_noun_dict[w] = w2n.word_to_num(n.split(' ')[0])
|
96 |
+
except:
|
97 |
+
pass
|
98 |
+
|
99 |
+
return num_noun_dict
|
100 |
+
|
101 |
+
|
102 |
+
def _obtain_nouns(gt):
|
103 |
+
gt = gt.replace("hair dryer", "hair drier").lower()
|
104 |
+
nouns_gt = _get_nouns(gt)
|
105 |
+
|
106 |
+
num_nouns_gt = _get_num_nouns(gt)
|
107 |
+
|
108 |
+
com_keys = []
|
109 |
+
for k in nouns_gt.keys():
|
110 |
+
if p.plural(k) in num_nouns_gt.keys():
|
111 |
+
com_keys.append(k)
|
112 |
+
for k in com_keys:
|
113 |
+
del nouns_gt[k]
|
114 |
+
|
115 |
+
num_nouns_gt = {**num_nouns_gt, **nouns_gt}
|
116 |
+
|
117 |
+
return num_nouns_gt
|
118 |
+
|
119 |
+
def generate_qa_pairs(text):
|
120 |
+
num_nouns = _obtain_nouns(text)
|
121 |
+
qa_pairs = []
|
122 |
+
|
123 |
+
for obj, count in num_nouns.items():
|
124 |
+
# Count question
|
125 |
+
if count == 1:
|
126 |
+
plural_obj = p.plural(obj)
|
127 |
+
else:
|
128 |
+
plural_obj = obj
|
129 |
+
count_question = f"How many {plural_obj} are there in the image?"
|
130 |
+
count_answer = f"There {'is' if count == 1 else 'are'} {num2words(count)} {obj} in the image."
|
131 |
+
qa_pairs.append((count_question, count_answer))
|
132 |
+
|
133 |
+
prob_positive = np.random.uniform(0,1.)
|
134 |
+
|
135 |
+
if prob_positive > 0.7 or count == 1:
|
136 |
+
numeric_presence_question = f"{'Is' if count == 1 else 'Are'} there {num2words(count)} {obj} in the image?"
|
137 |
+
numeric_presence_answer = "Yes."
|
138 |
+
elif count > 1:
|
139 |
+
numbers = [i for i in range(2, count + 6) if i != count]
|
140 |
+
# Select a random number from the range
|
141 |
+
cnt = random.choice(numbers)
|
142 |
+
numeric_presence_question = f"{'Is' if cnt == 1 else 'Are'} there {num2words(cnt)} {obj} in the image?"
|
143 |
+
numeric_presence_answer = "No."
|
144 |
+
|
145 |
+
qa_pairs.append((numeric_presence_question, numeric_presence_answer))
|
146 |
+
random.shuffle(qa_pairs)
|
147 |
+
|
148 |
+
return random.sample(qa_pairs, min(len(qa_pairs), random.choice([1, 2, 3, 4, 5, 6])))
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
|
152 |
+
text = "The objects present in the image are: wall, ceiling, shelf, cabinet, counter, dining table, two people, eighteen bottles, two wine glasses, refrigerator, tv, bowl"
|
153 |
+
|
154 |
+
qa = generate_qa_pairs(text)
|
155 |
+
from icecream import ic
|
156 |
+
ic(qa)
|
157 |
+
|
vcoder_llava/mm_utils.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import StoppingCriteria
|
7 |
+
from vcoder_llava.constants import IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX, DEPTH_TOKEN_INDEX
|
8 |
+
|
9 |
+
|
10 |
+
def load_image_from_base64(image):
|
11 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
12 |
+
|
13 |
+
|
14 |
+
def expand2square(pil_img, background_color):
|
15 |
+
width, height = pil_img.size
|
16 |
+
if width == height:
|
17 |
+
return pil_img
|
18 |
+
elif width > height:
|
19 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
20 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
21 |
+
return result
|
22 |
+
else:
|
23 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
24 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
25 |
+
return result
|
26 |
+
|
27 |
+
|
28 |
+
def process_images(images, image_processor, model_cfg):
|
29 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
30 |
+
new_images = []
|
31 |
+
if image_aspect_ratio == 'pad':
|
32 |
+
for image in images:
|
33 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
34 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
35 |
+
new_images.append(image)
|
36 |
+
else:
|
37 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
38 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
39 |
+
new_images = torch.stack(new_images, dim=0)
|
40 |
+
return new_images
|
41 |
+
|
42 |
+
|
43 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
44 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
45 |
+
|
46 |
+
def insert_separator(X, sep):
|
47 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
48 |
+
|
49 |
+
input_ids = []
|
50 |
+
offset = 0
|
51 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
52 |
+
offset = 1
|
53 |
+
input_ids.append(prompt_chunks[0][0])
|
54 |
+
|
55 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
56 |
+
input_ids.extend(x[offset:])
|
57 |
+
|
58 |
+
if return_tensors is not None:
|
59 |
+
if return_tensors == 'pt':
|
60 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
61 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
62 |
+
return input_ids
|
63 |
+
|
64 |
+
|
65 |
+
def tokenizer_seg_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, seg_token_index=SEG_TOKEN_INDEX, return_tensors=None):
|
66 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<seg>\n<image>')]
|
67 |
+
|
68 |
+
def insert_separator(X, sep):
|
69 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
70 |
+
|
71 |
+
input_ids = []
|
72 |
+
offset = 0
|
73 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
74 |
+
offset = 1
|
75 |
+
input_ids.append(prompt_chunks[0][0])
|
76 |
+
|
77 |
+
for x in insert_separator(prompt_chunks, [seg_token_index, image_token_index] * (offset + 1)):
|
78 |
+
if seg_token_index in x:
|
79 |
+
input_ids.extend(x[offset:-1])
|
80 |
+
else:
|
81 |
+
input_ids.extend(x[offset:])
|
82 |
+
|
83 |
+
if return_tensors is not None:
|
84 |
+
if return_tensors == 'pt':
|
85 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
86 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
87 |
+
return input_ids
|
88 |
+
|
89 |
+
def _tokenizer_depth_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, seg_token_index=SEG_TOKEN_INDEX, depth_token_index=DEPTH_TOKEN_INDEX, return_tensors=None):
|
90 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<depth>\n<seg>\n<image>')]
|
91 |
+
|
92 |
+
def insert_separator(X, sep):
|
93 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
94 |
+
|
95 |
+
input_ids = []
|
96 |
+
offset = 0
|
97 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
98 |
+
offset = 1
|
99 |
+
input_ids.append(prompt_chunks[0][0])
|
100 |
+
|
101 |
+
for x in insert_separator(prompt_chunks, [image_token_index, depth_token_index, seg_token_index] * (offset + 1)):
|
102 |
+
if depth_token_index in x and seg_token_index in x:
|
103 |
+
input_ids.extend(x[:3])
|
104 |
+
else:
|
105 |
+
input_ids.extend(x[offset:])
|
106 |
+
|
107 |
+
if return_tensors is not None:
|
108 |
+
if return_tensors == 'pt':
|
109 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
110 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
111 |
+
return input_ids
|
112 |
+
|
113 |
+
def tokenizer_depth_seg_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, seg_token_index=SEG_TOKEN_INDEX, depth_token_index=DEPTH_TOKEN_INDEX, return_tensors=None):
|
114 |
+
if "<depth>" in prompt:
|
115 |
+
return _tokenizer_depth_token(prompt, tokenizer, image_token_index, seg_token_index, depth_token_index, return_tensors)
|
116 |
+
else:
|
117 |
+
return tokenizer_seg_token(prompt, tokenizer, image_token_index, seg_token_index, return_tensors)
|
118 |
+
|
119 |
+
|
120 |
+
def get_model_name_from_path(model_path):
|
121 |
+
model_path = model_path.strip("/")
|
122 |
+
model_paths = model_path.split("/")
|
123 |
+
if model_paths[-1].startswith('checkpoint-'):
|
124 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
125 |
+
else:
|
126 |
+
return model_paths[-1]
|
127 |
+
|
128 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
129 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
130 |
+
self.keywords = keywords
|
131 |
+
self.keyword_ids = []
|
132 |
+
for keyword in keywords:
|
133 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
134 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
135 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
136 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
137 |
+
self.tokenizer = tokenizer
|
138 |
+
self.start_len = input_ids.shape[1]
|
139 |
+
|
140 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
141 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
142 |
+
offset = min(output_ids.shape[1] - self.start_len, 3)
|
143 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
144 |
+
for keyword_id in self.keyword_ids:
|
145 |
+
if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
|
146 |
+
return True
|
147 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
148 |
+
for keyword in self.keywords:
|
149 |
+
if keyword in outputs:
|
150 |
+
return True
|
151 |
+
return False
|
vcoder_llava/model/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
vcoder_llava/model/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
|
2 |
+
from .language_model.vcoder_llava_llama import VCoderLlavaLlamaForCausalLM, VCoderLlavaConfig
|
3 |
+
from .language_model.vcoder_ds_llava_llama import VCoderDSLlavaLlamaForCausalLM, VCoderDSLlavaConfig
|
vcoder_llava/model/apply_delta.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
from vcoder_llava import LlavaLlamaForCausalLM
|
11 |
+
|
12 |
+
|
13 |
+
def apply_delta(base_model_path, target_model_path, delta_path):
|
14 |
+
print("Loading base model")
|
15 |
+
base = AutoModelForCausalLM.from_pretrained(
|
16 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading delta")
|
19 |
+
delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
20 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
21 |
+
|
22 |
+
print("Applying delta")
|
23 |
+
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data += base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
|
31 |
+
f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
32 |
+
bparam = base.state_dict()[name]
|
33 |
+
param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
|
34 |
+
|
35 |
+
print("Saving target model")
|
36 |
+
delta.save_pretrained(target_model_path)
|
37 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
parser = argparse.ArgumentParser()
|
42 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
43 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
44 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
vcoder_llava/model/builder.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import warnings
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
+
import torch
|
22 |
+
from vcoder_llava.model import *
|
23 |
+
|
24 |
+
|
25 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
|
26 |
+
kwargs = {"device_map": device_map}
|
27 |
+
|
28 |
+
if load_8bit:
|
29 |
+
kwargs['load_in_8bit'] = True
|
30 |
+
elif load_4bit:
|
31 |
+
kwargs['load_in_4bit'] = True
|
32 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
33 |
+
load_in_4bit=True,
|
34 |
+
bnb_4bit_compute_dtype=torch.float16,
|
35 |
+
bnb_4bit_use_double_quant=True,
|
36 |
+
bnb_4bit_quant_type='nf4'
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
kwargs['torch_dtype'] = torch.float16
|
40 |
+
if 'llava' in model_name.lower():
|
41 |
+
# Load LLaVA model
|
42 |
+
if 'lora' in model_name.lower() and model_base is None:
|
43 |
+
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
44 |
+
if 'lora' in model_name.lower() and model_base is not None:
|
45 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
47 |
+
print('Loading LLaVA from base model...')
|
48 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
49 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
50 |
+
if model.lm_head.weight.shape[0] != token_num:
|
51 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
52 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
53 |
+
|
54 |
+
print('Loading additional LLaVA weights...')
|
55 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
56 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
57 |
+
else:
|
58 |
+
# this is probably from HF Hub
|
59 |
+
from huggingface_hub import hf_hub_download
|
60 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
61 |
+
cache_file = hf_hub_download(
|
62 |
+
repo_id=repo_id,
|
63 |
+
filename=filename,
|
64 |
+
subfolder=subfolder)
|
65 |
+
return torch.load(cache_file, map_location='cpu')
|
66 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
67 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
68 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
69 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
70 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
71 |
+
|
72 |
+
from peft import PeftModel
|
73 |
+
print('Loading LoRA weights...')
|
74 |
+
model = PeftModel.from_pretrained(model, model_path)
|
75 |
+
print('Merging LoRA weights...')
|
76 |
+
model = model.merge_and_unload()
|
77 |
+
print('Model is loaded...')
|
78 |
+
elif model_base is not None:
|
79 |
+
# this may be mm projector only
|
80 |
+
print('Loading LLaVA from base model...')
|
81 |
+
if 'vcoder_ds_llava' in model_name.lower():
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
83 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
84 |
+
model = VCoderDSLlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
85 |
+
elif 'vcoder_llava' in model_name.lower():
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
87 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
88 |
+
model = VCoderLlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
89 |
+
else:
|
90 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
91 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
92 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
93 |
+
|
94 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
95 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
96 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
97 |
+
else:
|
98 |
+
if 'vcoder_ds_llava' in model_name.lower():
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
100 |
+
model = VCoderDSLlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
101 |
+
elif 'vcoder_llava' in model_name.lower():
|
102 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
103 |
+
model = VCoderLlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
104 |
+
else:
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
106 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
107 |
+
else:
|
108 |
+
# Load language model
|
109 |
+
if model_base is not None:
|
110 |
+
# PEFT model
|
111 |
+
from peft import PeftModel
|
112 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
113 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
|
114 |
+
print(f"Loading LoRA weights from {model_path}")
|
115 |
+
model = PeftModel.from_pretrained(model, model_path)
|
116 |
+
print(f"Merging weights")
|
117 |
+
model = model.merge_and_unload()
|
118 |
+
print('Convert to FP16...')
|
119 |
+
model.to(torch.float16)
|
120 |
+
else:
|
121 |
+
use_fast = False
|
122 |
+
if 'mpt' in model_name.lower():
|
123 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
124 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
125 |
+
else:
|
126 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
127 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
128 |
+
|
129 |
+
image_processor = None
|
130 |
+
|
131 |
+
if hasattr(model.config, "max_sequence_length"):
|
132 |
+
context_len = model.config.max_sequence_length
|
133 |
+
else:
|
134 |
+
context_len = 2048
|
135 |
+
|
136 |
+
if 'llava' in model_name.lower():
|
137 |
+
vision_tower = model.get_vision_tower()
|
138 |
+
if not vision_tower.is_loaded:
|
139 |
+
vision_tower.load_model()
|
140 |
+
vision_tower.to(device=device, dtype=torch.float16)
|
141 |
+
image_processor = vision_tower.image_processor
|
142 |
+
|
143 |
+
seg_image_processor = None
|
144 |
+
if 'vcoder' in model_name.lower():
|
145 |
+
seg_image_processor = image_processor
|
146 |
+
|
147 |
+
depth_image_processor = None
|
148 |
+
if "ds" in model_name.lower():
|
149 |
+
depth_image_processor = image_processor
|
150 |
+
|
151 |
+
model.requires_grad_(False)
|
152 |
+
return tokenizer, model, image_processor, seg_image_processor, depth_image_processor, context_len
|
vcoder_llava/model/consolidate.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m vcoder_llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
+
from vcoder_llava.model import *
|
10 |
+
from vcoder_llava.model.utils import auto_upgrade
|
11 |
+
|
12 |
+
|
13 |
+
def consolidate_ckpt(src_path, dst_path):
|
14 |
+
print("Loading model")
|
15 |
+
auto_upgrade(src_path)
|
16 |
+
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
|
18 |
+
src_model.save_pretrained(dst_path)
|
19 |
+
src_tokenizer.save_pretrained(dst_path)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("--src", type=str, required=True)
|
25 |
+
parser.add_argument("--dst", type=str, required=True)
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
consolidate_ckpt(args.src, args.dst)
|
vcoder_llava/model/language_model/llava_llama.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
23 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
24 |
+
|
25 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
26 |
+
|
27 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class LlavaConfig(LlamaConfig):
|
31 |
+
model_type = "llava"
|
32 |
+
|
33 |
+
|
34 |
+
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
35 |
+
config_class = LlavaConfig
|
36 |
+
|
37 |
+
def __init__(self, config: LlamaConfig):
|
38 |
+
super(LlavaLlamaModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
42 |
+
config_class = LlavaConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(LlamaForCausalLM, self).__init__(config)
|
46 |
+
self.model = LlavaLlamaModel(config)
|
47 |
+
|
48 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
49 |
+
|
50 |
+
# Initialize weights and apply final processing
|
51 |
+
self.post_init()
|
52 |
+
|
53 |
+
def get_model(self):
|
54 |
+
return self.model
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
input_ids: torch.LongTensor = None,
|
59 |
+
attention_mask: Optional[torch.Tensor] = None,
|
60 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
61 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
62 |
+
labels: Optional[torch.LongTensor] = None,
|
63 |
+
use_cache: Optional[bool] = None,
|
64 |
+
output_attentions: Optional[bool] = None,
|
65 |
+
output_hidden_states: Optional[bool] = None,
|
66 |
+
images: Optional[torch.FloatTensor] = None,
|
67 |
+
images_cd: Optional[torch.FloatTensor] = None,
|
68 |
+
cd_beta: Optional[torch.FloatTensor] = None,
|
69 |
+
cd_alpha: Optional[torch.FloatTensor] = None,
|
70 |
+
return_dict: Optional[bool] = None,
|
71 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
72 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
73 |
+
output_hidden_states = (
|
74 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
75 |
+
)
|
76 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
77 |
+
|
78 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
79 |
+
|
80 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
81 |
+
outputs = self.model(
|
82 |
+
input_ids=input_ids,
|
83 |
+
attention_mask=attention_mask,
|
84 |
+
past_key_values=past_key_values,
|
85 |
+
inputs_embeds=inputs_embeds,
|
86 |
+
use_cache=use_cache,
|
87 |
+
output_attentions=output_attentions,
|
88 |
+
output_hidden_states=output_hidden_states,
|
89 |
+
return_dict=return_dict
|
90 |
+
)
|
91 |
+
|
92 |
+
hidden_states = outputs[0]
|
93 |
+
logits = self.lm_head(hidden_states)
|
94 |
+
|
95 |
+
loss = None
|
96 |
+
if labels is not None:
|
97 |
+
# Shift so that tokens < n predict n
|
98 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
99 |
+
shift_labels = labels[..., 1:].contiguous()
|
100 |
+
# Flatten the tokens
|
101 |
+
loss_fct = CrossEntropyLoss()
|
102 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
103 |
+
shift_labels = shift_labels.view(-1)
|
104 |
+
# Enable model/pipeline parallelism
|
105 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
106 |
+
loss = loss_fct(shift_logits, shift_labels)
|
107 |
+
|
108 |
+
if not return_dict:
|
109 |
+
output = (logits,) + outputs[1:]
|
110 |
+
return (loss,) + output if loss is not None else output
|
111 |
+
|
112 |
+
return CausalLMOutputWithPast(
|
113 |
+
loss=loss,
|
114 |
+
logits=logits,
|
115 |
+
past_key_values=outputs.past_key_values,
|
116 |
+
hidden_states=outputs.hidden_states,
|
117 |
+
attentions=outputs.attentions,
|
118 |
+
)
|
119 |
+
|
120 |
+
def prepare_inputs_for_generation(
|
121 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
122 |
+
):
|
123 |
+
if past_key_values:
|
124 |
+
input_ids = input_ids[:, -1:]
|
125 |
+
|
126 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
127 |
+
if inputs_embeds is not None and past_key_values is None:
|
128 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
129 |
+
else:
|
130 |
+
model_inputs = {"input_ids": input_ids}
|
131 |
+
|
132 |
+
model_inputs.update(
|
133 |
+
{
|
134 |
+
"past_key_values": past_key_values,
|
135 |
+
"use_cache": kwargs.get("use_cache"),
|
136 |
+
"attention_mask": attention_mask,
|
137 |
+
"images": kwargs.get("images", None),
|
138 |
+
}
|
139 |
+
)
|
140 |
+
return model_inputs
|
141 |
+
|
142 |
+
def prepare_inputs_for_generation_cd(
|
143 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
144 |
+
):
|
145 |
+
if past_key_values:
|
146 |
+
input_ids = input_ids[:, -1:]
|
147 |
+
|
148 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
149 |
+
if inputs_embeds is not None and past_key_values is None:
|
150 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
151 |
+
else:
|
152 |
+
model_inputs = {"input_ids": input_ids}
|
153 |
+
|
154 |
+
model_inputs.update(
|
155 |
+
{
|
156 |
+
"past_key_values": past_key_values,
|
157 |
+
"use_cache": kwargs.get("use_cache"),
|
158 |
+
"attention_mask": attention_mask,
|
159 |
+
"images": kwargs.get("images_cd", None),
|
160 |
+
}
|
161 |
+
)
|
162 |
+
return model_inputs
|
163 |
+
|
164 |
+
AutoConfig.register("llava", LlavaConfig)
|
165 |
+
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
vcoder_llava/model/language_model/vcoder_ds_llava_llama.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
23 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
24 |
+
|
25 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
26 |
+
|
27 |
+
from ..vcoder_ds_llava_arch import VCoderDSLlavaMetaModel, VCoderDSLlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class VCoderDSLlavaConfig(LlamaConfig):
|
31 |
+
model_type = "vcoder_ds_llava"
|
32 |
+
|
33 |
+
|
34 |
+
class VCoderDSLlavaLlamaModel(VCoderDSLlavaMetaModel, LlamaModel):
|
35 |
+
config_class = VCoderDSLlavaConfig
|
36 |
+
|
37 |
+
def __init__(self, config: LlamaConfig):
|
38 |
+
super(VCoderDSLlavaLlamaModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class VCoderDSLlavaLlamaForCausalLM(LlamaForCausalLM, VCoderDSLlavaMetaForCausalLM):
|
42 |
+
config_class = VCoderDSLlavaConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(LlamaForCausalLM, self).__init__(config)
|
46 |
+
self.model = VCoderDSLlavaLlamaModel(config)
|
47 |
+
|
48 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
49 |
+
|
50 |
+
|
51 |
+
# Initialize weights and apply final processing
|
52 |
+
self.post_init()
|
53 |
+
|
54 |
+
def get_model(self):
|
55 |
+
return self.model
|
56 |
+
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
input_ids: torch.LongTensor = None,
|
60 |
+
attention_mask: Optional[torch.Tensor] = None,
|
61 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
62 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
63 |
+
labels: Optional[torch.LongTensor] = None,
|
64 |
+
use_cache: Optional[bool] = None,
|
65 |
+
output_attentions: Optional[bool] = None,
|
66 |
+
output_hidden_states: Optional[bool] = None,
|
67 |
+
images: Optional[torch.FloatTensor] = None,
|
68 |
+
segs: Optional[torch.FloatTensor] = None,
|
69 |
+
depths: Optional[torch.FloatTensor] = None,
|
70 |
+
return_dict: Optional[bool] = None,
|
71 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
72 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
73 |
+
output_hidden_states = (
|
74 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
75 |
+
)
|
76 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
77 |
+
|
78 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, segs, depths)
|
79 |
+
|
80 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
81 |
+
outputs = self.model(
|
82 |
+
input_ids=input_ids,
|
83 |
+
attention_mask=attention_mask,
|
84 |
+
past_key_values=past_key_values,
|
85 |
+
inputs_embeds=inputs_embeds,
|
86 |
+
use_cache=use_cache,
|
87 |
+
output_attentions=output_attentions,
|
88 |
+
output_hidden_states=output_hidden_states,
|
89 |
+
return_dict=return_dict
|
90 |
+
)
|
91 |
+
|
92 |
+
hidden_states = outputs[0]
|
93 |
+
logits = self.lm_head(hidden_states)
|
94 |
+
|
95 |
+
loss = None
|
96 |
+
if labels is not None:
|
97 |
+
# Shift so that tokens < n predict n
|
98 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
99 |
+
shift_labels = labels[..., 1:].contiguous()
|
100 |
+
# Flatten the tokens
|
101 |
+
loss_fct = CrossEntropyLoss()
|
102 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
103 |
+
shift_labels = shift_labels.view(-1)
|
104 |
+
# Enable model/pipeline parallelism
|
105 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
106 |
+
loss = loss_fct(shift_logits, shift_labels)
|
107 |
+
|
108 |
+
if not return_dict:
|
109 |
+
output = (logits,) + outputs[1:]
|
110 |
+
return (loss,) + output if loss is not None else output
|
111 |
+
|
112 |
+
return CausalLMOutputWithPast(
|
113 |
+
loss=loss,
|
114 |
+
logits=logits,
|
115 |
+
past_key_values=outputs.past_key_values,
|
116 |
+
hidden_states=outputs.hidden_states,
|
117 |
+
attentions=outputs.attentions,
|
118 |
+
)
|
119 |
+
|
120 |
+
def prepare_inputs_for_generation(
|
121 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
122 |
+
):
|
123 |
+
if past_key_values:
|
124 |
+
input_ids = input_ids[:, -1:]
|
125 |
+
|
126 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
127 |
+
if inputs_embeds is not None and past_key_values is None:
|
128 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
129 |
+
else:
|
130 |
+
model_inputs = {"input_ids": input_ids}
|
131 |
+
|
132 |
+
model_inputs.update(
|
133 |
+
{
|
134 |
+
"past_key_values": past_key_values,
|
135 |
+
"use_cache": kwargs.get("use_cache"),
|
136 |
+
"attention_mask": attention_mask,
|
137 |
+
"images": kwargs.get("images", None),
|
138 |
+
"segs": kwargs.get("segs", None),
|
139 |
+
"depths": kwargs.get("depths", None),
|
140 |
+
}
|
141 |
+
)
|
142 |
+
return model_inputs
|
143 |
+
|
144 |
+
AutoConfig.register("vcoder_ds_llava", VCoderDSLlavaConfig)
|
145 |
+
AutoModelForCausalLM.register(VCoderDSLlavaConfig, VCoderDSLlavaLlamaForCausalLM)
|
vcoder_llava/model/language_model/vcoder_llava_llama.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
23 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
24 |
+
|
25 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
26 |
+
|
27 |
+
from ..vcoder_llava_arch import VCoderLlavaMetaModel, VCoderLlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class VCoderLlavaConfig(LlamaConfig):
|
31 |
+
model_type = "vcoder_llava"
|
32 |
+
|
33 |
+
|
34 |
+
class VCoderLlavaLlamaModel(VCoderLlavaMetaModel, LlamaModel):
|
35 |
+
config_class = VCoderLlavaConfig
|
36 |
+
|
37 |
+
def __init__(self, config: LlamaConfig):
|
38 |
+
super(VCoderLlavaLlamaModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class VCoderLlavaLlamaForCausalLM(LlamaForCausalLM, VCoderLlavaMetaForCausalLM):
|
42 |
+
config_class = VCoderLlavaConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(LlamaForCausalLM, self).__init__(config)
|
46 |
+
self.model = VCoderLlavaLlamaModel(config)
|
47 |
+
|
48 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
49 |
+
|
50 |
+
# Initialize weights and apply final processing
|
51 |
+
self.post_init()
|
52 |
+
|
53 |
+
def get_model(self):
|
54 |
+
return self.model
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
input_ids: torch.LongTensor = None,
|
59 |
+
attention_mask: Optional[torch.Tensor] = None,
|
60 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
61 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
62 |
+
labels: Optional[torch.LongTensor] = None,
|
63 |
+
use_cache: Optional[bool] = None,
|
64 |
+
output_attentions: Optional[bool] = None,
|
65 |
+
output_hidden_states: Optional[bool] = None,
|
66 |
+
images: Optional[torch.FloatTensor] = None,
|
67 |
+
segs: Optional[torch.FloatTensor] = None,
|
68 |
+
return_dict: Optional[bool] = None,
|
69 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
70 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
71 |
+
output_hidden_states = (
|
72 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
73 |
+
)
|
74 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
75 |
+
|
76 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, segs)
|
77 |
+
|
78 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
79 |
+
outputs = self.model(
|
80 |
+
input_ids=input_ids,
|
81 |
+
attention_mask=attention_mask,
|
82 |
+
past_key_values=past_key_values,
|
83 |
+
inputs_embeds=inputs_embeds,
|
84 |
+
use_cache=use_cache,
|
85 |
+
output_attentions=output_attentions,
|
86 |
+
output_hidden_states=output_hidden_states,
|
87 |
+
return_dict=return_dict
|
88 |
+
)
|
89 |
+
|
90 |
+
hidden_states = outputs[0]
|
91 |
+
logits = self.lm_head(hidden_states)
|
92 |
+
|
93 |
+
loss = None
|
94 |
+
if labels is not None:
|
95 |
+
# Shift so that tokens < n predict n
|
96 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
97 |
+
shift_labels = labels[..., 1:].contiguous()
|
98 |
+
# Flatten the tokens
|
99 |
+
loss_fct = CrossEntropyLoss()
|
100 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
101 |
+
shift_labels = shift_labels.view(-1)
|
102 |
+
# Enable model/pipeline parallelism
|
103 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
104 |
+
loss = loss_fct(shift_logits, shift_labels)
|
105 |
+
|
106 |
+
if not return_dict:
|
107 |
+
output = (logits,) + outputs[1:]
|
108 |
+
return (loss,) + output if loss is not None else output
|
109 |
+
|
110 |
+
return CausalLMOutputWithPast(
|
111 |
+
loss=loss,
|
112 |
+
logits=logits,
|
113 |
+
past_key_values=outputs.past_key_values,
|
114 |
+
hidden_states=outputs.hidden_states,
|
115 |
+
attentions=outputs.attentions,
|
116 |
+
)
|
117 |
+
|
118 |
+
def prepare_inputs_for_generation(
|
119 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
120 |
+
):
|
121 |
+
if past_key_values:
|
122 |
+
input_ids = input_ids[:, -1:]
|
123 |
+
|
124 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
125 |
+
if inputs_embeds is not None and past_key_values is None:
|
126 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
127 |
+
else:
|
128 |
+
model_inputs = {"input_ids": input_ids}
|
129 |
+
|
130 |
+
model_inputs.update(
|
131 |
+
{
|
132 |
+
"past_key_values": past_key_values,
|
133 |
+
"use_cache": kwargs.get("use_cache"),
|
134 |
+
"attention_mask": attention_mask,
|
135 |
+
"images": kwargs.get("images", None),
|
136 |
+
"segs": kwargs.get("segs", None),
|
137 |
+
}
|
138 |
+
)
|
139 |
+
return model_inputs
|
140 |
+
|
141 |
+
AutoConfig.register("vcoder_llava", VCoderLlavaConfig)
|
142 |
+
AutoModelForCausalLM.register(VCoderLlavaConfig, VCoderLlavaLlamaForCausalLM)
|
vcoder_llava/model/llava_arch.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from .multimodal_encoder.builder import build_vision_tower
|
22 |
+
from .multimodal_projector.builder import build_vision_projector
|
23 |
+
|
24 |
+
from vcoder_llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
|
25 |
+
|
26 |
+
|
27 |
+
class LlavaMetaModel:
|
28 |
+
|
29 |
+
def __init__(self, config):
|
30 |
+
super(LlavaMetaModel, self).__init__(config)
|
31 |
+
|
32 |
+
if hasattr(config, "mm_vision_tower"):
|
33 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
34 |
+
self.mm_projector = build_vision_projector(config)
|
35 |
+
|
36 |
+
def get_vision_tower(self):
|
37 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
38 |
+
if type(vision_tower) is list:
|
39 |
+
vision_tower = vision_tower[0]
|
40 |
+
return vision_tower
|
41 |
+
|
42 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
43 |
+
vision_tower = model_args.vision_tower
|
44 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
45 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
46 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
47 |
+
|
48 |
+
self.config.mm_vision_tower = vision_tower
|
49 |
+
|
50 |
+
if self.get_vision_tower() is None:
|
51 |
+
vision_tower = build_vision_tower(model_args)
|
52 |
+
|
53 |
+
if fsdp is not None and len(fsdp) > 0:
|
54 |
+
self.vision_tower = [vision_tower]
|
55 |
+
else:
|
56 |
+
self.vision_tower = vision_tower
|
57 |
+
else:
|
58 |
+
if fsdp is not None and len(fsdp) > 0:
|
59 |
+
vision_tower = self.vision_tower[0]
|
60 |
+
else:
|
61 |
+
vision_tower = self.vision_tower
|
62 |
+
vision_tower.load_model()
|
63 |
+
|
64 |
+
self.config.use_mm_proj = True
|
65 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
66 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
67 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
68 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
69 |
+
|
70 |
+
if getattr(self, 'mm_projector', None) is None:
|
71 |
+
self.mm_projector = build_vision_projector(self.config)
|
72 |
+
else:
|
73 |
+
# In case it is frozen by LoRA
|
74 |
+
for p in self.mm_projector.parameters():
|
75 |
+
p.requires_grad = True
|
76 |
+
|
77 |
+
if pretrain_mm_mlp_adapter is not None:
|
78 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
79 |
+
def get_w(weights, keyword):
|
80 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
81 |
+
|
82 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
83 |
+
|
84 |
+
|
85 |
+
class LlavaMetaForCausalLM(ABC):
|
86 |
+
|
87 |
+
@abstractmethod
|
88 |
+
def get_model(self):
|
89 |
+
pass
|
90 |
+
|
91 |
+
def get_vision_tower(self):
|
92 |
+
return self.get_model().get_vision_tower()
|
93 |
+
|
94 |
+
def encode_images(self, images):
|
95 |
+
image_features = self.get_model().get_vision_tower()(images)
|
96 |
+
image_features = self.get_model().mm_projector(image_features)
|
97 |
+
return image_features
|
98 |
+
|
99 |
+
def prepare_inputs_labels_for_multimodal(
|
100 |
+
self, input_ids, attention_mask, past_key_values, labels, images
|
101 |
+
):
|
102 |
+
vision_tower = self.get_vision_tower()
|
103 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
104 |
+
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
105 |
+
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
106 |
+
return input_ids, attention_mask, past_key_values, None, labels
|
107 |
+
|
108 |
+
if type(images) is list or images.ndim == 5:
|
109 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
110 |
+
image_features = self.encode_images(concat_images)
|
111 |
+
split_sizes = [image.shape[0] for image in images]
|
112 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
113 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
114 |
+
else:
|
115 |
+
image_features = self.encode_images(images)
|
116 |
+
|
117 |
+
new_input_embeds = []
|
118 |
+
new_labels = [] if labels is not None else None
|
119 |
+
cur_image_idx = 0
|
120 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
121 |
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
122 |
+
# multimodal LLM, but the current sample is not multimodal
|
123 |
+
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
124 |
+
half_len = cur_input_ids.shape[0] // 2
|
125 |
+
cur_image_features = image_features[cur_image_idx]
|
126 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
127 |
+
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
128 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
129 |
+
new_input_embeds.append(cur_input_embeds)
|
130 |
+
if labels is not None:
|
131 |
+
new_labels.append(labels[batch_idx])
|
132 |
+
cur_image_idx += 1
|
133 |
+
continue
|
134 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
135 |
+
cur_new_input_embeds = []
|
136 |
+
if labels is not None:
|
137 |
+
cur_labels = labels[batch_idx]
|
138 |
+
cur_new_labels = []
|
139 |
+
assert cur_labels.shape == cur_input_ids.shape
|
140 |
+
while image_token_indices.numel() > 0:
|
141 |
+
cur_image_features = image_features[cur_image_idx]
|
142 |
+
image_token_start = image_token_indices[0]
|
143 |
+
|
144 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
|
145 |
+
cur_new_input_embeds.append(cur_image_features)
|
146 |
+
if labels is not None:
|
147 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
148 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
149 |
+
cur_labels = cur_labels[image_token_start+1:]
|
150 |
+
cur_image_idx += 1
|
151 |
+
cur_input_ids = cur_input_ids[image_token_start+1:]
|
152 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
153 |
+
if cur_input_ids.numel() > 0:
|
154 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
|
155 |
+
if labels is not None:
|
156 |
+
cur_new_labels.append(cur_labels)
|
157 |
+
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
158 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
159 |
+
new_input_embeds.append(cur_new_input_embeds)
|
160 |
+
if labels is not None:
|
161 |
+
cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
162 |
+
new_labels.append(cur_new_labels)
|
163 |
+
|
164 |
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
165 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
166 |
+
|
167 |
+
new_input_embeds_align = []
|
168 |
+
for cur_new_embed in new_input_embeds:
|
169 |
+
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
|
170 |
+
new_input_embeds_align.append(cur_new_embed)
|
171 |
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
172 |
+
|
173 |
+
if labels is not None:
|
174 |
+
new_labels_align = []
|
175 |
+
_new_labels = new_labels
|
176 |
+
for cur_new_label in new_labels:
|
177 |
+
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
|
178 |
+
new_labels_align.append(cur_new_label)
|
179 |
+
new_labels = torch.stack(new_labels_align, dim=0)
|
180 |
+
|
181 |
+
if attention_mask is not None:
|
182 |
+
new_attention_mask = []
|
183 |
+
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
|
184 |
+
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
185 |
+
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
|
186 |
+
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
187 |
+
new_attention_mask.append(cur_new_attention_mask)
|
188 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
189 |
+
assert attention_mask.shape == new_labels.shape
|
190 |
+
else:
|
191 |
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
192 |
+
if labels is not None:
|
193 |
+
new_labels = torch.stack(new_labels, dim=0)
|
194 |
+
|
195 |
+
if attention_mask is not None:
|
196 |
+
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
197 |
+
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
198 |
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
199 |
+
|
200 |
+
return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
vcoder_llava/model/make_delta.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m vcoder_llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
from vcoder_llava.model.utils import auto_upgrade
|
11 |
+
|
12 |
+
|
13 |
+
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
14 |
+
print("Loading base model")
|
15 |
+
base = AutoModelForCausalLM.from_pretrained(
|
16 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading target model")
|
19 |
+
auto_upgrade(target_model_path)
|
20 |
+
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
21 |
+
|
22 |
+
print("Calculating delta")
|
23 |
+
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data -= base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
31 |
+
bparam = base.state_dict()[name]
|
32 |
+
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
|
33 |
+
|
34 |
+
print("Saving delta")
|
35 |
+
if hub_repo_id:
|
36 |
+
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
|
37 |
+
else:
|
38 |
+
kwargs = {}
|
39 |
+
target.save_pretrained(delta_path, **kwargs)
|
40 |
+
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
|
41 |
+
target_tokenizer.save_pretrained(delta_path, **kwargs)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
parser = argparse.ArgumentParser()
|
46 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
47 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
48 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
49 |
+
parser.add_argument("--hub-repo-id", type=str, default=None)
|
50 |
+
args = parser.parse_args()
|
51 |
+
|
52 |
+
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
|
vcoder_llava/model/multimodal_adapter/builder.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import re
|
3 |
+
|
4 |
+
class IdentityMap(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
def forward(self, x, *args, **kwargs):
|
9 |
+
return x
|
10 |
+
|
11 |
+
@property
|
12 |
+
def config(self):
|
13 |
+
return {"seg_mm_projector_type": 'identity'}
|
14 |
+
|
15 |
+
|
16 |
+
class SimpleResBlock(nn.Module):
|
17 |
+
def __init__(self, channels):
|
18 |
+
super().__init__()
|
19 |
+
self.pre_norm = nn.LayerNorm(channels)
|
20 |
+
|
21 |
+
self.proj = nn.Sequential(
|
22 |
+
nn.Linear(channels, channels),
|
23 |
+
nn.GELU(),
|
24 |
+
nn.Linear(channels, channels)
|
25 |
+
)
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.pre_norm(x)
|
28 |
+
return x + self.proj(x)
|
29 |
+
|
30 |
+
|
31 |
+
def build_seg_projector(config, delay_load=False, **kwargs):
|
32 |
+
projector_type = getattr(config, 'seg_mm_projector_type', 'linear')
|
33 |
+
|
34 |
+
if projector_type == 'linear':
|
35 |
+
return nn.Linear(config.seg_mm_hidden_size, config.hidden_size)
|
36 |
+
|
37 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
38 |
+
if mlp_gelu_match:
|
39 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
40 |
+
modules = [nn.Linear(config.seg_mm_hidden_size, config.hidden_size)]
|
41 |
+
for _ in range(1, mlp_depth):
|
42 |
+
modules.append(nn.GELU())
|
43 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
44 |
+
return nn.Sequential(*modules)
|
45 |
+
|
46 |
+
if projector_type == 'identity':
|
47 |
+
return IdentityMap()
|
48 |
+
|
49 |
+
raise ValueError(f'Unknown seg projector type: {projector_type}')
|
vcoder_llava/model/multimodal_depth_adapter/builder.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import re
|
3 |
+
|
4 |
+
class IdentityMap(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
def forward(self, x, *args, **kwargs):
|
9 |
+
return x
|
10 |
+
|
11 |
+
@property
|
12 |
+
def config(self):
|
13 |
+
return {"depth_mm_projector_type": 'identity'}
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class SimpleResBlock(nn.Module):
|
18 |
+
def __init__(self, channels):
|
19 |
+
super().__init__()
|
20 |
+
self.pre_norm = nn.LayerNorm(channels)
|
21 |
+
|
22 |
+
self.proj = nn.Sequential(
|
23 |
+
nn.Linear(channels, channels),
|
24 |
+
nn.GELU(),
|
25 |
+
nn.Linear(channels, channels)
|
26 |
+
)
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.pre_norm(x)
|
29 |
+
return x + self.proj(x)
|
30 |
+
|
31 |
+
|
32 |
+
def build_depth_projector(config, delay_load=False, **kwargs):
|
33 |
+
projector_type = getattr(config, 'depth_mm_projector_type', 'linear')
|
34 |
+
|
35 |
+
if projector_type == 'linear':
|
36 |
+
return nn.Linear(config.depth_mm_hidden_size, config.hidden_size)
|
37 |
+
|
38 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
39 |
+
if mlp_gelu_match:
|
40 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
41 |
+
modules = [nn.Linear(config.depth_mm_hidden_size, config.hidden_size)]
|
42 |
+
for _ in range(1, mlp_depth):
|
43 |
+
modules.append(nn.GELU())
|
44 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
45 |
+
return nn.Sequential(*modules)
|
46 |
+
|
47 |
+
if projector_type == 'identity':
|
48 |
+
return IdentityMap()
|
49 |
+
|
50 |
+
raise ValueError(f'Unknown depth projector type: {projector_type}')
|
vcoder_llava/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .clip_encoder import CLIPVisionTower
|
3 |
+
|
4 |
+
|
5 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
6 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
7 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
8 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
|
9 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
10 |
+
|
11 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
vcoder_llava/model/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
5 |
+
|
6 |
+
|
7 |
+
class CLIPVisionTower(nn.Module):
|
8 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.is_loaded = False
|
12 |
+
|
13 |
+
self.vision_tower_name = vision_tower
|
14 |
+
self.select_layer = args.mm_vision_select_layer
|
15 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
16 |
+
|
17 |
+
if not delay_load:
|
18 |
+
self.load_model()
|
19 |
+
else:
|
20 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
21 |
+
|
22 |
+
def load_model(self):
|
23 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
24 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
25 |
+
self.vision_tower.requires_grad_(False)
|
26 |
+
|
27 |
+
self.is_loaded = True
|
28 |
+
|
29 |
+
def feature_select(self, image_forward_outs):
|
30 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
31 |
+
if self.select_feature == 'patch':
|
32 |
+
image_features = image_features[:, 1:]
|
33 |
+
elif self.select_feature == 'cls_patch':
|
34 |
+
image_features = image_features
|
35 |
+
else:
|
36 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
37 |
+
return image_features
|
38 |
+
|
39 |
+
@torch.no_grad()
|
40 |
+
def forward(self, images):
|
41 |
+
if type(images) is list:
|
42 |
+
image_features = []
|
43 |
+
for image in images:
|
44 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
45 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
46 |
+
image_features.append(image_feature)
|
47 |
+
else:
|
48 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
49 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
50 |
+
|
51 |
+
return image_features
|
52 |
+
|
53 |
+
@property
|
54 |
+
def dummy_feature(self):
|
55 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
56 |
+
|
57 |
+
@property
|
58 |
+
def dtype(self):
|
59 |
+
return self.vision_tower.dtype
|
60 |
+
|
61 |
+
@property
|
62 |
+
def device(self):
|
63 |
+
return self.vision_tower.device
|
64 |
+
|
65 |
+
@property
|
66 |
+
def config(self):
|
67 |
+
if self.is_loaded:
|
68 |
+
return self.vision_tower.config
|
69 |
+
else:
|
70 |
+
return self.cfg_only
|
71 |
+
|
72 |
+
@property
|
73 |
+
def hidden_size(self):
|
74 |
+
return self.config.hidden_size
|
75 |
+
|
76 |
+
@property
|
77 |
+
def num_patches(self):
|
78 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
vcoder_llava/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
class IdentityMap(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
def forward(self, x, *args, **kwargs):
|
11 |
+
return x
|
12 |
+
|
13 |
+
@property
|
14 |
+
def config(self):
|
15 |
+
return {"mm_projector_type": 'identity'}
|
16 |
+
|
17 |
+
|
18 |
+
class SimpleResBlock(nn.Module):
|
19 |
+
def __init__(self, channels):
|
20 |
+
super().__init__()
|
21 |
+
self.pre_norm = nn.LayerNorm(channels)
|
22 |
+
|
23 |
+
self.proj = nn.Sequential(
|
24 |
+
nn.Linear(channels, channels),
|
25 |
+
nn.GELU(),
|
26 |
+
nn.Linear(channels, channels)
|
27 |
+
)
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.pre_norm(x)
|
30 |
+
return x + self.proj(x)
|
31 |
+
|
32 |
+
|
33 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
34 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
35 |
+
|
36 |
+
if projector_type == 'linear':
|
37 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
38 |
+
|
39 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
40 |
+
if mlp_gelu_match:
|
41 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
42 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
43 |
+
for _ in range(1, mlp_depth):
|
44 |
+
modules.append(nn.GELU())
|
45 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
46 |
+
return nn.Sequential(*modules)
|
47 |
+
|
48 |
+
if projector_type == 'identity':
|
49 |
+
return IdentityMap()
|
50 |
+
|
51 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
vcoder_llava/model/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoConfig
|
2 |
+
|
3 |
+
|
4 |
+
def auto_upgrade(config):
|
5 |
+
cfg = AutoConfig.from_pretrained(config)
|
6 |
+
if 'llava' in config and 'llava' not in cfg.model_type:
|
7 |
+
assert cfg.model_type == 'llama'
|
8 |
+
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
9 |
+
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
10 |
+
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
11 |
+
if confirm.lower() in ["y", "yes"]:
|
12 |
+
print("Upgrading checkpoint...")
|
13 |
+
assert len(cfg.architectures) == 1
|
14 |
+
setattr(cfg.__class__, "model_type", "llava")
|
15 |
+
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
16 |
+
cfg.save_pretrained(config)
|
17 |
+
print("Checkpoint upgraded.")
|
18 |
+
else:
|
19 |
+
print("Checkpoint upgrade aborted.")
|
20 |
+
exit(1)
|
vcoder_llava/model/vcd/vcd_add_noise.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def add_diffusion_noise(image_tensor, noise_step):
|
4 |
+
num_steps = 1000 # Number of diffusion steps
|
5 |
+
|
6 |
+
# decide beta in each step
|
7 |
+
betas = torch.linspace(-6,6,num_steps)
|
8 |
+
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
|
9 |
+
|
10 |
+
# decide alphas in each step
|
11 |
+
alphas = 1 - betas
|
12 |
+
alphas_prod = torch.cumprod(alphas, dim=0)
|
13 |
+
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]],0) # p for previous
|
14 |
+
alphas_bar_sqrt = torch.sqrt(alphas_prod)
|
15 |
+
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
|
16 |
+
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
|
17 |
+
|
18 |
+
def q_x(x_0,t):
|
19 |
+
noise = torch.randn_like(x_0)
|
20 |
+
alphas_t = alphas_bar_sqrt[t]
|
21 |
+
alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
|
22 |
+
return (alphas_t*x_0 + alphas_1_m_t*noise)
|
23 |
+
|
24 |
+
noise_delta = int(noise_step) # from 0-999
|
25 |
+
noisy_image = image_tensor.clone()
|
26 |
+
image_tensor_cd = q_x(noisy_image,noise_step)
|
27 |
+
|
28 |
+
return image_tensor_cd
|
vcoder_llava/model/vcd/vcd_sample.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import inspect
|
3 |
+
import warnings
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from transformers.generation.logits_process import (
|
12 |
+
LogitsProcessorList,
|
13 |
+
)
|
14 |
+
from transformers.generation.stopping_criteria import (
|
15 |
+
StoppingCriteria,
|
16 |
+
StoppingCriteriaList,
|
17 |
+
validate_stopping_criteria,
|
18 |
+
)
|
19 |
+
import transformers
|
20 |
+
from transformers.generation.utils import SampleOutput
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def sample(
|
25 |
+
self,
|
26 |
+
input_ids: torch.LongTensor,
|
27 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
28 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
29 |
+
logits_warper: Optional[LogitsProcessorList] = None,
|
30 |
+
max_length: Optional[int] = None,
|
31 |
+
pad_token_id: Optional[int] = None,
|
32 |
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
33 |
+
output_attentions: Optional[bool] = None,
|
34 |
+
output_hidden_states: Optional[bool] = None,
|
35 |
+
output_scores: Optional[bool] = None,
|
36 |
+
return_dict_in_generate: Optional[bool] = None,
|
37 |
+
synced_gpus: bool = False,
|
38 |
+
streamer: Optional["BaseStreamer"] = None,
|
39 |
+
**model_kwargs,
|
40 |
+
) -> Union[SampleOutput, torch.LongTensor]:
|
41 |
+
# init values
|
42 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
43 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
44 |
+
if max_length is not None:
|
45 |
+
warnings.warn(
|
46 |
+
"`max_length` is deprecated in this function, use"
|
47 |
+
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
48 |
+
UserWarning,
|
49 |
+
)
|
50 |
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
51 |
+
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
52 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
53 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
54 |
+
|
55 |
+
|
56 |
+
if isinstance(eos_token_id, int):
|
57 |
+
eos_token_id = [eos_token_id]
|
58 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
59 |
+
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
60 |
+
output_attentions = (
|
61 |
+
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
62 |
+
)
|
63 |
+
output_hidden_states = (
|
64 |
+
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
65 |
+
)
|
66 |
+
|
67 |
+
return_dict_in_generate = (
|
68 |
+
return_dict_in_generate
|
69 |
+
if return_dict_in_generate is not None
|
70 |
+
else self.generation_config.return_dict_in_generate
|
71 |
+
)
|
72 |
+
|
73 |
+
# init attention / hidden states / scores tuples
|
74 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
75 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
76 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
77 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
78 |
+
|
79 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
80 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
81 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
82 |
+
encoder_hidden_states = (
|
83 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
84 |
+
)
|
85 |
+
|
86 |
+
# keep track of which sequences are already finished
|
87 |
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
88 |
+
|
89 |
+
this_peer_finished = False # used by synced_gpus only
|
90 |
+
|
91 |
+
# auto-regressive generation
|
92 |
+
while True:
|
93 |
+
if synced_gpus:
|
94 |
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
95 |
+
# The following logic allows an early break if all peers finished generating their sequence
|
96 |
+
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
97 |
+
# send 0.0 if we finished, 1.0 otherwise
|
98 |
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
99 |
+
# did all peers finish? the reduced sum will be 0.0 then
|
100 |
+
if this_peer_finished_flag.item() == 0.0:
|
101 |
+
break
|
102 |
+
|
103 |
+
# prepare model inputs
|
104 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
105 |
+
|
106 |
+
# forward pass to get next token
|
107 |
+
outputs = self(
|
108 |
+
**model_inputs,
|
109 |
+
return_dict=True,
|
110 |
+
output_attentions=output_attentions,
|
111 |
+
output_hidden_states=output_hidden_states,
|
112 |
+
)
|
113 |
+
|
114 |
+
if synced_gpus and this_peer_finished:
|
115 |
+
continue # don't waste resources running the code we don't need
|
116 |
+
|
117 |
+
next_token_logits = outputs.logits[:, -1, :]
|
118 |
+
|
119 |
+
|
120 |
+
## For contrastive decoding initial
|
121 |
+
use_cd = model_kwargs.get("images_cd") != None
|
122 |
+
output_attentions_wo_img = (
|
123 |
+
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
124 |
+
)
|
125 |
+
output_hidden_states_wo_img = (
|
126 |
+
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
127 |
+
)
|
128 |
+
model_kwargs_cd = model_kwargs.copy()
|
129 |
+
|
130 |
+
if use_cd:
|
131 |
+
## cd_comments: forward pass of the model with distorted image input
|
132 |
+
model_inputs_cd = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd)
|
133 |
+
outputs_cd = self(
|
134 |
+
**model_inputs_cd,
|
135 |
+
return_dict=True,
|
136 |
+
output_attentions=output_attentions_wo_img,
|
137 |
+
output_hidden_states=output_hidden_states_wo_img,
|
138 |
+
)
|
139 |
+
next_token_logits_cd = outputs_cd.logits[:, -1, :]
|
140 |
+
|
141 |
+
## cd_comments: pre-process logits from contrastive inputs
|
142 |
+
cd_alpha = model_kwargs.get("cd_alpha") if model_kwargs.get("cd_alpha") is not None else 0.5
|
143 |
+
cd_beta = model_kwargs.get("cd_beta") if model_kwargs.get("cd_beta") is not None else 0.1
|
144 |
+
|
145 |
+
# version 1 set cutoff for Adaptive Plausibility Constraints
|
146 |
+
# probs = nn.functional.softmax(next_token_logits, dim=-1)
|
147 |
+
# cutoff = cd_beta * probs.max(dim=-1, keepdim=True).values
|
148 |
+
|
149 |
+
# version 2 set cutoff for Adaptive Plausibility Constraints
|
150 |
+
cutoff = torch.log(torch.tensor(cd_beta)) + next_token_logits.max(dim=-1, keepdim=True).values
|
151 |
+
|
152 |
+
diffs = (1+cd_alpha)*next_token_logits - cd_alpha*next_token_logits_cd
|
153 |
+
cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))
|
154 |
+
|
155 |
+
## cd_comments: apply temperature warping and top-k filtering in contrastive decoding
|
156 |
+
cd_logits = logits_processor(input_ids, cd_logits)
|
157 |
+
cd_logits = logits_warper(input_ids, cd_logits)
|
158 |
+
|
159 |
+
next_token_scores = cd_logits
|
160 |
+
cd_probs = nn.functional.softmax(cd_logits, dim=-1)
|
161 |
+
next_tokens = torch.multinomial(cd_probs, num_samples=1).squeeze(1)
|
162 |
+
else:
|
163 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
164 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
165 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
166 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
# Store scores, attentions and hidden_states when required
|
171 |
+
if return_dict_in_generate:
|
172 |
+
if output_scores:
|
173 |
+
scores += (next_token_scores,)
|
174 |
+
if output_attentions:
|
175 |
+
decoder_attentions += (
|
176 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
177 |
+
)
|
178 |
+
if self.config.is_encoder_decoder:
|
179 |
+
cross_attentions += (outputs.cross_attentions,)
|
180 |
+
|
181 |
+
if output_hidden_states:
|
182 |
+
decoder_hidden_states += (
|
183 |
+
(outputs.decoder_hidden_states,)
|
184 |
+
if self.config.is_encoder_decoder
|
185 |
+
else (outputs.hidden_states,)
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
# finished sentences should have their next token be a padding token
|
190 |
+
if eos_token_id is not None:
|
191 |
+
if pad_token_id is None:
|
192 |
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
193 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
194 |
+
|
195 |
+
# update generated ids, model inputs, and length for next step
|
196 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
197 |
+
if streamer is not None:
|
198 |
+
streamer.put(next_tokens.cpu())
|
199 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
200 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
201 |
+
)
|
202 |
+
## cd_comments: update model_kwargs_cd for contrastive decoding
|
203 |
+
if use_cd:
|
204 |
+
model_kwargs_cd = self._update_model_kwargs_for_generation(
|
205 |
+
outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
|
206 |
+
)
|
207 |
+
|
208 |
+
# if eos_token was found in one sentence, set sentence to finished
|
209 |
+
if eos_token_id_tensor is not None:
|
210 |
+
unfinished_sequences = unfinished_sequences.mul(
|
211 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
212 |
+
)
|
213 |
+
|
214 |
+
# stop when each sentence is finished
|
215 |
+
if unfinished_sequences.max() == 0:
|
216 |
+
this_peer_finished = True
|
217 |
+
|
218 |
+
# stop if we exceed the maximum length
|
219 |
+
if stopping_criteria(input_ids, scores):
|
220 |
+
this_peer_finished = True
|
221 |
+
|
222 |
+
if this_peer_finished and not synced_gpus:
|
223 |
+
break
|
224 |
+
|
225 |
+
if streamer is not None:
|
226 |
+
streamer.end()
|
227 |
+
|
228 |
+
if return_dict_in_generate:
|
229 |
+
if self.config.is_encoder_decoder:
|
230 |
+
return SampleEncoderDecoderOutput(
|
231 |
+
sequences=input_ids,
|
232 |
+
scores=scores,
|
233 |
+
encoder_attentions=encoder_attentions,
|
234 |
+
encoder_hidden_states=encoder_hidden_states,
|
235 |
+
decoder_attentions=decoder_attentions,
|
236 |
+
cross_attentions=cross_attentions,
|
237 |
+
decoder_hidden_states=decoder_hidden_states,
|
238 |
+
)
|
239 |
+
else:
|
240 |
+
return SampleDecoderOnlyOutput(
|
241 |
+
sequences=input_ids,
|
242 |
+
scores=scores,
|
243 |
+
attentions=decoder_attentions,
|
244 |
+
hidden_states=decoder_hidden_states,
|
245 |
+
)
|
246 |
+
else:
|
247 |
+
return input_ids
|
248 |
+
|
249 |
+
def evolve_vcd_sampling():
|
250 |
+
transformers.generation.utils.GenerationMixin.sample = sample
|
vcoder_llava/model/vcoder_ds_llava_arch.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from .multimodal_encoder.builder import build_vision_tower
|
22 |
+
from .multimodal_projector.builder import build_vision_projector
|
23 |
+
from .multimodal_adapter.builder import build_seg_projector
|
24 |
+
from .multimodal_depth_adapter.builder import build_depth_projector
|
25 |
+
|
26 |
+
from vcoder_llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX, DEPTH_TOKEN_INDEX
|
27 |
+
|
28 |
+
class VCoderDSLlavaMetaModel:
|
29 |
+
|
30 |
+
def __init__(self, config):
|
31 |
+
super(VCoderDSLlavaMetaModel, self).__init__(config)
|
32 |
+
self.config = config
|
33 |
+
|
34 |
+
if hasattr(config, "mm_vision_tower"):
|
35 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
36 |
+
self.mm_projector = build_vision_projector(config)
|
37 |
+
|
38 |
+
if hasattr(config, "seg_mm_projector_type"):
|
39 |
+
self.seg_mm_projector = build_seg_projector(config)
|
40 |
+
|
41 |
+
if hasattr(config, "use_mm2_proj"):
|
42 |
+
if config.use_mm2_proj:
|
43 |
+
self.mm2_projector = build_vision_projector(config)
|
44 |
+
|
45 |
+
if hasattr(config, "depth_mm_projector_type"):
|
46 |
+
self.depth_mm_projector = build_depth_projector(config)
|
47 |
+
|
48 |
+
if hasattr(config, "mm_vcoder_lm_emb"):
|
49 |
+
self.vcoder_lm_emb = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
50 |
+
|
51 |
+
def get_vision_tower(self):
|
52 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
53 |
+
if type(vision_tower) is list:
|
54 |
+
vision_tower = vision_tower[0]
|
55 |
+
return vision_tower
|
56 |
+
|
57 |
+
def initialize_seg_modules(self, model_args, fsdp=None):
|
58 |
+
mm_seg_select_layer = model_args.mm_seg_select_layer
|
59 |
+
mm_seg_select_feature = model_args.mm_seg_select_feature
|
60 |
+
|
61 |
+
self.config.seg_mm_hidden_size = self.vision_tower.hidden_size
|
62 |
+
|
63 |
+
self.config.seg_use_mm_proj = True
|
64 |
+
self.config.seg_mm_projector_type = getattr(model_args, 'seg_mm_projector_type', 'linear')
|
65 |
+
self.config.mm_seg_select_layer = mm_seg_select_layer
|
66 |
+
self.config.mm_seg_select_feature = mm_seg_select_feature
|
67 |
+
|
68 |
+
self.seg_mm_projector = build_seg_projector(self.config)
|
69 |
+
self.vcoder_lm_emb = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.config.pad_token_id)
|
70 |
+
|
71 |
+
# use MLP from pretraining stage
|
72 |
+
pretrain_mm2_mlp_adapter = model_args.pretrain_mm2_mlp_adapter
|
73 |
+
if getattr(model_args, "use_mm2_proj"):
|
74 |
+
self.config.use_mm2_proj = model_args.use_mm2_proj
|
75 |
+
self.mm2_projector = build_vision_projector(self.config)
|
76 |
+
|
77 |
+
if pretrain_mm2_mlp_adapter is not None:
|
78 |
+
mm2_projector_weights = torch.load(pretrain_mm2_mlp_adapter, map_location='cpu')
|
79 |
+
def get_w(weights, keyword):
|
80 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
81 |
+
|
82 |
+
self.mm2_projector.load_state_dict(get_w(mm2_projector_weights, 'mm_projector'))
|
83 |
+
|
84 |
+
def initialize_depth_modules(self, model_args, fsdp=None):
|
85 |
+
mm_depth_select_layer = model_args.mm_depth_select_layer
|
86 |
+
mm_depth_select_feature = model_args.mm_depth_select_feature
|
87 |
+
|
88 |
+
self.config.depth_mm_hidden_size = self.vision_tower.hidden_size
|
89 |
+
|
90 |
+
self.config.depth_use_mm_proj = True
|
91 |
+
self.config.depth_mm_projector_type = getattr(model_args, 'depth_mm_projector_type', 'linear')
|
92 |
+
self.config.mm_depth_select_layer = mm_depth_select_layer
|
93 |
+
self.config.mm_depth_select_feature = mm_depth_select_feature
|
94 |
+
|
95 |
+
self.depth_mm_projector = build_depth_projector(self.config)
|
96 |
+
|
97 |
+
class VCoderDSLlavaMetaForCausalLM(ABC):
|
98 |
+
|
99 |
+
@abstractmethod
|
100 |
+
def get_model(self):
|
101 |
+
pass
|
102 |
+
|
103 |
+
def get_vision_tower(self):
|
104 |
+
return self.get_model().get_vision_tower()
|
105 |
+
|
106 |
+
def encode_seg_images(self, seg_images):
|
107 |
+
seg_features = self.get_model().get_vision_tower()(seg_images)
|
108 |
+
seg_features = self.get_model().seg_mm_projector(seg_features)
|
109 |
+
return seg_features
|
110 |
+
|
111 |
+
def encode_depth_images(self, depth_images):
|
112 |
+
depth_features = self.get_model().get_vision_tower()(depth_images)
|
113 |
+
depth_features = self.get_model().seg_mm_projector(depth_features)
|
114 |
+
return depth_features
|
115 |
+
|
116 |
+
def encode_images(self, images):
|
117 |
+
image_features = self.get_model().get_vision_tower()(images)
|
118 |
+
image_features = self.get_model().mm_projector(image_features)
|
119 |
+
return image_features
|
120 |
+
|
121 |
+
def encode_images_w_seg(self, images):
|
122 |
+
image_features = self.get_model().get_vision_tower()(images)
|
123 |
+
image_features = self.get_model().mm2_projector(image_features)
|
124 |
+
return image_features
|
125 |
+
|
126 |
+
def prepare_inputs_labels_for_multimodal(
|
127 |
+
self, input_ids, attention_mask, past_key_values, labels, images, seg_images, depth_images
|
128 |
+
):
|
129 |
+
vision_tower = self.get_vision_tower()
|
130 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
131 |
+
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
132 |
+
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
133 |
+
return input_ids, attention_mask, past_key_values, None, labels
|
134 |
+
|
135 |
+
if type(images) is list or images.ndim == 5:
|
136 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
137 |
+
if seg_images is not None and hasattr(self, 'mm2_projector'):
|
138 |
+
image_features = self.encode_images_w_seg(concat_images)
|
139 |
+
else:
|
140 |
+
image_features = self.encode_images(concat_images)
|
141 |
+
split_sizes = [image.shape[0] for image in images]
|
142 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
143 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
144 |
+
else:
|
145 |
+
if seg_images is not None and hasattr(self, 'mm2_projector'):
|
146 |
+
image_features = self.encode_images_w_seg(images)
|
147 |
+
else:
|
148 |
+
image_features = self.encode_images(images)
|
149 |
+
|
150 |
+
if seg_images is not None:
|
151 |
+
if type(seg_images) is list or seg_images.ndim == 5:
|
152 |
+
concat_seg_images = torch.cat([image for image in seg_images], dim=0)
|
153 |
+
seg_features = self.encode_seg_images(concat_seg_images)
|
154 |
+
split_sizes = [image.shape[0] for image in seg_images]
|
155 |
+
seg_features = torch.split(seg_features, split_sizes, dim=0)
|
156 |
+
seg_features = [x.flatten(0, 1) for x in seg_features]
|
157 |
+
else:
|
158 |
+
seg_features = self.encode_seg_images(seg_images)
|
159 |
+
|
160 |
+
if depth_images is not None:
|
161 |
+
try:
|
162 |
+
for p in self.get_model().depth_mm_projector.parameters():
|
163 |
+
p.requires_grad = True
|
164 |
+
if type(depth_images) is list or depth_images.ndim == 5:
|
165 |
+
concat_depth_images = torch.cat([image for image in depth_images], dim=0)
|
166 |
+
depth_features = self.encode_depth_images(concat_depth_images)
|
167 |
+
split_sizes = [image.shape[0] for image in depth_images]
|
168 |
+
depth_features = torch.split(depth_features, split_sizes, dim=0)
|
169 |
+
depth_features = [x.flatten(0, 1) for x in depth_features]
|
170 |
+
else:
|
171 |
+
depth_features = self.encode_depth_images(depth_images)
|
172 |
+
except:
|
173 |
+
depth_images = None
|
174 |
+
mask = input_ids != DEPTH_TOKEN_INDEX # drop depth indices
|
175 |
+
input_ids = input_ids[mask]
|
176 |
+
for p in self.get_model().depth_mm_projector.parameters():
|
177 |
+
p.requires_grad = False
|
178 |
+
else:
|
179 |
+
for p in self.get_model().depth_mm_projector.parameters():
|
180 |
+
p.requires_grad = False
|
181 |
+
|
182 |
+
self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
|
183 |
+
|
184 |
+
new_input_embeds = []
|
185 |
+
new_labels = [] if labels is not None else None
|
186 |
+
cur_image_idx = 0
|
187 |
+
cur_seg_idx = 0
|
188 |
+
cur_depth_idx = 0
|
189 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
190 |
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 and (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
|
191 |
+
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
192 |
+
cur_image_features = image_features[cur_image_idx]
|
193 |
+
half_len = cur_input_ids.shape[0] // 2
|
194 |
+
if seg_images is not None:
|
195 |
+
cur_seg_features = seg_features[cur_seg_idx]
|
196 |
+
if depth_images is not None:
|
197 |
+
cur_depth_features = depth_features[cur_depth_idx]
|
198 |
+
cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
|
199 |
+
cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
|
200 |
+
else:
|
201 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
202 |
+
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
203 |
+
if seg_images is not None:
|
204 |
+
if depth_images is not None:
|
205 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_depth_features[0:0], cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
206 |
+
else:
|
207 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
208 |
+
else:
|
209 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
210 |
+
new_input_embeds.append(cur_input_embeds)
|
211 |
+
if labels is not None:
|
212 |
+
new_labels.append(labels[batch_idx])
|
213 |
+
cur_image_idx += 1
|
214 |
+
cur_seg_idx += 1
|
215 |
+
cur_depth_idx += 1
|
216 |
+
continue
|
217 |
+
|
218 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
219 |
+
|
220 |
+
cur_new_input_embeds = []
|
221 |
+
if labels is not None:
|
222 |
+
cur_labels = labels[batch_idx]
|
223 |
+
cur_new_labels = []
|
224 |
+
assert cur_labels.shape == cur_input_ids.shape
|
225 |
+
while image_token_indices.numel() > 0:
|
226 |
+
cur_image_features = image_features[cur_image_idx]
|
227 |
+
image_token_start = image_token_indices[0]
|
228 |
+
if seg_images is None:
|
229 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
|
230 |
+
else:
|
231 |
+
cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:image_token_start]))
|
232 |
+
cur_new_input_embeds.append(cur_image_features)
|
233 |
+
if labels is not None:
|
234 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
235 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
236 |
+
cur_labels = cur_labels[image_token_start+1:]
|
237 |
+
cur_image_idx += 1
|
238 |
+
cur_input_ids = cur_input_ids[image_token_start+1:]
|
239 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
240 |
+
|
241 |
+
if seg_images is not None:
|
242 |
+
seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
|
243 |
+
while seg_token_indices.numel() > 0:
|
244 |
+
cur_seg_features = seg_features[cur_seg_idx]
|
245 |
+
seg_token_start = seg_token_indices[0]
|
246 |
+
if depth_images is None:
|
247 |
+
cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:seg_token_start]))
|
248 |
+
cur_new_input_embeds.append(cur_seg_features)
|
249 |
+
if labels is not None:
|
250 |
+
if depth_images is None:
|
251 |
+
cur_new_labels.append(cur_labels[:seg_token_start])
|
252 |
+
cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
253 |
+
cur_labels = cur_labels[seg_token_start+1:]
|
254 |
+
cur_seg_idx += 1
|
255 |
+
cur_input_ids = cur_input_ids[seg_token_start+1:]
|
256 |
+
seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
|
257 |
+
|
258 |
+
if depth_images is not None:
|
259 |
+
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
260 |
+
while depth_token_indices.numel() > 0:
|
261 |
+
cur_depth_features = depth_features[cur_depth_idx]
|
262 |
+
depth_token_start = depth_token_indices[0]
|
263 |
+
cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:depth_token_start]))
|
264 |
+
cur_new_input_embeds.append(cur_depth_features)
|
265 |
+
if labels is not None:
|
266 |
+
cur_new_labels.append(cur_labels[:depth_token_start])
|
267 |
+
cur_new_labels.append(torch.full((cur_depth_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
268 |
+
cur_labels = cur_labels[depth_token_start+1:]
|
269 |
+
cur_depth_idx += 1
|
270 |
+
cur_input_ids = cur_input_ids[depth_token_start+1:]
|
271 |
+
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
272 |
+
|
273 |
+
if cur_input_ids.numel() > 0:
|
274 |
+
if seg_images is None:
|
275 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
|
276 |
+
else:
|
277 |
+
cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids))
|
278 |
+
if labels is not None:
|
279 |
+
cur_new_labels.append(cur_labels)
|
280 |
+
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
281 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
282 |
+
new_input_embeds.append(cur_new_input_embeds)
|
283 |
+
if labels is not None:
|
284 |
+
cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
285 |
+
new_labels.append(cur_new_labels)
|
286 |
+
|
287 |
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
288 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
289 |
+
|
290 |
+
new_input_embeds_align = []
|
291 |
+
for cur_new_embed in new_input_embeds:
|
292 |
+
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
|
293 |
+
new_input_embeds_align.append(cur_new_embed)
|
294 |
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
295 |
+
|
296 |
+
if labels is not None:
|
297 |
+
new_labels_align = []
|
298 |
+
_new_labels = new_labels
|
299 |
+
for cur_new_label in new_labels:
|
300 |
+
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
|
301 |
+
new_labels_align.append(cur_new_label)
|
302 |
+
new_labels = torch.stack(new_labels_align, dim=0)
|
303 |
+
|
304 |
+
if attention_mask is not None:
|
305 |
+
new_attention_mask = []
|
306 |
+
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
|
307 |
+
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
308 |
+
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
|
309 |
+
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
310 |
+
new_attention_mask.append(cur_new_attention_mask)
|
311 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
312 |
+
assert attention_mask.shape == new_labels.shape
|
313 |
+
else:
|
314 |
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
315 |
+
if labels is not None:
|
316 |
+
new_labels = torch.stack(new_labels, dim=0)
|
317 |
+
|
318 |
+
if attention_mask is not None:
|
319 |
+
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
320 |
+
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
321 |
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
322 |
+
|
323 |
+
return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
vcoder_llava/model/vcoder_llava_arch.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from .multimodal_encoder.builder import build_vision_tower
|
22 |
+
from .multimodal_projector.builder import build_vision_projector
|
23 |
+
from .multimodal_adapter.builder import build_seg_projector
|
24 |
+
|
25 |
+
from vcoder_llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, SEG_TOKEN_INDEX
|
26 |
+
|
27 |
+
class VCoderLlavaMetaModel:
|
28 |
+
|
29 |
+
def __init__(self, config):
|
30 |
+
super(VCoderLlavaMetaModel, self).__init__(config)
|
31 |
+
self.config = config
|
32 |
+
|
33 |
+
if hasattr(config, "mm_vision_tower"):
|
34 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
35 |
+
self.mm_projector = build_vision_projector(config)
|
36 |
+
|
37 |
+
if hasattr(config, "seg_mm_projector_type"):
|
38 |
+
self.seg_mm_projector = build_seg_projector(config)
|
39 |
+
|
40 |
+
if hasattr(config, "use_mm2_proj"):
|
41 |
+
if config.use_mm2_proj:
|
42 |
+
self.mm2_projector = build_vision_projector(config)
|
43 |
+
|
44 |
+
if hasattr(config, "mm_vcoder_lm_emb"):
|
45 |
+
self.vcoder_lm_emb = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
46 |
+
|
47 |
+
def get_vision_tower(self):
|
48 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
49 |
+
if type(vision_tower) is list:
|
50 |
+
vision_tower = vision_tower[0]
|
51 |
+
return vision_tower
|
52 |
+
|
53 |
+
def initialize_seg_modules(self, model_args, fsdp=None):
|
54 |
+
mm_seg_select_layer = model_args.mm_seg_select_layer
|
55 |
+
mm_seg_select_feature = model_args.mm_seg_select_feature
|
56 |
+
|
57 |
+
self.config.seg_mm_hidden_size = self.vision_tower.hidden_size
|
58 |
+
|
59 |
+
pretrain_mm2_mlp_adapter = model_args.pretrain_mm2_mlp_adapter
|
60 |
+
|
61 |
+
self.config.seg_use_mm_proj = True
|
62 |
+
self.config.seg_mm_projector_type = getattr(model_args, 'seg_mm_projector_type', 'linear')
|
63 |
+
self.config.mm_seg_select_layer = mm_seg_select_layer
|
64 |
+
self.config.mm_seg_select_feature = mm_seg_select_feature
|
65 |
+
|
66 |
+
self.seg_mm_projector = build_seg_projector(self.config)
|
67 |
+
self.vcoder_lm_emb = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.config.pad_token_id)
|
68 |
+
|
69 |
+
if getattr(model_args, "use_mm2_proj"):
|
70 |
+
self.config.use_mm2_proj = model_args.use_mm2_proj
|
71 |
+
self.mm2_projector = build_vision_projector(self.config)
|
72 |
+
|
73 |
+
if pretrain_mm2_mlp_adapter is not None:
|
74 |
+
mm2_projector_weights = torch.load(pretrain_mm2_mlp_adapter, map_location='cpu')
|
75 |
+
def get_w(weights, keyword):
|
76 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
77 |
+
|
78 |
+
self.mm2_projector.load_state_dict(get_w(mm2_projector_weights, 'mm_projector'))
|
79 |
+
|
80 |
+
class VCoderLlavaMetaForCausalLM(ABC):
|
81 |
+
|
82 |
+
@abstractmethod
|
83 |
+
def get_model(self):
|
84 |
+
pass
|
85 |
+
|
86 |
+
def get_vision_tower(self):
|
87 |
+
return self.get_model().get_vision_tower()
|
88 |
+
|
89 |
+
def encode_seg_images(self, seg_images):
|
90 |
+
seg_features = self.get_model().get_vision_tower()(seg_images)
|
91 |
+
seg_features = self.get_model().seg_mm_projector(seg_features)
|
92 |
+
return seg_features
|
93 |
+
|
94 |
+
def encode_images(self, images):
|
95 |
+
image_features = self.get_model().get_vision_tower()(images)
|
96 |
+
image_features = self.get_model().mm_projector(image_features)
|
97 |
+
return image_features
|
98 |
+
|
99 |
+
def encode_images_w_seg(self, images):
|
100 |
+
image_features = self.get_model().get_vision_tower()(images)
|
101 |
+
image_features = self.get_model().mm2_projector(image_features)
|
102 |
+
return image_features
|
103 |
+
|
104 |
+
def prepare_inputs_labels_for_multimodal(
|
105 |
+
self, input_ids, attention_mask, past_key_values, labels, images, seg_images,
|
106 |
+
):
|
107 |
+
vision_tower = self.get_vision_tower()
|
108 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
109 |
+
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
110 |
+
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
111 |
+
return input_ids, attention_mask, past_key_values, None, labels
|
112 |
+
|
113 |
+
if type(images) is list or images.ndim == 5:
|
114 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
115 |
+
if seg_images is not None and hasattr(self, 'mm2_projector'):
|
116 |
+
image_features = self.encode_images_w_seg(concat_images)
|
117 |
+
else:
|
118 |
+
image_features = self.encode_images(concat_images)
|
119 |
+
split_sizes = [image.shape[0] for image in images]
|
120 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
121 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
122 |
+
else:
|
123 |
+
if seg_images is not None and hasattr(self, 'mm2_projector'):
|
124 |
+
image_features = self.encode_images_w_seg(images)
|
125 |
+
else:
|
126 |
+
image_features = self.encode_images(images)
|
127 |
+
|
128 |
+
if seg_images is not None:
|
129 |
+
if type(seg_images) is list or seg_images.ndim == 5:
|
130 |
+
concat_seg_images = torch.cat([image for image in seg_images], dim=0)
|
131 |
+
seg_features = self.encode_seg_images(concat_seg_images)
|
132 |
+
split_sizes = [image.shape[0] for image in seg_images]
|
133 |
+
seg_features = torch.split(seg_features, split_sizes, dim=0)
|
134 |
+
seg_features = [x.flatten(0, 1) for x in seg_features]
|
135 |
+
else:
|
136 |
+
seg_features = self.encode_seg_images(seg_images)
|
137 |
+
|
138 |
+
self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
|
139 |
+
|
140 |
+
new_input_embeds = []
|
141 |
+
new_labels = [] if labels is not None else None
|
142 |
+
cur_image_idx = 0
|
143 |
+
cur_seg_idx = 0
|
144 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
145 |
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 or (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
|
146 |
+
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
147 |
+
cur_image_features = image_features[cur_image_idx]
|
148 |
+
if seg_images is not None:
|
149 |
+
cur_seg_features = seg_features[cur_seg_idx]
|
150 |
+
half_len = cur_input_ids.shape[0] // 2
|
151 |
+
if seg_images is not None:
|
152 |
+
cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
|
153 |
+
cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
|
154 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
155 |
+
else:
|
156 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
157 |
+
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
158 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
159 |
+
new_input_embeds.append(cur_input_embeds)
|
160 |
+
if labels is not None:
|
161 |
+
new_labels.append(labels[batch_idx])
|
162 |
+
cur_image_idx += 1
|
163 |
+
cur_seg_idx += 1
|
164 |
+
continue
|
165 |
+
|
166 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
167 |
+
|
168 |
+
cur_new_input_embeds = []
|
169 |
+
if labels is not None:
|
170 |
+
cur_labels = labels[batch_idx]
|
171 |
+
cur_new_labels = []
|
172 |
+
assert cur_labels.shape == cur_input_ids.shape
|
173 |
+
while image_token_indices.numel() > 0:
|
174 |
+
cur_image_features = image_features[cur_image_idx]
|
175 |
+
image_token_start = image_token_indices[0]
|
176 |
+
if seg_images is None:
|
177 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
|
178 |
+
else:
|
179 |
+
cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:image_token_start]))
|
180 |
+
cur_new_input_embeds.append(cur_image_features)
|
181 |
+
if labels is not None:
|
182 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
183 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
184 |
+
cur_labels = cur_labels[image_token_start+1:]
|
185 |
+
cur_image_idx += 1
|
186 |
+
cur_input_ids = cur_input_ids[image_token_start+1:]
|
187 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
188 |
+
|
189 |
+
if seg_images is not None:
|
190 |
+
seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
|
191 |
+
while seg_token_indices.numel() > 0:
|
192 |
+
cur_seg_features = seg_features[cur_seg_idx]
|
193 |
+
seg_token_start = seg_token_indices[0]
|
194 |
+
cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:seg_token_start]))
|
195 |
+
cur_new_input_embeds.append(cur_seg_features)
|
196 |
+
if labels is not None:
|
197 |
+
cur_new_labels.append(cur_labels[:seg_token_start])
|
198 |
+
cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
199 |
+
cur_labels = cur_labels[seg_token_start+1:]
|
200 |
+
cur_seg_idx += 1
|
201 |
+
cur_input_ids = cur_input_ids[seg_token_start+1:]
|
202 |
+
seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
|
203 |
+
|
204 |
+
if cur_input_ids.numel() > 0:
|
205 |
+
if seg_images is None:
|
206 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
|
207 |
+
else:
|
208 |
+
cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids))
|
209 |
+
if labels is not None:
|
210 |
+
cur_new_labels.append(cur_labels)
|
211 |
+
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
212 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
213 |
+
new_input_embeds.append(cur_new_input_embeds)
|
214 |
+
if labels is not None:
|
215 |
+
cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
216 |
+
new_labels.append(cur_new_labels)
|
217 |
+
|
218 |
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
219 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
220 |
+
|
221 |
+
new_input_embeds_align = []
|
222 |
+
for cur_new_embed in new_input_embeds:
|
223 |
+
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
|
224 |
+
new_input_embeds_align.append(cur_new_embed)
|
225 |
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
226 |
+
|
227 |
+
if labels is not None:
|
228 |
+
new_labels_align = []
|
229 |
+
_new_labels = new_labels
|
230 |
+
for cur_new_label in new_labels:
|
231 |
+
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
|
232 |
+
new_labels_align.append(cur_new_label)
|
233 |
+
new_labels = torch.stack(new_labels_align, dim=0)
|
234 |
+
|
235 |
+
if attention_mask is not None:
|
236 |
+
new_attention_mask = []
|
237 |
+
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
|
238 |
+
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
239 |
+
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
|
240 |
+
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
241 |
+
new_attention_mask.append(cur_new_attention_mask)
|
242 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
243 |
+
assert attention_mask.shape == new_labels.shape
|
244 |
+
else:
|
245 |
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
246 |
+
if labels is not None:
|
247 |
+
new_labels = torch.stack(new_labels, dim=0)
|
248 |
+
|
249 |
+
if attention_mask is not None:
|
250 |
+
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
251 |
+
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
252 |
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
253 |
+
|
254 |
+
return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
vcoder_llava/questions.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
SEMANTIC_QUESTIONS = [
|
3 |
+
"What objects can be seen in the image? Perceive as done for semantic segmentation.",
|
4 |
+
"What items are depicted in the picture? Consider in terms of semantic segmentation.",
|
5 |
+
"Which elements are present in the visual? Analyze as you would for semantic segmentation.",
|
6 |
+
"Can you identify the objects in the image? Think from a semantic segmentation perspective.",
|
7 |
+
"What are the components visible in the graphic? Examine as if segmenting semantically.",
|
8 |
+
"Which entities can be spotted in the photo? View through the lens of semantic segmentation.",
|
9 |
+
"What are the discernible objects in the snapshot? Envision in relation to semantic segmentation.",
|
10 |
+
"What elements stand out in the illustration? Reflect upon it as for semantic segmentation.",
|
11 |
+
"Can you spot any items within the visual representation? Contemplate in a semantic segmentation context.",
|
12 |
+
"What features are evident in this visual content? Analyze with semantic segmentation in mind.",
|
13 |
+
"Which objects are noticeable in the image? Think of it in terms of semantic layers.",
|
14 |
+
"How would you categorize the objects in this picture? As if you're doing semantic segmentation.",
|
15 |
+
"What constituents can you recognize in the image? Ponder considering semantic segmentation.",
|
16 |
+
"Which components can be distinguished in the photo? Evaluate as per semantic segmentation guidelines.",
|
17 |
+
"What items in the image can you point out? Interpret with a semantic segmentation approach.",
|
18 |
+
"Can you enumerate the objects present in this visual? Think semantically.",
|
19 |
+
"What do you observe in the graphic? Consider its semantic segments.",
|
20 |
+
"How many distinct objects can you identify in the visual? Keeping semantic segmentation in perspective.",
|
21 |
+
"Which items are apparent in this depiction? Assess as one would for semantic segmentation.",
|
22 |
+
"What are the visible entities within this image? Delve into it semantically.",
|
23 |
+
"Can you discern specific objects in the portrayal? Approach it from a semantic segmentation standpoint.",
|
24 |
+
]
|
25 |
+
|
26 |
+
INSTANCE_QUESTIONS = [
|
27 |
+
"What objects can be seen in the image? Perceive as done for instance segmentation",
|
28 |
+
"What items are visible in the picture? Analyze as you would for instance segmentation.",
|
29 |
+
"Which elements are present in the visual? Consider from an instance segmentation perspective.",
|
30 |
+
"What are the distinguishable objects in the image? Think in terms of instance segmentation.",
|
31 |
+
"Can you identify the entities in the graphic? Approach it with instance segmentation in mind.",
|
32 |
+
"What components are apparent in the photo? Examine as if performing instance segmentation.",
|
33 |
+
"Which items can be detected in the snapshot? View it through the lens of instance segmentation.",
|
34 |
+
"What features stand out in the illustration? Reflect upon it as for instance segmentation.",
|
35 |
+
"How would you describe the objects in this image? Keeping instance segmentation as a reference.",
|
36 |
+
"What constituents are evident in the visual content? Think from an instance segmentation standpoint.",
|
37 |
+
"Which objects can you spot in the depiction? Evaluate as per instance segmentation guidelines.",
|
38 |
+
"What do you observe in the graphic? Contemplate with instance segmentation considerations.",
|
39 |
+
"Can you discern specific entities in the visual? Approach it in the context of instance segmentation.",
|
40 |
+
"Which components in the image catch your eye? Think of it in relation to instance layers.",
|
41 |
+
"How many distinct items can you pinpoint in the photo? With an instance segmentation approach.",
|
42 |
+
"What elements are noticeable in this portrayal? Analyze while considering instance segmentation.",
|
43 |
+
"Can you list the objects present in the visual representation? Reflecting on instance segmentation.",
|
44 |
+
"What items in the snapshot can you recognize? Interpret with an instance segmentation perspective.",
|
45 |
+
"Which entities are discernible in this depiction? Delve into it from an instance segmentation angle.",
|
46 |
+
"What are the components you can spot within the image? Think instance-wise.",
|
47 |
+
"Can you detail the objects in the visual? Assess as one would for instance segmentation.",
|
48 |
+
]
|
49 |
+
|
50 |
+
PANOPTIC_QUESTIONS = [
|
51 |
+
"What objects can be seen in the image? Perceive as done for panoptic segmentation",
|
52 |
+
"What items are evident in the picture? Analyze with a panoptic segmentation perspective.",
|
53 |
+
"Which elements emerge in the visual? Think in terms of panoptic segmentation.",
|
54 |
+
"What are the discernible objects in the graphic? Approach it from a panoptic segmentation viewpoint.",
|
55 |
+
"Can you identify the entities within the image? Consider it as you would for panoptic segmentation.",
|
56 |
+
"What components stand out in the photo? Examine with panoptic segmentation in mind.",
|
57 |
+
"Which items are detectable in the snapshot? Reflect upon it with panoptic segmentation considerations.",
|
58 |
+
"What features can be observed in the illustration? View through the lens of panoptic segmentation.",
|
59 |
+
"How would you describe the objects in this depiction? Keeping panoptic segmentation as a reference.",
|
60 |
+
"What constituents are visible in the visual content? Think from a panoptic segmentation standpoint.",
|
61 |
+
"Which objects can you pinpoint in the image? Evaluate as per panoptic segmentation guidelines.",
|
62 |
+
"What do you perceive in the graphic? Delve into it with panoptic segmentation insights.",
|
63 |
+
"Can you spot specific components in the visual? Contextualize with panoptic segmentation.",
|
64 |
+
"What items in the portrayal catch your attention? Think in relation to panoptic layers.",
|
65 |
+
"How many distinct entities can you recognize in the photo? With a panoptic segmentation approach.",
|
66 |
+
"What elements are present in this visual? Analyze while keeping panoptic segmentation in mind.",
|
67 |
+
"Can you list the objects depicted in the visual representation? Reflecting on panoptic segmentation.",
|
68 |
+
"Which features in the image can you discern? Interpret considering panoptic segmentation.",
|
69 |
+
"What are the components evident in this depiction? Approach it using a panoptic segmentation angle.",
|
70 |
+
"What items can you detect in the visual content? Think panoptically.",
|
71 |
+
"Can you detail the entities present in the image? Assess as one would when considering panoptic segmentation.",
|
72 |
+
]
|
73 |
+
|
74 |
+
DEPTH_QUESTIONS = [
|
75 |
+
"what is depth order of objects in the image?",
|
76 |
+
"Can you describe the depth order of the objects in this image, from closest to farthest?",
|
77 |
+
"Which objects in the image appear nearest to the viewer and which seem furthest away?",
|
78 |
+
"Could you list the objects in the image in order of their perceived distance from the foreground to the background?",
|
79 |
+
"In what order do the objects in this image appear based on their depth, starting from the closest?",
|
80 |
+
"How would you rank the objects in this picture from the most proximal to the most distal?",
|
81 |
+
"Can you arrange the objects seen here from those appearing closest to those appearing farthest?",
|
82 |
+
"What is the sequence of objects in this image based on their distance from the front to the back?",
|
83 |
+
"Please identify the order of objects in terms of depth perspective in this image.",
|
84 |
+
"Which objects in the picture seem to be in the front, and which ones appear to be in the back?",
|
85 |
+
"How are the objects in this image layered in depth, from the one nearest to the camera to the one farthest?",
|
86 |
+
"Could you sort the objects in this photo from foreground to background?",
|
87 |
+
"In this image, what is the spatial arrangement of objects from closest to furthest?",
|
88 |
+
"Can you pinpoint the depth hierarchy of these objects, starting from the closest?",
|
89 |
+
"What's the depth sequence of the objects displayed in this picture?",
|
90 |
+
"From nearest to furthest, how would you order the objects in this image?",
|
91 |
+
"How would you describe the spatial positioning of these objects in terms of their depth?",
|
92 |
+
"Can you determine the depth placement of each object in this photo, starting with the nearest?",
|
93 |
+
"What is the arrangement of objects in this scene by depth?",
|
94 |
+
"Could you outline the depth profile of the objects in this image?",
|
95 |
+
"In what depth order do the objects in this image align, from the frontmost to the rearmost?",
|
96 |
+
"How are the objects in this image ordered in terms of their relative distance from the observer?",
|
97 |
+
]
|
98 |
+
|
99 |
+
QUESTIONS = {
|
100 |
+
'semantic': SEMANTIC_QUESTIONS,
|
101 |
+
'instance': INSTANCE_QUESTIONS,
|
102 |
+
'panoptic': PANOPTIC_QUESTIONS,
|
103 |
+
'depth': DEPTH_QUESTIONS,
|
104 |
+
}
|
105 |
+
|
106 |
+
### Depth Prompts
|
107 |
+
# Can you describe the depth order of the objects in this image, from closest to farthest? Return answer in the paragraph format: `The depth order for the objects present in the image is: ...' and then list the objects with their order number (if greater than 1) separated by a hyphen like `person-2'. For example, an acceptable response is "The depth order for objects present in the image is: bicycle, bicycle-2, bicycle-3, pavement, road, bus, tree, sky, building."
|
108 |
+
|
109 |
+
### Seg Prompts
|
110 |
+
# What objects can be seen in the image? Return the answer in the paragraph format: The objects present in the image are: ...' and then list the objects with their count in word format (if greater than 1) in front of them, like two people'.
|
vcoder_llava/utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from vcoder_llava.constants import LOGDIR
|
10 |
+
|
11 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
12 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
13 |
+
|
14 |
+
handler = None
|
15 |
+
|
16 |
+
|
17 |
+
def build_logger(logger_name, logger_filename):
|
18 |
+
global handler
|
19 |
+
|
20 |
+
formatter = logging.Formatter(
|
21 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
22 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
23 |
+
)
|
24 |
+
|
25 |
+
# Set the format of root handlers
|
26 |
+
if not logging.getLogger().handlers:
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
29 |
+
|
30 |
+
# Redirect stdout and stderr to loggers
|
31 |
+
stdout_logger = logging.getLogger("stdout")
|
32 |
+
stdout_logger.setLevel(logging.INFO)
|
33 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
34 |
+
sys.stdout = sl
|
35 |
+
|
36 |
+
stderr_logger = logging.getLogger("stderr")
|
37 |
+
stderr_logger.setLevel(logging.ERROR)
|
38 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
39 |
+
sys.stderr = sl
|
40 |
+
|
41 |
+
# Get logger
|
42 |
+
logger = logging.getLogger(logger_name)
|
43 |
+
logger.setLevel(logging.INFO)
|
44 |
+
|
45 |
+
# Add a file handler for all loggers
|
46 |
+
if handler is None:
|
47 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
48 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
49 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
50 |
+
filename, when='D', utc=True)
|
51 |
+
handler.setFormatter(formatter)
|
52 |
+
|
53 |
+
for name, item in logging.root.manager.loggerDict.items():
|
54 |
+
if isinstance(item, logging.Logger):
|
55 |
+
item.addHandler(handler)
|
56 |
+
|
57 |
+
return logger
|
58 |
+
|
59 |
+
|
60 |
+
class StreamToLogger(object):
|
61 |
+
"""
|
62 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
63 |
+
"""
|
64 |
+
def __init__(self, logger, log_level=logging.INFO):
|
65 |
+
self.terminal = sys.stdout
|
66 |
+
self.logger = logger
|
67 |
+
self.log_level = log_level
|
68 |
+
self.linebuf = ''
|
69 |
+
|
70 |
+
def __getattr__(self, attr):
|
71 |
+
return getattr(self.terminal, attr)
|
72 |
+
|
73 |
+
def write(self, buf):
|
74 |
+
temp_linebuf = self.linebuf + buf
|
75 |
+
self.linebuf = ''
|
76 |
+
for line in temp_linebuf.splitlines(True):
|
77 |
+
# From the io.TextIOWrapper docs:
|
78 |
+
# On output, if newline is None, any '\n' characters written
|
79 |
+
# are translated to the system default line separator.
|
80 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
81 |
+
# translates them so this is still cross platform.
|
82 |
+
if line[-1] == '\n':
|
83 |
+
self.logger.log(self.log_level, line.rstrip())
|
84 |
+
else:
|
85 |
+
self.linebuf += line
|
86 |
+
|
87 |
+
def flush(self):
|
88 |
+
if self.linebuf != '':
|
89 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
90 |
+
self.linebuf = ''
|
91 |
+
|
92 |
+
|
93 |
+
def disable_torch_init():
|
94 |
+
"""
|
95 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
96 |
+
"""
|
97 |
+
import torch
|
98 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
99 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
100 |
+
|
101 |
+
|
102 |
+
def violates_moderation(text):
|
103 |
+
"""
|
104 |
+
Check whether the text violates OpenAI moderation API.
|
105 |
+
"""
|
106 |
+
url = "https://api.openai.com/v1/moderations"
|
107 |
+
headers = {"Content-Type": "application/json",
|
108 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
109 |
+
text = text.replace("\n", "")
|
110 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
111 |
+
data = data.encode("utf-8")
|
112 |
+
try:
|
113 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
114 |
+
flagged = ret.json()["results"][0]["flagged"]
|
115 |
+
except requests.exceptions.RequestException as e:
|
116 |
+
flagged = False
|
117 |
+
except KeyError as e:
|
118 |
+
flagged = False
|
119 |
+
|
120 |
+
return flagged
|
121 |
+
|
122 |
+
|
123 |
+
def pretty_print_semaphore(semaphore):
|
124 |
+
if semaphore is None:
|
125 |
+
return "None"
|
126 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
vcoder_llava/vcoder_conversation.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
PLAIN = auto()
|
12 |
+
LLAMA_2 = auto()
|
13 |
+
|
14 |
+
|
15 |
+
@dataclasses.dataclass
|
16 |
+
class VCoderConversation:
|
17 |
+
"""A class that keeps all conversation history."""
|
18 |
+
system: str
|
19 |
+
roles: List[str]
|
20 |
+
messages: List[List[str]]
|
21 |
+
offset: int
|
22 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
23 |
+
sep: str = "###"
|
24 |
+
sep2: str = None
|
25 |
+
version: str = "Unknown"
|
26 |
+
|
27 |
+
skip_next: bool = False
|
28 |
+
|
29 |
+
def get_prompt(self):
|
30 |
+
messages = self.messages
|
31 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
32 |
+
ret = self.system + self.sep
|
33 |
+
for role, message in messages:
|
34 |
+
if message:
|
35 |
+
if type(message) is tuple:
|
36 |
+
message, _, _, _, _, _, _ = message
|
37 |
+
ret += role + ": " + message + self.sep
|
38 |
+
else:
|
39 |
+
ret += role + ":"
|
40 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
41 |
+
seps = [self.sep, self.sep2]
|
42 |
+
ret = self.system + seps[0]
|
43 |
+
for i, (role, message) in enumerate(messages):
|
44 |
+
if message:
|
45 |
+
if type(message) is tuple:
|
46 |
+
message, _, _, _, _, _, _ = message
|
47 |
+
ret += role + ": " + message + seps[i % 2]
|
48 |
+
else:
|
49 |
+
ret += role + ":"
|
50 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
51 |
+
ret = self.system + self.sep
|
52 |
+
for role, message in messages:
|
53 |
+
if message:
|
54 |
+
if type(message) is tuple:
|
55 |
+
message, _, _, _, _, _, _ = message
|
56 |
+
ret += role + message + self.sep
|
57 |
+
else:
|
58 |
+
ret += role
|
59 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
60 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
61 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
62 |
+
ret = ""
|
63 |
+
|
64 |
+
for i, (role, message) in enumerate(messages):
|
65 |
+
if i == 0:
|
66 |
+
assert message, "first message should not be none"
|
67 |
+
assert role == self.roles[0], "first message should come from user"
|
68 |
+
if message:
|
69 |
+
if type(message) is tuple:
|
70 |
+
message, _, _, _, _, _, _ = message
|
71 |
+
if i == 0: message = wrap_sys(self.system) + message
|
72 |
+
if i % 2 == 0:
|
73 |
+
message = wrap_inst(message)
|
74 |
+
ret += self.sep + message
|
75 |
+
else:
|
76 |
+
ret += " " + message + " " + self.sep2
|
77 |
+
else:
|
78 |
+
ret += ""
|
79 |
+
ret = ret.lstrip(self.sep)
|
80 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
81 |
+
seps = [self.sep, self.sep2]
|
82 |
+
ret = self.system
|
83 |
+
for i, (role, message) in enumerate(messages):
|
84 |
+
if message:
|
85 |
+
if type(message) is tuple:
|
86 |
+
message, _, _, _, _, _, _ = message
|
87 |
+
ret += message + seps[i % 2]
|
88 |
+
else:
|
89 |
+
ret += ""
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
92 |
+
|
93 |
+
return ret
|
94 |
+
|
95 |
+
def append_message(self, role, message):
|
96 |
+
self.messages.append([role, message])
|
97 |
+
|
98 |
+
def get_images(self, return_pil=False):
|
99 |
+
images = []
|
100 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
101 |
+
if i % 2 == 0:
|
102 |
+
if type(msg) is tuple:
|
103 |
+
import base64
|
104 |
+
from io import BytesIO
|
105 |
+
from PIL import Image
|
106 |
+
msg, image, image_process_mode, _, _, _, _ = msg
|
107 |
+
if image is not None:
|
108 |
+
if image_process_mode == "Pad":
|
109 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
110 |
+
width, height = pil_img.size
|
111 |
+
if width == height:
|
112 |
+
return pil_img
|
113 |
+
elif width > height:
|
114 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
115 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
116 |
+
return result
|
117 |
+
else:
|
118 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
119 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
120 |
+
return result
|
121 |
+
image = expand2square(image)
|
122 |
+
elif image_process_mode in ["Default", "Crop"]:
|
123 |
+
pass
|
124 |
+
elif image_process_mode == "Resize":
|
125 |
+
image = image.resize((336, 336))
|
126 |
+
else:
|
127 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
128 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
129 |
+
aspect_ratio = max_hw / min_hw
|
130 |
+
max_len, min_len = 800, 400
|
131 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
132 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
133 |
+
W, H = image.size
|
134 |
+
if longest_edge != max(image.size):
|
135 |
+
if H > W:
|
136 |
+
H, W = longest_edge, shortest_edge
|
137 |
+
else:
|
138 |
+
H, W = shortest_edge, longest_edge
|
139 |
+
image = image.resize((W, H))
|
140 |
+
if return_pil:
|
141 |
+
images.append(image)
|
142 |
+
else:
|
143 |
+
buffered = BytesIO()
|
144 |
+
image.save(buffered, format="PNG")
|
145 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
146 |
+
images.append(img_b64_str)
|
147 |
+
return images
|
148 |
+
|
149 |
+
def get_segs(self, return_pil=False):
|
150 |
+
segs = []
|
151 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
152 |
+
if i % 2 == 0:
|
153 |
+
if type(msg) is tuple:
|
154 |
+
import base64
|
155 |
+
from io import BytesIO
|
156 |
+
from PIL import Image
|
157 |
+
msg, _, _, seg, seg_process_mode, _, _ = msg
|
158 |
+
if seg is not None:
|
159 |
+
if seg_process_mode == "Pad":
|
160 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
161 |
+
width, height = pil_img.size
|
162 |
+
if width == height:
|
163 |
+
return pil_img
|
164 |
+
elif width > height:
|
165 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
166 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
167 |
+
return result
|
168 |
+
else:
|
169 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
170 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
171 |
+
return result
|
172 |
+
seg = expand2square(seg)
|
173 |
+
elif seg_process_mode in ["Default", "Crop"]:
|
174 |
+
pass
|
175 |
+
elif seg_process_mode == "Resize":
|
176 |
+
seg = seg.resize((336, 336))
|
177 |
+
else:
|
178 |
+
raise ValueError(f"Invalid image_process_mode: {seg_process_mode}")
|
179 |
+
max_hw, min_hw = max(seg.size), min(seg.size)
|
180 |
+
aspect_ratio = max_hw / min_hw
|
181 |
+
max_len, min_len = 800, 400
|
182 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
183 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
184 |
+
W, H = seg.size
|
185 |
+
if longest_edge != max(seg.size):
|
186 |
+
if H > W:
|
187 |
+
H, W = longest_edge, shortest_edge
|
188 |
+
else:
|
189 |
+
H, W = shortest_edge, longest_edge
|
190 |
+
seg = seg.resize((W, H))
|
191 |
+
if return_pil:
|
192 |
+
segs.append(seg)
|
193 |
+
else:
|
194 |
+
buffered = BytesIO()
|
195 |
+
seg.save(buffered, format="PNG")
|
196 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
197 |
+
segs.append(img_b64_str)
|
198 |
+
return segs
|
199 |
+
|
200 |
+
def get_depths(self, return_pil=False):
|
201 |
+
depths = []
|
202 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
203 |
+
if i % 2 == 0:
|
204 |
+
if type(msg) is tuple:
|
205 |
+
import base64
|
206 |
+
from io import BytesIO
|
207 |
+
from PIL import Image
|
208 |
+
msg, _, _, _, _, depth, depth_process_mode = msg
|
209 |
+
if depth is not None:
|
210 |
+
if depth_process_mode == "Pad":
|
211 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
212 |
+
width, height = pil_img.size
|
213 |
+
if width == height:
|
214 |
+
return pil_img
|
215 |
+
elif width > height:
|
216 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
217 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
218 |
+
return result
|
219 |
+
else:
|
220 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
221 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
222 |
+
return result
|
223 |
+
depth = expand2square(depth)
|
224 |
+
elif depth_process_mode in ["Default", "Crop"]:
|
225 |
+
pass
|
226 |
+
elif depth_process_mode == "Resize":
|
227 |
+
depth = depth.resize((336, 336))
|
228 |
+
else:
|
229 |
+
raise ValueError(f"Invalid image_process_mode: {depth_process_mode}")
|
230 |
+
max_hw, min_hw = max(depth.size), min(depth.size)
|
231 |
+
aspect_ratio = max_hw / min_hw
|
232 |
+
max_len, min_len = 800, 400
|
233 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
234 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
235 |
+
W, H = depth.size
|
236 |
+
if longest_edge != max(depth.size):
|
237 |
+
if H > W:
|
238 |
+
H, W = longest_edge, shortest_edge
|
239 |
+
else:
|
240 |
+
H, W = shortest_edge, longest_edge
|
241 |
+
depth = depth.resize((W, H))
|
242 |
+
if return_pil:
|
243 |
+
depths.append(depth)
|
244 |
+
else:
|
245 |
+
buffered = BytesIO()
|
246 |
+
depth.save(buffered, format="PNG")
|
247 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
248 |
+
depths.append(img_b64_str)
|
249 |
+
return depths
|
250 |
+
|
251 |
+
def to_gradio_chatbot(self):
|
252 |
+
ret = []
|
253 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
254 |
+
if i % 2 == 0:
|
255 |
+
if type(msg) is tuple:
|
256 |
+
import base64
|
257 |
+
from io import BytesIO
|
258 |
+
msg, image, image_process_mode, seg, seg_process_mode, depth, depth_process_mode = msg
|
259 |
+
if image is not None:
|
260 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
261 |
+
aspect_ratio = max_hw / min_hw
|
262 |
+
max_len, min_len = 800, 400
|
263 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
264 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
265 |
+
W, H = image.size
|
266 |
+
if H > W:
|
267 |
+
H, W = longest_edge, shortest_edge
|
268 |
+
else:
|
269 |
+
H, W = shortest_edge, longest_edge
|
270 |
+
image = image.resize((W, H))
|
271 |
+
buffered = BytesIO()
|
272 |
+
image.save(buffered, format="JPEG")
|
273 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
274 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
275 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
276 |
+
|
277 |
+
if seg is not None:
|
278 |
+
W, H = seg.size
|
279 |
+
if H > W:
|
280 |
+
H, W = longest_edge, shortest_edge
|
281 |
+
else:
|
282 |
+
H, W = shortest_edge, longest_edge
|
283 |
+
seg = seg.resize((W, H))
|
284 |
+
seg_buffered = BytesIO()
|
285 |
+
seg.save(seg_buffered, format="JPEG")
|
286 |
+
seg_b64_str = base64.b64encode(seg_buffered.getvalue()).decode()
|
287 |
+
seg_str = f'<img src="data:image/png;base64,{seg_b64_str}" alt="user upload seg" />'
|
288 |
+
msg = seg_str + msg.replace('<seg>', '').strip()
|
289 |
+
|
290 |
+
if depth is not None:
|
291 |
+
W, H = depth.size
|
292 |
+
if H > W:
|
293 |
+
H, W = longest_edge, shortest_edge
|
294 |
+
else:
|
295 |
+
H, W = shortest_edge, longest_edge
|
296 |
+
depth = depth.resize((W, H))
|
297 |
+
depth_buffered = BytesIO()
|
298 |
+
depth.save(depth_buffered, format="JPEG")
|
299 |
+
depth_b64_str = base64.b64encode(depth_buffered.getvalue()).decode()
|
300 |
+
depth_str = f'<img src="data:image/png;base64,{depth_b64_str}" alt="user upload depth" />'
|
301 |
+
msg = depth_str + msg.replace('<depth>', '').strip()
|
302 |
+
ret.append([msg, None])
|
303 |
+
else:
|
304 |
+
ret.append([msg, None])
|
305 |
+
else:
|
306 |
+
ret[-1][-1] = msg
|
307 |
+
return ret
|
308 |
+
|
309 |
+
def copy(self):
|
310 |
+
return VCoderConversation(
|
311 |
+
system=self.system,
|
312 |
+
roles=self.roles,
|
313 |
+
messages=[[x, y] for x, y in self.messages],
|
314 |
+
offset=self.offset,
|
315 |
+
sep_style=self.sep_style,
|
316 |
+
sep=self.sep,
|
317 |
+
sep2=self.sep2,
|
318 |
+
version=self.version)
|
319 |
+
|
320 |
+
def dict(self):
|
321 |
+
if len(self.get_images()) > 0:
|
322 |
+
return {
|
323 |
+
"system": self.system,
|
324 |
+
"roles": self.roles,
|
325 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
326 |
+
"offset": self.offset,
|
327 |
+
"sep": self.sep,
|
328 |
+
"sep2": self.sep2,
|
329 |
+
}
|
330 |
+
return {
|
331 |
+
"system": self.system,
|
332 |
+
"roles": self.roles,
|
333 |
+
"messages": self.messages,
|
334 |
+
"offset": self.offset,
|
335 |
+
"sep": self.sep,
|
336 |
+
"sep2": self.sep2,
|
337 |
+
}
|
338 |
+
|
339 |
+
|
340 |
+
conv_vicuna_v1 = VCoderConversation(
|
341 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
342 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
343 |
+
roles=("USER", "ASSISTANT"),
|
344 |
+
version="v1",
|
345 |
+
messages=(),
|
346 |
+
offset=0,
|
347 |
+
sep_style=SeparatorStyle.TWO,
|
348 |
+
sep=" ",
|
349 |
+
sep2="</s>",
|
350 |
+
)
|
351 |
+
|
352 |
+
conv_llava_v1 = VCoderConversation(
|
353 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
354 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
355 |
+
roles=("USER", "ASSISTANT"),
|
356 |
+
version="v1",
|
357 |
+
messages=(),
|
358 |
+
offset=0,
|
359 |
+
sep_style=SeparatorStyle.TWO,
|
360 |
+
sep=" ",
|
361 |
+
sep2="</s>",
|
362 |
+
)
|
363 |
+
|
364 |
+
|
365 |
+
default_conversation = conv_vicuna_v1
|
366 |
+
conv_templates = {
|
367 |
+
"v1": conv_vicuna_v1,
|
368 |
+
"vicuna_v1": conv_vicuna_v1,
|
369 |
+
"llava_v1": conv_llava_v1,
|
370 |
+
}
|
371 |
+
|
372 |
+
|
373 |
+
if __name__ == "__main__":
|
374 |
+
print(default_conversation.get_prompt())
|