ga89tiy
commited on
Commit
•
db6ee6a
1
Parent(s):
b56b523
Initial model commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LLAVA_Biovil/.dockerignore +21 -0
- LLAVA_Biovil/.editorconfig +18 -0
- LLAVA_Biovil/.gitattributes +29 -0
- LLAVA_Biovil/.gitignore +35 -0
- LLAVA_Biovil/LICENSE +201 -0
- LLAVA_Biovil/README.md +410 -0
- LLAVA_Biovil/__init__.py +0 -0
- LLAVA_Biovil/biovil_t/__init__.py +0 -0
- LLAVA_Biovil/biovil_t/encoder.py +180 -0
- LLAVA_Biovil/biovil_t/model.py +130 -0
- LLAVA_Biovil/biovil_t/modules.py +85 -0
- LLAVA_Biovil/biovil_t/pretrained.py +85 -0
- LLAVA_Biovil/biovil_t/resnet.py +80 -0
- LLAVA_Biovil/biovil_t/transformer.py +266 -0
- LLAVA_Biovil/biovil_t/types.py +37 -0
- LLAVA_Biovil/cog.yaml +37 -0
- LLAVA_Biovil/install.md +6 -0
- LLAVA_Biovil/llava/__init__.py +1 -0
- LLAVA_Biovil/llava/constants.py +13 -0
- LLAVA_Biovil/llava/conversation.py +414 -0
- LLAVA_Biovil/llava/eval/__init__.py +0 -0
- LLAVA_Biovil/llava/eval/eval_gpt_review.py +113 -0
- LLAVA_Biovil/llava/eval/eval_gpt_review_bench.py +121 -0
- LLAVA_Biovil/llava/eval/eval_gpt_review_visual.py +118 -0
- LLAVA_Biovil/llava/eval/eval_pope.py +81 -0
- LLAVA_Biovil/llava/eval/eval_science_qa.py +114 -0
- LLAVA_Biovil/llava/eval/eval_science_qa_gpt4.py +104 -0
- LLAVA_Biovil/llava/eval/eval_science_qa_gpt4_requery.py +149 -0
- LLAVA_Biovil/llava/eval/eval_textvqa.py +65 -0
- LLAVA_Biovil/llava/eval/generate_webpage_data_from_table.py +111 -0
- LLAVA_Biovil/llava/eval/m4c_evaluator.py +334 -0
- LLAVA_Biovil/llava/eval/model_qa.py +85 -0
- LLAVA_Biovil/llava/eval/model_vqa.py +112 -0
- LLAVA_Biovil/llava/eval/model_vqa_loader.py +141 -0
- LLAVA_Biovil/llava/eval/model_vqa_mmbench.py +169 -0
- LLAVA_Biovil/llava/eval/model_vqa_qbench.py +120 -0
- LLAVA_Biovil/llava/eval/model_vqa_science.py +147 -0
- LLAVA_Biovil/llava/eval/qa_baseline_gpt35.py +74 -0
- LLAVA_Biovil/llava/eval/run_llava.py +155 -0
- LLAVA_Biovil/llava/eval/summarize_gpt_review.py +60 -0
- LLAVA_Biovil/llava/eval/webpage/figures/alpaca.png +0 -0
- LLAVA_Biovil/llava/eval/webpage/figures/bard.jpg +0 -0
- LLAVA_Biovil/llava/eval/webpage/figures/chatgpt.svg +1 -0
- LLAVA_Biovil/llava/eval/webpage/figures/llama.jpg +0 -0
- LLAVA_Biovil/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg +1 -0
- LLAVA_Biovil/llava/eval/webpage/figures/vicuna.jpeg +0 -0
- LLAVA_Biovil/llava/eval/webpage/index.html +162 -0
- LLAVA_Biovil/llava/eval/webpage/script.js +245 -0
- LLAVA_Biovil/llava/eval/webpage/styles.css +105 -0
- LLAVA_Biovil/llava/mm_utils.py +148 -0
LLAVA_Biovil/.dockerignore
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The .dockerignore file excludes files from the container build process.
|
2 |
+
#
|
3 |
+
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
|
4 |
+
|
5 |
+
# Exclude Git files
|
6 |
+
.git
|
7 |
+
.github
|
8 |
+
.gitignore
|
9 |
+
|
10 |
+
# Exclude Python cache files
|
11 |
+
__pycache__
|
12 |
+
.mypy_cache
|
13 |
+
.pytest_cache
|
14 |
+
.ruff_cache
|
15 |
+
|
16 |
+
# Exclude Python virtual environment
|
17 |
+
/venv
|
18 |
+
|
19 |
+
# Exclude some weights
|
20 |
+
/openai
|
21 |
+
/liuhaotian
|
LLAVA_Biovil/.editorconfig
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
root = true
|
2 |
+
|
3 |
+
# Unix-style newlines with a newline ending every file
|
4 |
+
[*]
|
5 |
+
end_of_line = lf
|
6 |
+
insert_final_newline = true
|
7 |
+
trim_trailing_whitespace = true
|
8 |
+
charset = utf-8
|
9 |
+
|
10 |
+
# 4 space indentation
|
11 |
+
[*.{py,json}]
|
12 |
+
indent_style = space
|
13 |
+
indent_size = 4
|
14 |
+
|
15 |
+
# 2 space indentation
|
16 |
+
[*.{md,sh,yaml,yml}]
|
17 |
+
indent_style = space
|
18 |
+
indent_size = 2
|
LLAVA_Biovil/.gitattributes
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://git-scm.com/docs/gitattributes
|
2 |
+
|
3 |
+
# Set the default behavior, in case people don't have core.autocrlf set.
|
4 |
+
# https://git-scm.com/docs/gitattributes#_end_of_line_conversion
|
5 |
+
* text=auto
|
6 |
+
|
7 |
+
# common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes
|
8 |
+
# Source files
|
9 |
+
# ============
|
10 |
+
*.pxd text diff=python
|
11 |
+
*.py text diff=python
|
12 |
+
*.py3 text diff=python
|
13 |
+
*.pyw text diff=python
|
14 |
+
*.pyx text diff=python
|
15 |
+
*.pyz text diff=python
|
16 |
+
*.pyi text diff=python
|
17 |
+
|
18 |
+
# Binary files
|
19 |
+
# ============
|
20 |
+
*.db binary
|
21 |
+
*.p binary
|
22 |
+
*.pkl binary
|
23 |
+
*.pickle binary
|
24 |
+
*.pyc binary export-ignore
|
25 |
+
*.pyo binary export-ignore
|
26 |
+
*.pyd binary
|
27 |
+
|
28 |
+
# Jupyter notebook
|
29 |
+
*.ipynb text eol=lf
|
LLAVA_Biovil/.gitignore
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__
|
3 |
+
*.pyc
|
4 |
+
*.egg-info
|
5 |
+
dist
|
6 |
+
|
7 |
+
# Log
|
8 |
+
*.log
|
9 |
+
*.log.*
|
10 |
+
*.json
|
11 |
+
*.jsonl
|
12 |
+
|
13 |
+
# Data
|
14 |
+
!**/alpaca-data-conversation.json
|
15 |
+
|
16 |
+
# Editor
|
17 |
+
../.idea
|
18 |
+
*.swp
|
19 |
+
|
20 |
+
# Other
|
21 |
+
.DS_Store
|
22 |
+
wandb
|
23 |
+
output
|
24 |
+
|
25 |
+
checkpoints
|
26 |
+
ckpts*
|
27 |
+
|
28 |
+
.ipynb_checkpoints
|
29 |
+
*.ipynb
|
30 |
+
|
31 |
+
# DevContainer
|
32 |
+
!.devcontainer/*
|
33 |
+
|
34 |
+
# Demo
|
35 |
+
serve_images/
|
LLAVA_Biovil/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
LLAVA_Biovil/README.md
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🌋 LLaVA: Large Language and Vision Assistant
|
2 |
+
|
3 |
+
*Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.*
|
4 |
+
|
5 |
+
[[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)]
|
6 |
+
|
7 |
+
🤝Community Contributions: [[llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436)] [[Colab](https://github.com/camenduru/LLaVA-colab)] [[🤗Space](https://huggingface.co/spaces/badayvedat/LLaVA)] [[Replicate](https://replicate.com/yorickvp/llava-13b)] [[AutoGen](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb)] [[BakLLaVA (LLaVA with Mistral-7B)](https://github.com/SkunkworksAI/BakLLaVA)]
|
8 |
+
|
9 |
+
**Improved Baselines with Visual Instruction Tuning** [[Paper](https://arxiv.org/abs/2310.03744)] <br>
|
10 |
+
[Haotian Liu](https://hliu.cc), [Chunyuan Li](https://chunyuan.li/), [Yuheng Li](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/)
|
11 |
+
|
12 |
+
**Visual Instruction Tuning** (NeurIPS 2023, **Oral**) [[Paper](https://arxiv.org/abs/2304.08485)]<br>
|
13 |
+
[Haotian Liu*](https://hliu.cc), [Chunyuan Li*](https://chunyuan.li/), [Qingyang Wu](https://scholar.google.ca/citations?user=HDiw-TsAAAAJ&hl=en/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/) (*Equal Contribution)
|
14 |
+
|
15 |
+
<!--p align="center">
|
16 |
+
<a href="https://llava.hliu.cc/"><img src="images/llava_logo.png" width="50%"></a> <br>
|
17 |
+
Generated by <a href="https://gligen.github.io/">GLIGEN</a> via "a cute lava llama with glasses" and box prompt
|
18 |
+
</p-->
|
19 |
+
|
20 |
+
|
21 |
+
## Release
|
22 |
+
- [11/10] [LLaVA-Plus](https://llava-vl.github.io/llava-plus/) is released: Learning to Use Tools for Creating Multimodal Agents, with LLaVA-Plus (LLaVA that Plug and Learn to Use Skills). [[Project Page](https://llava-vl.github.io/llava-plus/)] [[Demo](https://llavaplus.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Plus-Codebase)] [[Paper](https://arxiv.org/abs/2311.05437)]
|
23 |
+
- [11/6] Support **Intel** dGPU and CPU platforms. [More details here.](https://github.com/haotian-liu/LLaVA/tree/intel/docs/intel)
|
24 |
+
- [11/2] [LLaVA-Interactive](https://llava-vl.github.io/llava-interactive/) is released: Experience the future of human-AI multimodal interaction with an all-in-one demo for Image Chat, Segmentation, Generation and Editing. [[Project Page](https://llava-vl.github.io/llava-interactive/)] [[Demo](https://llavainteractive.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Interactive-Demo)] [[Paper](https://arxiv.org/abs/2311.00571)]
|
25 |
+
- [10/26] 🔥 LLaVA-1.5 with LoRA achieves comparable performance as full-model finetuning, with a reduced GPU RAM requirement ([ckpts](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#llava-v15), [script](https://github.com/haotian-liu/LLaVA#train)). We also provide a [doc](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md) on how to finetune LLaVA-1.5 on your own dataset with LoRA.
|
26 |
+
- [10/12] Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! [[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)]
|
27 |
+
- [10/12] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support!
|
28 |
+
- [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
|
29 |
+
- [10/10] [Roboflow Deep Dive](https://blog.roboflow.com/first-impressions-with-llava-1-5/): First Impressions with LLaVA-1.5.
|
30 |
+
- [10/5] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
|
31 |
+
- [9/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/)
|
32 |
+
- [9/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.
|
33 |
+
- [9/20] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)
|
34 |
+
<p align="center">
|
35 |
+
<img src="https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings/blob/main/images/mfm_evolution.jpeg?raw=true" width=50%/>
|
36 |
+
</p>
|
37 |
+
|
38 |
+
- [7/19] 🔥 We release a major upgrade, including support for LLaMA-2, LoRA training, 4-/8-bit inference, higher resolution (336x336), and a lot more. We release [LLaVA Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) for benchmarking open-ended visual chat with results from Bard and Bing-Chat. We also support and verify training with RTX 3090 and RTX A6000. Check out [LLaVA-from-LLaMA-2](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_from_LLaMA2.md), and our [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)!
|
39 |
+
- [6/26] [CVPR 2023 Tutorial](https://vlp-tutorial.github.io/) on **Large Multimodal Models: Towards Building and Surpassing Multimodal GPT-4**! Please check out [[Slides](https://datarelease.blob.core.windows.net/tutorial/vision_foundation_models_2023/slides/Chunyuan_cvpr2023_tutorial_lmm.pdf)] [[Notes](https://arxiv.org/abs/2306.14895)] [[YouTube](https://youtu.be/mkI7EPD1vp8)] [[Bilibli](https://www.bilibili.com/video/BV1Ng4y1T7v3/)].
|
40 |
+
- [6/11] We released the preview for the most requested feature: DeepSpeed and LoRA support! Please see documentations [here](./docs/LoRA.md).
|
41 |
+
- [6/1] We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890) and [page](https://github.com/microsoft/LLaVA-Med).
|
42 |
+
- [5/6] We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details.
|
43 |
+
- [5/2] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details.
|
44 |
+
- [4/27] Thanks to the community effort, LLaVA-13B with 4-bit quantization allows you to run on a GPU with as few as 12GB VRAM! Try it out [here](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava).
|
45 |
+
- [4/17] 🔥 We released **LLaVA: Large Language and Vision Assistant**. We propose visual instruction tuning, towards building large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2304.08485) and [demo](https://llava.hliu.cc/).
|
46 |
+
|
47 |
+
<!-- <a href="https://llava.hliu.cc/"><img src="assets/demo.gif" width="70%"></a> -->
|
48 |
+
|
49 |
+
[![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
|
50 |
+
[![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE)
|
51 |
+
**Usage and License Notices**: The data and checkpoint is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
|
52 |
+
|
53 |
+
|
54 |
+
## Contents
|
55 |
+
- [Install](#install)
|
56 |
+
- [LLaVA Weights](#llava-weights)
|
57 |
+
- [Demo](#Demo)
|
58 |
+
- [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
|
59 |
+
- [Dataset](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)
|
60 |
+
- [Train](#train)
|
61 |
+
- [Evaluation](#evaluation)
|
62 |
+
|
63 |
+
## Install
|
64 |
+
|
65 |
+
If you are not using Linux, do *NOT* proceed, see instructions for [macOS](https://github.com/haotian-liu/LLaVA/blob/main/docs/macOS.md) and [Windows](https://github.com/haotian-liu/LLaVA/blob/main/docs/Windows.md).
|
66 |
+
|
67 |
+
1. Clone this repository and navigate to LLaVA folder
|
68 |
+
```bash
|
69 |
+
git clone https://github.com/haotian-liu/LLaVA.git
|
70 |
+
cd LLaVA
|
71 |
+
```
|
72 |
+
|
73 |
+
2. Install Package
|
74 |
+
```Shell
|
75 |
+
conda create -n llava python=3.10 -y
|
76 |
+
conda activate llava
|
77 |
+
pip install --upgrade pip # enable PEP 660 support
|
78 |
+
pip install -e .
|
79 |
+
```
|
80 |
+
|
81 |
+
3. Install additional packages for training cases
|
82 |
+
```
|
83 |
+
pip install -e ".[train]"
|
84 |
+
pip install flash-attn --no-build-isolation
|
85 |
+
```
|
86 |
+
|
87 |
+
### Upgrade to latest code base
|
88 |
+
|
89 |
+
```Shell
|
90 |
+
git pull
|
91 |
+
pip install -e .
|
92 |
+
```
|
93 |
+
|
94 |
+
### Quick Start With HuggingFace
|
95 |
+
|
96 |
+
<details>
|
97 |
+
<summary>Example Code</summary>
|
98 |
+
|
99 |
+
```Python
|
100 |
+
from LLAV.llava import load_pretrained_model
|
101 |
+
from LLAV.llava import get_model_name_from_path
|
102 |
+
from LLAV.llava import eval_model
|
103 |
+
|
104 |
+
model_path = "liuhaotian/llava-v1.5-7b"
|
105 |
+
|
106 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
107 |
+
model_path=model_path,
|
108 |
+
model_base=None,
|
109 |
+
model_name=get_model_name_from_path(model_path)
|
110 |
+
)
|
111 |
+
```
|
112 |
+
|
113 |
+
Check out the details wth the `load_pretrained_model` function in `llava/model/builder.py`.
|
114 |
+
|
115 |
+
You can also use the `eval_model` function in `llava/eval/run_llava.py` to get the output easily. By doing so, you can use this code on Colab directly after downloading this repository.
|
116 |
+
|
117 |
+
``` python
|
118 |
+
model_path = "liuhaotian/llava-v1.5-7b"
|
119 |
+
prompt = "What are the things I should be cautious about when I visit here?"
|
120 |
+
image_file = "https://llava-vl.github.io/static/images/view.jpg"
|
121 |
+
|
122 |
+
args = type('Args', (), {
|
123 |
+
"model_path": model_path,
|
124 |
+
"model_base": None,
|
125 |
+
"model_name": get_model_name_from_path(model_path),
|
126 |
+
"query": prompt,
|
127 |
+
"conv_mode": None,
|
128 |
+
"image_file": image_file,
|
129 |
+
"sep": ",",
|
130 |
+
})()
|
131 |
+
|
132 |
+
eval_model(args)
|
133 |
+
```
|
134 |
+
</details>
|
135 |
+
|
136 |
+
## LLaVA Weights
|
137 |
+
Please check out our [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) for all public LLaVA checkpoints, and the instructions of how to use the weights.
|
138 |
+
|
139 |
+
## Demo
|
140 |
+
|
141 |
+
To run our demo, you need to prepare LLaVA checkpoints locally. Please follow the instructions [here](#llava-weights) to download the checkpoints.
|
142 |
+
|
143 |
+
### Gradio Web UI
|
144 |
+
|
145 |
+
To launch a Gradio demo locally, please run the following commands one by one. If you plan to launch multiple model workers to compare between different checkpoints, you only need to launch the controller and the web server *ONCE*.
|
146 |
+
|
147 |
+
```mermaid
|
148 |
+
flowchart BT
|
149 |
+
%% Declare Nodes
|
150 |
+
gws("Gradio (UI Server)")
|
151 |
+
c("Controller (API Server):<br/>PORT: 10000")
|
152 |
+
mw7b("Model Worker:<br/>llava-v1.5-7b<br/>PORT: 40000")
|
153 |
+
mw13b("Model Worker:<br/>llava-v1.5-13b<br/>PORT: 40001")
|
154 |
+
|
155 |
+
%% Declare Styles
|
156 |
+
classDef data fill:#3af,stroke:#48a,stroke-width:2px,color:#444
|
157 |
+
classDef success fill:#8f8,stroke:#0a0,stroke-width:2px,color:#444
|
158 |
+
classDef failure fill:#f88,stroke:#f00,stroke-width:2px,color:#444
|
159 |
+
|
160 |
+
%% Assign Styles
|
161 |
+
class id,od data;
|
162 |
+
class cimg,cs_s,scsim_s success;
|
163 |
+
class ncimg,cs_f,scsim_f failure;
|
164 |
+
|
165 |
+
subgraph Demo Connections
|
166 |
+
direction BT
|
167 |
+
c<-->gws
|
168 |
+
|
169 |
+
mw7b<-->c
|
170 |
+
mw13b<-->c
|
171 |
+
end
|
172 |
+
```
|
173 |
+
|
174 |
+
#### Launch a controller
|
175 |
+
```Shell
|
176 |
+
python -m llava.serve.controller --host 0.0.0.0 --port 10000
|
177 |
+
```
|
178 |
+
|
179 |
+
#### Launch a gradio web server.
|
180 |
+
```Shell
|
181 |
+
python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload
|
182 |
+
```
|
183 |
+
You just launched the Gradio web interface. Now, you can open the web interface with the URL printed on the screen. You may notice that there is no model in the model list. Do not worry, as we have not launched any model worker yet. It will be automatically updated when you launch a model worker.
|
184 |
+
|
185 |
+
#### Launch a model worker
|
186 |
+
|
187 |
+
This is the actual *worker* that performs the inference on the GPU. Each worker is responsible for a single model specified in `--model-path`.
|
188 |
+
|
189 |
+
```Shell
|
190 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b
|
191 |
+
```
|
192 |
+
Wait until the process finishes loading the model and you see "Uvicorn running on ...". Now, refresh your Gradio web UI, and you will see the model you just launched in the model list.
|
193 |
+
|
194 |
+
You can launch as many workers as you want, and compare between different model checkpoints in the same Gradio interface. Please keep the `--controller` the same, and modify the `--port` and `--worker` to a different port number for each worker.
|
195 |
+
```Shell
|
196 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port <different from 40000, say 40001> --worker http://localhost:<change accordingly, i.e. 40001> --model-path <ckpt2>
|
197 |
+
```
|
198 |
+
|
199 |
+
If you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the `--device` flag: `--device mps`.
|
200 |
+
|
201 |
+
#### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB)
|
202 |
+
|
203 |
+
If the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs. Our latest code base will automatically try to use multiple GPUs if you have more than one GPU. You can specify which GPUs to use with `CUDA_VISIBLE_DEVICES`. Below is an example of running with the first two GPUs.
|
204 |
+
|
205 |
+
```Shell
|
206 |
+
CUDA_VISIBLE_DEVICES=0,1 python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b
|
207 |
+
```
|
208 |
+
|
209 |
+
#### Launch a model worker (4-bit, 8-bit inference, quantized)
|
210 |
+
|
211 |
+
You can launch the model worker with quantized bits (4-bit, 8-bit), which allows you to run the inference with reduced GPU memory footprint, potentially allowing you to run on a GPU with as few as 12GB VRAM. Note that inference with quantized bits may not be as accurate as the full-precision model. Simply append `--load-4bit` or `--load-8bit` to the **model worker** command that you are executing. Below is an example of running with 4-bit quantization.
|
212 |
+
|
213 |
+
```Shell
|
214 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b --load-4bit
|
215 |
+
```
|
216 |
+
|
217 |
+
#### Launch a model worker (LoRA weights, unmerged)
|
218 |
+
|
219 |
+
You can launch the model worker with LoRA weights, without merging them with the base checkpoint, to save disk space. There will be additional loading time, while the inference speed is the same as the merged checkpoints. Unmerged LoRA checkpoints do not have `lora-merge` in the model name, and are usually much smaller (less than 1GB) than the merged checkpoints (13G for 7B, and 25G for 13B).
|
220 |
+
|
221 |
+
To load unmerged LoRA weights, you simply need to pass an additional argument `--model-base`, which is the base LLM that is used to train the LoRA weights. You can check the base LLM of each LoRA weights in the [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
|
222 |
+
|
223 |
+
```Shell
|
224 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1-0719-336px-lora-vicuna-13b-v1.3 --model-base lmsys/vicuna-13b-v1.3
|
225 |
+
```
|
226 |
+
|
227 |
+
### CLI Inference
|
228 |
+
|
229 |
+
Chat about images using LLaVA without the need of Gradio interface. It also supports multiple GPUs, 4-bit and 8-bit quantized inference. With 4-bit quantization, for our LLaVA-1.5-7B, it uses less than 8GB VRAM on a single GPU.
|
230 |
+
|
231 |
+
```Shell
|
232 |
+
python -m llava.serve.cli \
|
233 |
+
--model-path liuhaotian/llava-v1.5-7b \
|
234 |
+
--image-file "https://llava-vl.github.io/static/images/view.jpg" \
|
235 |
+
--load-4bit
|
236 |
+
```
|
237 |
+
|
238 |
+
<img src="images/demo_cli.gif" width="70%">
|
239 |
+
|
240 |
+
## Train
|
241 |
+
|
242 |
+
*Below is the latest training configuration for LLaVA v1.5. For legacy models, please refer to README of [this](https://github.com/haotian-liu/LLaVA/tree/v1.0.1) version for now. We'll add them in a separate doc later.*
|
243 |
+
|
244 |
+
LLaVA training consists of two stages: (1) feature alignment stage: use our 558K subset of the LAION-CC-SBU dataset to connect a *frozen pretrained* vision encoder to a *frozen LLM*; (2) visual instruction tuning stage: use 150K GPT-generated multimodal instruction-following data, plus around 515K VQA data from academic-oriented tasks, to teach the model to follow multimodal instructions.
|
245 |
+
|
246 |
+
LLaVA is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
|
247 |
+
|
248 |
+
### Hyperparameters
|
249 |
+
We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
|
250 |
+
|
251 |
+
1. Pretraining
|
252 |
+
|
253 |
+
| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
|
254 |
+
| --- | ---: | ---: | ---: | ---: | ---: |
|
255 |
+
| LLaVA-v1.5-13B | 256 | 1e-3 | 1 | 2048 | 0 |
|
256 |
+
|
257 |
+
2. Finetuning
|
258 |
+
|
259 |
+
| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
|
260 |
+
| --- | ---: | ---: | ---: | ---: | ---: |
|
261 |
+
| LLaVA-v1.5-13B | 128 | 2e-5 | 1 | 2048 | 0 |
|
262 |
+
|
263 |
+
### Download Vicuna checkpoints (automatically)
|
264 |
+
|
265 |
+
Our base model Vicuna v1.5, which is an instruction-tuned chatbot, will be downloaded automatically when you run our provided training scripts. No action is needed.
|
266 |
+
|
267 |
+
### Pretrain (feature alignment)
|
268 |
+
|
269 |
+
Please download the 558K subset of the LAION-CC-SBU dataset with BLIP captions we use in the paper [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain).
|
270 |
+
|
271 |
+
Pretrain takes around 5.5 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 3.5 hours for LLaVA-v1.5-7B.
|
272 |
+
|
273 |
+
Training script with DeepSpeed ZeRO-2: [`pretrain.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/pretrain.sh).
|
274 |
+
|
275 |
+
- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
|
276 |
+
- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
|
277 |
+
|
278 |
+
<details>
|
279 |
+
<summary>Pretrain takes around 20 hours for LLaVA-7B on 8x V100 (32G)</summary>
|
280 |
+
|
281 |
+
We provide training script with DeepSpeed [here](https://github.com/haotian-liu/LLaVA/blob/main/scripts/pretrain_xformers.sh).
|
282 |
+
Tips:
|
283 |
+
- If you are using V100 which is not supported by FlashAttention, you can use the [memory-efficient attention](https://arxiv.org/abs/2112.05682) implemented in [xFormers](https://github.com/facebookresearch/xformers). Install xformers and replace `llava/train/train_mem.py` above with [llava/train/train_xformers.py](LLAV/llava/train/train_xformers.py).
|
284 |
+
</details>
|
285 |
+
|
286 |
+
### Visual Instruction Tuning
|
287 |
+
|
288 |
+
1. Prepare data
|
289 |
+
|
290 |
+
Please download the annotation of the final mixture our instruction tuning data [llava_v1_5_mix665k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json), and download the images from constituting datasets:
|
291 |
+
|
292 |
+
- COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip)
|
293 |
+
- GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip)
|
294 |
+
- OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing), **we save all files as `.jpg`**
|
295 |
+
- TextVQA: [train_val_images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip)
|
296 |
+
- VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip)
|
297 |
+
|
298 |
+
After downloading all of them, organize the data as follows in `./playground/data`,
|
299 |
+
|
300 |
+
```
|
301 |
+
├── coco
|
302 |
+
│ └── train2017
|
303 |
+
├── gqa
|
304 |
+
│ └── images
|
305 |
+
├── ocr_vqa
|
306 |
+
│ └── images
|
307 |
+
├── textvqa
|
308 |
+
│ └── train_images
|
309 |
+
└── vg
|
310 |
+
├── VG_100K
|
311 |
+
└── VG_100K_2
|
312 |
+
```
|
313 |
+
|
314 |
+
2. Start training!
|
315 |
+
|
316 |
+
You may download our pretrained projectors in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). It is not recommended to use legacy projectors, as they may be trained with a different version of the codebase, and if any option is off, the model will not function/train as we expected.
|
317 |
+
|
318 |
+
Visual instruction tuning takes around 20 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 10 hours for LLaVA-v1.5-7B on 8x A100 (40G).
|
319 |
+
|
320 |
+
Training script with DeepSpeed ZeRO-3: [`finetune.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune.sh).
|
321 |
+
|
322 |
+
If you are do not have enough GPU memory:
|
323 |
+
|
324 |
+
- Use LoRA: [`finetune_lora.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune_lora.sh). We are able to fit 13B training in 8-A100-40G/8-A6000, and 7B training in 8-RTX3090. Make sure `per_device_train_batch_size*gradient_accumulation_steps` is the same as the provided script for best reproducibility.
|
325 |
+
- Replace `zero3.json` with `zero3_offload.json` which offloads some parameters to CPU RAM. This slows down the training speed.
|
326 |
+
|
327 |
+
If you are interested in finetuning LLaVA model to your own task/data, please check out [`Finetune_Custom_Data.md`](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md)。
|
328 |
+
|
329 |
+
New options to note:
|
330 |
+
|
331 |
+
- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
|
332 |
+
- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
|
333 |
+
- `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
|
334 |
+
- `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct). It makes the training sampler only sample a single modality (either image or language) during training, which we observe to speed up training by ~25%, and does not affect the final outcome.
|
335 |
+
|
336 |
+
## Evaluation
|
337 |
+
|
338 |
+
In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
|
339 |
+
|
340 |
+
See [Evaluation.md](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md).
|
341 |
+
|
342 |
+
### GPT-assisted Evaluation
|
343 |
+
|
344 |
+
Our GPT-assisted evaluation pipeline for multimodal modeling is provided for a comprehensive understanding of the capabilities of vision-language models. Please see our paper for more details.
|
345 |
+
|
346 |
+
1. Generate LLaVA responses
|
347 |
+
|
348 |
+
```Shell
|
349 |
+
python model_vqa.py \
|
350 |
+
--model-path ./checkpoints/LLaVA-13B-v0 \
|
351 |
+
--question-file \
|
352 |
+
playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \
|
353 |
+
--image-folder \
|
354 |
+
/path/to/coco2014_val \
|
355 |
+
--answers-file \
|
356 |
+
/path/to/answer-file-our.jsonl
|
357 |
+
```
|
358 |
+
|
359 |
+
2. Evaluate the generated responses. In our case, [`answer-file-ref.jsonl`](./playground/data/coco2014_val_qa_eval/qa90_gpt4_answer.jsonl) is the response generated by text-only GPT-4 (0314), with the context captions/boxes provided.
|
360 |
+
|
361 |
+
```Shell
|
362 |
+
OPENAI_API_KEY="sk-***********************************" python llava/eval/eval_gpt_review_visual.py \
|
363 |
+
--question playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \
|
364 |
+
--context llava/eval/table/caps_boxes_coco2014_val_80.jsonl \
|
365 |
+
--answer-list \
|
366 |
+
/path/to/answer-file-ref.jsonl \
|
367 |
+
/path/to/answer-file-our.jsonl \
|
368 |
+
--rule llava/eval/table/rule.json \
|
369 |
+
--output /path/to/review.json
|
370 |
+
```
|
371 |
+
|
372 |
+
3. Summarize the evaluation results
|
373 |
+
|
374 |
+
```Shell
|
375 |
+
python summarize_gpt_review.py
|
376 |
+
```
|
377 |
+
|
378 |
+
## Citation
|
379 |
+
|
380 |
+
If you find LLaVA useful for your research and applications, please cite using this BibTeX:
|
381 |
+
```bibtex
|
382 |
+
|
383 |
+
@misc{liu2023improvedllava,
|
384 |
+
title={Improved Baselines with Visual Instruction Tuning},
|
385 |
+
author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae},
|
386 |
+
publisher={arXiv:2310.03744},
|
387 |
+
year={2023},
|
388 |
+
}
|
389 |
+
|
390 |
+
@misc{liu2023llava,
|
391 |
+
title={Visual Instruction Tuning},
|
392 |
+
author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
|
393 |
+
publisher={arXiv:2304.08485},
|
394 |
+
year={2023},
|
395 |
+
}
|
396 |
+
```
|
397 |
+
|
398 |
+
## Acknowledgement
|
399 |
+
|
400 |
+
- [Vicuna](https://github.com/lm-sys/FastChat): the codebase we built upon, and our base model Vicuna-13B that has the amazing language capabilities!
|
401 |
+
|
402 |
+
## Related Projects
|
403 |
+
|
404 |
+
- [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
405 |
+
- [LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day](https://github.com/microsoft/LLaVA-Med)
|
406 |
+
- [Otter: In-Context Multi-Modal Instruction Tuning](https://github.com/Luodian/Otter)
|
407 |
+
|
408 |
+
For future project ideas, please check out:
|
409 |
+
- [SEEM: Segment Everything Everywhere All at Once](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)
|
410 |
+
- [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) to detect, segment, and generate anything by marrying [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) and [Segment-Anything](https://github.com/facebookresearch/segment-anything).
|
LLAVA_Biovil/__init__.py
ADDED
File without changes
|
LLAVA_Biovil/biovil_t/__init__.py
ADDED
File without changes
|
LLAVA_Biovil/biovil_t/encoder.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# -------------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from typing import Any, Generator, Optional, Sequence, Tuple, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from health_multimodal.common.device import get_module_device
|
14 |
+
from timm.models.layers import trunc_normal_
|
15 |
+
|
16 |
+
from .resnet import resnet18, resnet50
|
17 |
+
from .transformer import VisionTransformerPooler
|
18 |
+
from .types import ImageEncoderType
|
19 |
+
|
20 |
+
DEFAULT_DILATION_VALUES_FOR_RESNET = (False, False, True)
|
21 |
+
ImageEncoderOutputType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
22 |
+
|
23 |
+
|
24 |
+
class ImageEncoder(nn.Module):
|
25 |
+
"""Image encoder trunk module for the ``ImageModel`` class.
|
26 |
+
|
27 |
+
:param img_encoder_type : Type of image encoder model to use, either ``"resnet18_multi_image"`` or
|
28 |
+
``"resnet50_multi_image"``.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, img_encoder_type: str):
|
32 |
+
super().__init__()
|
33 |
+
self.img_encoder_type = img_encoder_type
|
34 |
+
self.encoder = self._create_encoder()
|
35 |
+
|
36 |
+
def _create_encoder(self, **kwargs: Any) -> nn.Module:
|
37 |
+
if self.img_encoder_type in [ImageEncoderType.RESNET18, ImageEncoderType.RESNET18_MULTI_IMAGE]:
|
38 |
+
encoder_class = resnet18
|
39 |
+
elif self.img_encoder_type in [ImageEncoderType.RESNET50, ImageEncoderType.RESNET50_MULTI_IMAGE]:
|
40 |
+
encoder_class = resnet50
|
41 |
+
else:
|
42 |
+
supported = ImageEncoderType.get_members(multi_image_encoders_only=False)
|
43 |
+
raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}")
|
44 |
+
|
45 |
+
encoder = encoder_class(pretrained=True, **kwargs)
|
46 |
+
|
47 |
+
return encoder
|
48 |
+
|
49 |
+
def forward(self,
|
50 |
+
current_image: torch.Tensor,
|
51 |
+
return_patch_embeddings: bool = False) -> ImageEncoderOutputType:
|
52 |
+
"""Get image global and patch embeddings"""
|
53 |
+
|
54 |
+
patch_emb = self.encoder(current_image)
|
55 |
+
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_emb, (1, 1)), 1)
|
56 |
+
if return_patch_embeddings:
|
57 |
+
return patch_emb, avg_pooled_emb
|
58 |
+
|
59 |
+
return avg_pooled_emb
|
60 |
+
|
61 |
+
def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
|
62 |
+
"""Workaround for enabling dilated convolutions after model initialization.
|
63 |
+
|
64 |
+
:param replace_stride_with_dilation: Replace the 2x2 standard convolution stride with a dilated convolution
|
65 |
+
in each layer in the last three blocks of ResNet architecture.
|
66 |
+
"""
|
67 |
+
if self.img_encoder_type == ImageEncoderType.RESNET18:
|
68 |
+
# resnet18 uses BasicBlock implementation, which does not support dilated convolutions.
|
69 |
+
raise NotImplementedError("resnet18 does not support dilated convolutions")
|
70 |
+
|
71 |
+
if replace_stride_with_dilation is None:
|
72 |
+
replace_stride_with_dilation = DEFAULT_DILATION_VALUES_FOR_RESNET
|
73 |
+
|
74 |
+
device = next(self.encoder.parameters()).device
|
75 |
+
new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device)
|
76 |
+
|
77 |
+
if self.encoder.training:
|
78 |
+
new_encoder.train()
|
79 |
+
else:
|
80 |
+
new_encoder.eval()
|
81 |
+
|
82 |
+
new_encoder.load_state_dict(self.encoder.state_dict())
|
83 |
+
self.encoder = new_encoder
|
84 |
+
|
85 |
+
|
86 |
+
class MultiImageEncoder(ImageEncoder):
|
87 |
+
"""Multi-image encoder trunk module for the ``ImageModel`` class.
|
88 |
+
It can be used to encode multiple images into combined latent representation.
|
89 |
+
Currently it only supports two input images but can be extended to support more in future.
|
90 |
+
|
91 |
+
:param img_encoder_type: Type of image encoder model to use: either ``"resnet18"`` or ``"resnet50"``.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, img_encoder_type: str):
|
95 |
+
super().__init__(img_encoder_type)
|
96 |
+
|
97 |
+
output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff]
|
98 |
+
grid_shape = (14, 14) # Spatial dimensions of patch grid.
|
99 |
+
|
100 |
+
backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self))
|
101 |
+
|
102 |
+
self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim,
|
103 |
+
kernel_size=1, stride=1, padding=0, bias=False)
|
104 |
+
self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=grid_shape)
|
105 |
+
|
106 |
+
# Missing image embedding
|
107 |
+
self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1))
|
108 |
+
trunc_normal_(self.missing_previous_emb, std=.02)
|
109 |
+
|
110 |
+
def forward(self, # type: ignore[override]
|
111 |
+
current_image: torch.Tensor,
|
112 |
+
previous_image: Optional[torch.Tensor] = None,
|
113 |
+
return_patch_embeddings: bool = False) -> ImageEncoderOutputType:
|
114 |
+
|
115 |
+
batch_size = current_image.shape[0]
|
116 |
+
|
117 |
+
if previous_image is not None:
|
118 |
+
assert current_image.shape == previous_image.shape
|
119 |
+
x = torch.cat([current_image, previous_image], dim=0)
|
120 |
+
x = super().forward(x, return_patch_embeddings=True)[0]
|
121 |
+
x = self.backbone_to_vit(x)
|
122 |
+
patch_x, patch_x_previous = x[:batch_size], x[batch_size:]
|
123 |
+
diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_x_previous)
|
124 |
+
else:
|
125 |
+
x = super().forward(current_image, return_patch_embeddings=True)[0]
|
126 |
+
patch_x = self.backbone_to_vit(x)
|
127 |
+
B, _, W, H = patch_x.shape
|
128 |
+
diff_x = self.missing_previous_emb.repeat(B, 1, W, H)
|
129 |
+
|
130 |
+
patch_fused = torch.cat([patch_x, diff_x], dim=1)
|
131 |
+
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_fused, (1, 1)), 1)
|
132 |
+
|
133 |
+
if return_patch_embeddings:
|
134 |
+
return patch_fused, avg_pooled_emb
|
135 |
+
|
136 |
+
return avg_pooled_emb
|
137 |
+
|
138 |
+
def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
|
139 |
+
raise NotImplementedError
|
140 |
+
|
141 |
+
|
142 |
+
@torch.no_grad()
|
143 |
+
def get_encoder_output_dim(module: torch.nn.Module, device: torch.device) -> int:
|
144 |
+
"""Calculate the output dimension of an encoder by making a single forward pass.
|
145 |
+
|
146 |
+
:param module: Encoder module.
|
147 |
+
:param device: Compute device to use.
|
148 |
+
"""
|
149 |
+
# Target device
|
150 |
+
assert isinstance(device, torch.device)
|
151 |
+
|
152 |
+
x = torch.rand((1, 3, 448, 448)).to(device)
|
153 |
+
|
154 |
+
# Extract the number of output feature dimensions
|
155 |
+
with restore_training_mode(module):
|
156 |
+
module.eval()
|
157 |
+
representations = module(x)
|
158 |
+
return representations.shape[1]
|
159 |
+
|
160 |
+
|
161 |
+
@contextmanager
|
162 |
+
def restore_training_mode(module: nn.Module) -> Generator[None, None, None]:
|
163 |
+
"""Restore the training mode of a module after some operation.
|
164 |
+
|
165 |
+
:param module: PyTorch module.
|
166 |
+
"""
|
167 |
+
training_mode = module.training
|
168 |
+
yield
|
169 |
+
module.train(mode=training_mode)
|
170 |
+
|
171 |
+
|
172 |
+
def get_encoder_from_type(img_encoder_type: str) -> ImageEncoder:
|
173 |
+
"""Returns the encoder class for the given encoder type.
|
174 |
+
|
175 |
+
:param img_encoder_type: Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE}
|
176 |
+
"""
|
177 |
+
if img_encoder_type in ImageEncoderType.get_members(multi_image_encoders_only=True):
|
178 |
+
return MultiImageEncoder(img_encoder_type=img_encoder_type)
|
179 |
+
else:
|
180 |
+
return ImageEncoder(img_encoder_type=img_encoder_type)
|
LLAVA_Biovil/biovil_t/model.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# -------------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any, Optional, Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from health_multimodal.common.device import get_module_device
|
16 |
+
|
17 |
+
from .encoder import get_encoder_from_type, get_encoder_output_dim, MultiImageEncoder
|
18 |
+
from .modules import MLP, MultiTaskModel
|
19 |
+
from .types import ImageModelOutput
|
20 |
+
|
21 |
+
|
22 |
+
class BaseImageModel(nn.Module, ABC):
|
23 |
+
"""Abstract class for image models."""
|
24 |
+
@abstractmethod
|
25 |
+
def forward(self, *args: Any, **kwargs: Any) -> ImageModelOutput:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
|
30 |
+
raise NotImplementedError
|
31 |
+
|
32 |
+
|
33 |
+
class ImageModel(BaseImageModel):
|
34 |
+
"""Image encoder module"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
img_encoder_type: str,
|
38 |
+
joint_feature_size: int,
|
39 |
+
freeze_encoder: bool = False,
|
40 |
+
pretrained_model_path: Optional[Union[str, Path]] = None,
|
41 |
+
**downstream_classifier_kwargs: Any):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
# Initiate encoder, projector, and classifier
|
45 |
+
self.encoder = get_encoder_from_type(img_encoder_type)
|
46 |
+
self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder))
|
47 |
+
self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
|
48 |
+
hidden_dim=joint_feature_size, use_1x1_convs=True)
|
49 |
+
self.downstream_classifier_kwargs = downstream_classifier_kwargs
|
50 |
+
self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None
|
51 |
+
|
52 |
+
# Initialise the mode of modules
|
53 |
+
self.freeze_encoder = freeze_encoder
|
54 |
+
self.train()
|
55 |
+
|
56 |
+
self.image_processor = None #TODO
|
57 |
+
|
58 |
+
if pretrained_model_path is not None:
|
59 |
+
if not isinstance(pretrained_model_path, (str, Path)):
|
60 |
+
raise TypeError(f"Expected a string or Path, got {type(pretrained_model_path)}")
|
61 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
62 |
+
# drop projector
|
63 |
+
# for k in list(state_dict.keys()):
|
64 |
+
# if k.startswith("projector"):
|
65 |
+
# state_dict.pop(k)
|
66 |
+
|
67 |
+
self.load_state_dict(state_dict, strict=False)
|
68 |
+
|
69 |
+
|
70 |
+
def train(self, mode: bool = True) -> Any:
|
71 |
+
"""Switch the model between training and evaluation modes."""
|
72 |
+
super().train(mode=mode)
|
73 |
+
if self.freeze_encoder:
|
74 |
+
self.encoder.train(mode=False)
|
75 |
+
self.projector.train(mode=False)
|
76 |
+
return self
|
77 |
+
|
78 |
+
def forward(self, x: torch.Tensor) -> ImageModelOutput: # type: ignore[override]
|
79 |
+
with torch.set_grad_enabled(not self.freeze_encoder):
|
80 |
+
patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True)
|
81 |
+
return self.forward_post_encoder(patch_x, pooled_x)
|
82 |
+
|
83 |
+
def forward_post_encoder(self, patch_x: torch.Tensor, pooled_x: torch.Tensor) -> ImageModelOutput:
|
84 |
+
with torch.set_grad_enabled(not self.freeze_encoder):
|
85 |
+
projected_patch_embeddings = self.projector(patch_x)
|
86 |
+
projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3))
|
87 |
+
|
88 |
+
logits = self.classifier(pooled_x) if self.classifier else None
|
89 |
+
return ImageModelOutput(img_embedding=pooled_x,
|
90 |
+
patch_embeddings=patch_x,
|
91 |
+
class_logits=logits,
|
92 |
+
projected_patch_embeddings=projected_patch_embeddings,
|
93 |
+
projected_global_embedding=projected_global_embedding)
|
94 |
+
|
95 |
+
def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel:
|
96 |
+
"""Create the classification module for the downstream task."""
|
97 |
+
downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs
|
98 |
+
return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs)
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
|
102 |
+
"""Get patch-wise projected embeddings from the CNN model.
|
103 |
+
|
104 |
+
:param input_img: input tensor image [B, C, H, W].
|
105 |
+
:param normalize: If ``True``, the embeddings are L2-normalized.
|
106 |
+
:returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size].
|
107 |
+
"""
|
108 |
+
assert not self.training, "This function is only implemented for evaluation mode"
|
109 |
+
outputs = self.forward(input_img)
|
110 |
+
projected_embeddings = outputs.projected_patch_embeddings.detach() # type: ignore
|
111 |
+
if normalize:
|
112 |
+
projected_embeddings = F.normalize(projected_embeddings, dim=1)
|
113 |
+
projected_embeddings = projected_embeddings.permute([0, 2, 3, 1]) # B D H W -> B H W D (D: Features)
|
114 |
+
return projected_embeddings
|
115 |
+
|
116 |
+
|
117 |
+
class MultiImageModel(ImageModel):
|
118 |
+
def __init__(self, **kwargs: Any) -> None:
|
119 |
+
super().__init__(**kwargs)
|
120 |
+
assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder"
|
121 |
+
|
122 |
+
def forward(self, # type: ignore[override]
|
123 |
+
current_image: torch.Tensor,
|
124 |
+
previous_image: Optional[torch.Tensor] = None) -> ImageModelOutput:
|
125 |
+
|
126 |
+
with torch.set_grad_enabled(not self.freeze_encoder):
|
127 |
+
patch_x, pooled_x = self.encoder(current_image=current_image,
|
128 |
+
previous_image=previous_image,
|
129 |
+
return_patch_embeddings=True)
|
130 |
+
return self.forward_post_encoder(patch_x, pooled_x)
|
LLAVA_Biovil/biovil_t/modules.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# -------------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
from typing import Callable, Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
|
12 |
+
class MLP(nn.Module):
|
13 |
+
"""
|
14 |
+
Fully connected layers to map between image embeddings and projection space where pairs of images are compared.
|
15 |
+
|
16 |
+
:param input_dim: Input embedding feature size
|
17 |
+
:param hidden_dim: Hidden layer size in MLP
|
18 |
+
:param output_dim: Output projection size
|
19 |
+
:param use_1x1_convs: Use 1x1 conv kernels instead of 2D linear transformations for speed and memory efficiency.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
input_dim: int,
|
24 |
+
output_dim: int,
|
25 |
+
hidden_dim: Optional[int] = None,
|
26 |
+
use_1x1_convs: bool = False) -> None:
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
if use_1x1_convs:
|
30 |
+
linear_proj_1_args = {'in_channels': input_dim, 'out_channels': hidden_dim, 'kernel_size': 1, 'bias': False}
|
31 |
+
linear_proj_2_args = {'in_channels': hidden_dim, 'out_channels': output_dim, 'kernel_size': 1, 'bias': True}
|
32 |
+
normalisation_layer: Callable = nn.BatchNorm2d
|
33 |
+
projection_layer: Callable = nn.Conv2d
|
34 |
+
else:
|
35 |
+
linear_proj_1_args = {'in_features': input_dim, 'out_features': hidden_dim, 'bias': False}
|
36 |
+
linear_proj_2_args = {'in_features': hidden_dim, 'out_features': output_dim, 'bias': True}
|
37 |
+
normalisation_layer = nn.BatchNorm1d
|
38 |
+
projection_layer = nn.Linear
|
39 |
+
|
40 |
+
self.output_dim = output_dim
|
41 |
+
self.input_dim = input_dim
|
42 |
+
if hidden_dim is not None:
|
43 |
+
self.model = nn.Sequential(
|
44 |
+
projection_layer(**linear_proj_1_args),
|
45 |
+
normalisation_layer(hidden_dim),
|
46 |
+
nn.ReLU(inplace=True),
|
47 |
+
projection_layer(**linear_proj_2_args))
|
48 |
+
else:
|
49 |
+
self.model = nn.Linear(input_dim, output_dim) # type: ignore
|
50 |
+
|
51 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
52 |
+
"""forward pass of the multi-layer perceptron"""
|
53 |
+
x = self.model(x)
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class MultiTaskModel(nn.Module):
|
58 |
+
"""Torch module for multi-task classification heads. We create a separate classification head
|
59 |
+
for each task and perform a forward pass on each head independently in forward(). Classification
|
60 |
+
heads are instances of `MLP`.
|
61 |
+
|
62 |
+
:param input_dim: Number of dimensions of the input feature map.
|
63 |
+
:param classifier_hidden_dim: Number of dimensions of hidden features in the MLP.
|
64 |
+
:param num_classes: Number of output classes per task.
|
65 |
+
:param num_tasks: Number of classification tasks or heads required.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, input_dim: int, classifier_hidden_dim: Optional[int], num_classes: int, num_tasks: int):
|
69 |
+
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
self.num_classes = num_classes
|
73 |
+
self.num_tasks = num_tasks
|
74 |
+
|
75 |
+
for task in range(num_tasks):
|
76 |
+
setattr(self, "fc_" + str(task), MLP(input_dim, output_dim=num_classes, hidden_dim=classifier_hidden_dim))
|
77 |
+
|
78 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
79 |
+
"""Returns [batch_size, num_tasks, num_classes] tensor of logits."""
|
80 |
+
batch_size = x.shape[0]
|
81 |
+
out = torch.zeros((batch_size, self.num_classes, self.num_tasks), dtype=x.dtype, device=x.device)
|
82 |
+
for task in range(self.num_tasks):
|
83 |
+
classifier = getattr(self, "fc_" + str(task))
|
84 |
+
out[:, :, task] = classifier(x)
|
85 |
+
return out
|
LLAVA_Biovil/biovil_t/pretrained.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# -------------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
import tempfile
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
from torchvision.datasets.utils import download_url
|
12 |
+
|
13 |
+
from .model import ImageModel
|
14 |
+
from .types import ImageEncoderType
|
15 |
+
|
16 |
+
|
17 |
+
JOINT_FEATURE_SIZE = 128
|
18 |
+
|
19 |
+
BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
|
20 |
+
BIOMED_VLP_BIOVIL_T = "microsoft/BiomedVLP-BioViL-T"
|
21 |
+
HF_URL = "https://huggingface.co"
|
22 |
+
|
23 |
+
CXR_BERT_COMMIT_TAG = "v1.1"
|
24 |
+
BIOVIL_T_COMMIT_TAG = "v1.0"
|
25 |
+
|
26 |
+
BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
|
27 |
+
BIOVIL_IMAGE_WEIGHTS_URL = f"{HF_URL}/{BIOMED_VLP_CXR_BERT_SPECIALIZED}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}" # noqa: E501
|
28 |
+
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"
|
29 |
+
|
30 |
+
BIOVIL_T_IMAGE_WEIGHTS_NAME = "biovil_t_image_model_proj_size_128.pt"
|
31 |
+
BIOVIL_T_IMAGE_WEIGHTS_URL = f"{HF_URL}/{BIOMED_VLP_BIOVIL_T}/resolve/{BIOVIL_T_COMMIT_TAG}/{BIOVIL_T_IMAGE_WEIGHTS_NAME}" # noqa: E501
|
32 |
+
BIOVIL_T_IMAGE_WEIGHTS_MD5 = "a83080e2f23aa584a4f2b24c39b1bb64"
|
33 |
+
|
34 |
+
|
35 |
+
def _download_biovil_image_model_weights() -> Path:
|
36 |
+
"""Download image model weights from Hugging Face.
|
37 |
+
|
38 |
+
More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
|
39 |
+
"""
|
40 |
+
root_dir = tempfile.gettempdir()
|
41 |
+
download_url(
|
42 |
+
BIOVIL_IMAGE_WEIGHTS_URL,
|
43 |
+
root=root_dir,
|
44 |
+
filename=BIOVIL_IMAGE_WEIGHTS_NAME,
|
45 |
+
md5=BIOVIL_IMAGE_WEIGHTS_MD5,
|
46 |
+
)
|
47 |
+
return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)
|
48 |
+
|
49 |
+
|
50 |
+
def _download_biovil_t_image_model_weights() -> Path:
|
51 |
+
"""Download image model weights from Hugging Face.
|
52 |
+
|
53 |
+
More information available at https://huggingface.co/microsoft/microsoft/BiomedVLP-BioViL-T.
|
54 |
+
"""
|
55 |
+
root_dir = tempfile.gettempdir()
|
56 |
+
download_url(
|
57 |
+
BIOVIL_T_IMAGE_WEIGHTS_URL,
|
58 |
+
root=root_dir,
|
59 |
+
filename=BIOVIL_T_IMAGE_WEIGHTS_NAME,
|
60 |
+
md5=BIOVIL_T_IMAGE_WEIGHTS_MD5
|
61 |
+
)
|
62 |
+
return Path(root_dir, BIOVIL_T_IMAGE_WEIGHTS_NAME)
|
63 |
+
|
64 |
+
|
65 |
+
def get_biovil_image_encoder(pretrained: bool = True) -> ImageModel:
|
66 |
+
"""Download weights from Hugging Face and instantiate the image model."""
|
67 |
+
resnet_checkpoint_path = _download_biovil_image_model_weights() if pretrained else None
|
68 |
+
|
69 |
+
image_model = ImageModel(
|
70 |
+
img_encoder_type=ImageEncoderType.RESNET50,
|
71 |
+
joint_feature_size=JOINT_FEATURE_SIZE,
|
72 |
+
pretrained_model_path=resnet_checkpoint_path,
|
73 |
+
)
|
74 |
+
return image_model
|
75 |
+
|
76 |
+
|
77 |
+
def get_biovil_t_image_encoder() -> ImageModel:
|
78 |
+
"""Download weights from Hugging Face and instantiate the image model."""
|
79 |
+
|
80 |
+
biovilt_checkpoint_path = _download_biovil_t_image_model_weights()
|
81 |
+
model_type = ImageEncoderType.RESNET50_MULTI_IMAGE
|
82 |
+
image_model = ImageModel(img_encoder_type=model_type,
|
83 |
+
joint_feature_size=JOINT_FEATURE_SIZE,
|
84 |
+
pretrained_model_path=biovilt_checkpoint_path)
|
85 |
+
return image_model
|
LLAVA_Biovil/biovil_t/resnet.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# -------------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
from typing import Any, List, Tuple, Type, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.hub import load_state_dict_from_url
|
10 |
+
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
|
11 |
+
|
12 |
+
TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
13 |
+
|
14 |
+
|
15 |
+
class ResNetHIML(ResNet):
|
16 |
+
"""Wrapper class of the original torchvision ResNet model.
|
17 |
+
|
18 |
+
The forward function is updated to return the penultimate layer
|
19 |
+
activations, which are required to obtain image patch embeddings.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, **kwargs: Any) -> None:
|
23 |
+
super().__init__(**kwargs)
|
24 |
+
|
25 |
+
def forward(self, x: torch.Tensor,
|
26 |
+
return_intermediate_layers: bool = False) -> Union[torch.Tensor, TypeSkipConnections]:
|
27 |
+
"""ResNetHIML forward pass. Optionally returns intermediate layers using the
|
28 |
+
``return_intermediate_layers`` argument.
|
29 |
+
|
30 |
+
:param return_intermediate_layers: If ``True``, return layers x0-x4 as a tuple,
|
31 |
+
otherwise return x4 only.
|
32 |
+
"""
|
33 |
+
|
34 |
+
x0 = self.conv1(x)
|
35 |
+
x0 = self.bn1(x0)
|
36 |
+
x0 = self.relu(x0)
|
37 |
+
x0 = self.maxpool(x0)
|
38 |
+
|
39 |
+
x1 = self.layer1(x0)
|
40 |
+
x2 = self.layer2(x1)
|
41 |
+
x3 = self.layer3(x2)
|
42 |
+
x4 = self.layer4(x3)
|
43 |
+
|
44 |
+
if return_intermediate_layers:
|
45 |
+
return x0, x1, x2, x3, x4
|
46 |
+
else:
|
47 |
+
return x4
|
48 |
+
|
49 |
+
|
50 |
+
def _resnet(arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int],
|
51 |
+
pretrained: bool, progress: bool, **kwargs: Any) -> ResNetHIML:
|
52 |
+
"""Instantiate a custom :class:`ResNet` model.
|
53 |
+
|
54 |
+
Adapted from :mod:`torchvision.models.resnet`.
|
55 |
+
"""
|
56 |
+
model = ResNetHIML(block=block, layers=layers, **kwargs)
|
57 |
+
if pretrained:
|
58 |
+
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=progress)
|
59 |
+
model.load_state_dict(state_dict)
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
|
64 |
+
r"""ResNet-18 model from
|
65 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
66 |
+
|
67 |
+
:param pretrained: If ``True``, returns a model pre-trained on ImageNet.
|
68 |
+
:param progress: If ``True``, displays a progress bar of the download to ``stderr``.
|
69 |
+
"""
|
70 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
71 |
+
|
72 |
+
|
73 |
+
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
|
74 |
+
r"""ResNet-50 model from
|
75 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
76 |
+
|
77 |
+
:param pretrained: If ``True``, returns a model pre-trained on ImageNet
|
78 |
+
:param progress: If ``True``, displays a progress bar of the download to ``stderr``.
|
79 |
+
"""
|
80 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
|
LLAVA_Biovil/biovil_t/transformer.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# -------------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from functools import partial
|
9 |
+
from typing import Any, Callable, Optional, Set, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from timm.models.layers import DropPath, Mlp, trunc_normal_
|
14 |
+
|
15 |
+
|
16 |
+
def torch_int_div(tensor1, tensor2):
|
17 |
+
"""
|
18 |
+
A function that performs integer division across different versions of PyTorch.
|
19 |
+
"""
|
20 |
+
return torch.div(tensor1, tensor2, rounding_mode="floor")
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class MultiHeadAttentionOutput:
|
24 |
+
mha_output: torch.Tensor
|
25 |
+
attention: Optional[torch.Tensor] = None
|
26 |
+
|
27 |
+
|
28 |
+
class VisionTransformerPooler(nn.Module):
|
29 |
+
"""
|
30 |
+
:param input_dim: Input feature dimension (i.e., channels in old CNN terminology)
|
31 |
+
:param grid_shape: Shape of the grid of patches per image
|
32 |
+
:param num_heads: Number of self-attention heads within the MHA block
|
33 |
+
:param num_blocks: Number of blocks per attention layer
|
34 |
+
:param norm_layer: Normalisation layer
|
35 |
+
|
36 |
+
`self.type_embed`: Is used to characterise prior and current scans, and
|
37 |
+
create permutation variance across modalities/series.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
input_dim: int,
|
42 |
+
grid_shape: Tuple[int, int],
|
43 |
+
num_heads: int = 8,
|
44 |
+
num_blocks: int = 3,
|
45 |
+
norm_layer: Any = partial(nn.LayerNorm, eps=1e-6)):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
block_kwargs = dict(dim=input_dim, num_heads=num_heads, mlp_ratio=1., drop=0.10, attn_drop=0.10,
|
49 |
+
drop_path=0.25, act_layer=nn.GELU, norm_layer=norm_layer)
|
50 |
+
self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_blocks)])
|
51 |
+
self.norm_post = norm_layer(input_dim)
|
52 |
+
self.grid_shape = grid_shape
|
53 |
+
self.num_patches = grid_shape[0] * grid_shape[1]
|
54 |
+
self.num_blocks = num_blocks
|
55 |
+
|
56 |
+
# Temporal positional embeddings
|
57 |
+
num_series: int = 2
|
58 |
+
self.type_embed = nn.Parameter(torch.zeros(num_series, 1, input_dim))
|
59 |
+
trunc_normal_(self.type_embed, std=.02)
|
60 |
+
|
61 |
+
# Positional embeddings 1 x L x C (L: Sequence length, C: Feature dimension)
|
62 |
+
self.pos_drop = nn.Dropout(p=0.10)
|
63 |
+
pos_embed_class = SinePositionEmbedding(embedding_dim=input_dim // 2, normalize=True)
|
64 |
+
pos_embed = pos_embed_class(mask=torch.ones([1, grid_shape[0], grid_shape[1]])) # 1 x L x C
|
65 |
+
self.register_buffer("pos_embed", pos_embed, persistent=False)
|
66 |
+
|
67 |
+
# Initialisation
|
68 |
+
self.apply(self._init_weights)
|
69 |
+
|
70 |
+
def no_weight_decay(self) -> Set[str]:
|
71 |
+
return {'type_embed'}
|
72 |
+
|
73 |
+
def forward(self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None) -> torch.Tensor:
|
74 |
+
B, C, H, W = current_image.shape
|
75 |
+
assert H == self.grid_shape[0] and W == self.grid_shape[1], "Input and grid shapes do not match"
|
76 |
+
|
77 |
+
# Flatten patch embeddings to have shape (B x L x C), L = H * W
|
78 |
+
if previous_image is not None:
|
79 |
+
assert previous_image.shape == current_image.shape, "current_image and previous_image shapes do not match"
|
80 |
+
previous_image = previous_image.view(B, C, H * W).transpose(1, 2)
|
81 |
+
current_image = current_image.view(B, C, H * W).transpose(1, 2)
|
82 |
+
pos_embed = self.pos_embed.repeat(B, 1, 1) # type: ignore
|
83 |
+
|
84 |
+
# Final token activations (B x 2L x C)
|
85 |
+
token_features = self.forward_after_reshape(x=current_image, pos_embed=pos_embed, x_previous=previous_image)
|
86 |
+
|
87 |
+
# Extract the patch features of current image
|
88 |
+
cur_img_token_id = 0
|
89 |
+
current_token_features = token_features[:, cur_img_token_id:self.num_patches+cur_img_token_id]
|
90 |
+
current_patch_features = current_token_features.transpose(1, 2).view(B, C, H, W)
|
91 |
+
|
92 |
+
return current_patch_features
|
93 |
+
|
94 |
+
def forward_after_reshape(self,
|
95 |
+
x: torch.Tensor,
|
96 |
+
pos_embed: torch.Tensor,
|
97 |
+
x_previous: Optional[torch.Tensor] = None) -> torch.Tensor:
|
98 |
+
B, L, _ = x.shape # Batch, Sequence length, Feature dimension
|
99 |
+
|
100 |
+
# Positional and type embeddings
|
101 |
+
type_embed = self.type_embed[0].expand(B, L, -1)
|
102 |
+
if x_previous is not None:
|
103 |
+
x = torch.cat((x, x_previous), dim=1)
|
104 |
+
pos_embed = torch.cat((pos_embed, pos_embed), dim=1)
|
105 |
+
prev_type_embed = self.type_embed[1].expand(B, L, -1)
|
106 |
+
type_embed = torch.cat((type_embed, prev_type_embed), dim=1)
|
107 |
+
|
108 |
+
# Add positional and type embeddings (used in query and key matching)
|
109 |
+
pos_and_type_embed = pos_embed + type_embed
|
110 |
+
|
111 |
+
# Positional dropout
|
112 |
+
x = self.pos_drop(x)
|
113 |
+
|
114 |
+
# Multihead attention followed by MLP
|
115 |
+
for block in self.blocks:
|
116 |
+
x = block(x=x, pos_and_type_embed=pos_and_type_embed)
|
117 |
+
x = self.norm_post(x)
|
118 |
+
|
119 |
+
return x
|
120 |
+
|
121 |
+
def _init_weights(self, m: nn.Module) -> None:
|
122 |
+
if isinstance(m, nn.Linear):
|
123 |
+
trunc_normal_(m.weight, std=.02)
|
124 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
125 |
+
nn.init.constant_(m.bias, 0)
|
126 |
+
elif isinstance(m, nn.LayerNorm):
|
127 |
+
nn.init.constant_(m.bias, 0)
|
128 |
+
nn.init.constant_(m.weight, 1.0)
|
129 |
+
|
130 |
+
|
131 |
+
class MultiHeadAttentionLayer(nn.Module):
|
132 |
+
"""
|
133 |
+
Multi-head self attention module
|
134 |
+
|
135 |
+
The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
|
136 |
+
- Defines a custom `MultiHeadAttentionLayer` which does not only apply `self-attention` but it can be
|
137 |
+
generalised to arbitrary (query, key, value) input tuples. This feature can be valuable to process
|
138 |
+
more than 2 scans at a time.
|
139 |
+
- `Self-attention` specific use-case can still be invoked by calling the `forward_as_mhsa` method.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self,
|
143 |
+
dim: int,
|
144 |
+
num_heads: int = 8,
|
145 |
+
qkv_bias: bool = False,
|
146 |
+
attn_drop: float = 0.,
|
147 |
+
proj_drop: float = 0.) -> None:
|
148 |
+
super().__init__()
|
149 |
+
self.num_heads = num_heads
|
150 |
+
assert dim % num_heads == 0, f"The embedding dim ({dim}) must be divisible by the number of heads ({num_heads})"
|
151 |
+
head_dim = dim // num_heads
|
152 |
+
self.scale = head_dim ** -0.5
|
153 |
+
self.return_attention = False
|
154 |
+
|
155 |
+
self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
|
156 |
+
self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
|
157 |
+
self.proj_v = nn.Linear(dim, dim, bias=qkv_bias)
|
158 |
+
|
159 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
160 |
+
self.proj = nn.Linear(dim, dim)
|
161 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
162 |
+
|
163 |
+
def forward(self, k: torch.Tensor, q: torch.Tensor, v: torch.Tensor) -> MultiHeadAttentionOutput:
|
164 |
+
B, N, C = v.shape
|
165 |
+
assert C % self.num_heads == 0, \
|
166 |
+
f"The embedding dim ({C}) must be divisible by the number of heads ({self.num_heads})"
|
167 |
+
|
168 |
+
w_q = self.proj_q(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
169 |
+
w_k = self.proj_k(k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
170 |
+
w_v = self.proj_v(v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
171 |
+
|
172 |
+
attn = (w_q @ w_k.transpose(-2, -1)) * self.scale
|
173 |
+
attn = attn.softmax(dim=-1)
|
174 |
+
attn = self.attn_drop(attn)
|
175 |
+
|
176 |
+
o = (attn @ w_v).transpose(1, 2).reshape(B, N, C)
|
177 |
+
o = self.proj(o)
|
178 |
+
o = self.proj_drop(o)
|
179 |
+
|
180 |
+
attention_output = attn if self.return_attention else None
|
181 |
+
|
182 |
+
return MultiHeadAttentionOutput(mha_output=o, attention=attention_output)
|
183 |
+
|
184 |
+
def forward_as_mhsa(self, input: torch.Tensor) -> MultiHeadAttentionOutput:
|
185 |
+
return self(k=input, q=input, v=input)
|
186 |
+
|
187 |
+
|
188 |
+
class Block(nn.Module):
|
189 |
+
"""
|
190 |
+
Encapsulates multi-layer perceptron and multi-head self attention modules into a block.
|
191 |
+
|
192 |
+
The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
|
193 |
+
- This implementation uses spatio-temporal positional embeddings instead of 2D positional embeddings only,
|
194 |
+
and they are taken into account within the forward pass of each ViT block.
|
195 |
+
- Utilises the custom defined `MultiHeadAttentionLayer` which does not apply `self-attention` only but can be
|
196 |
+
generalised to arbitrary (query, key, value) tuples. This can be valuable to process more than 2 scans.
|
197 |
+
|
198 |
+
Positional and type embeddings are handled in a similar fashion as DETR object localisation paper
|
199 |
+
https://alcinos.github.io/detr_page/, where a fixed set of sine/cos positional embeddings are used
|
200 |
+
in an additive manner to Q and K tensors.
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 1., qkv_bias: bool = False, drop: float = 0.,
|
204 |
+
attn_drop: float = 0., drop_path: float = 0., act_layer: Callable = nn.GELU,
|
205 |
+
norm_layer: Callable = nn.LayerNorm) -> None:
|
206 |
+
super().__init__()
|
207 |
+
self.norm1 = norm_layer(dim)
|
208 |
+
self.attn = MultiHeadAttentionLayer(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
209 |
+
attn_drop=attn_drop, proj_drop=drop)
|
210 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
211 |
+
self.norm2 = norm_layer(dim)
|
212 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
213 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
214 |
+
|
215 |
+
def with_pos_and_type_embed(self, tensor: torch.Tensor, emb: Optional[torch.Tensor]) -> torch.Tensor:
|
216 |
+
# Add positional embeddings to key and query tensors
|
217 |
+
return tensor if emb is None else tensor + emb
|
218 |
+
|
219 |
+
def forward(self, x: torch.Tensor, pos_and_type_embed: Optional[torch.Tensor]) -> torch.Tensor:
|
220 |
+
x_with_emb = self.with_pos_and_type_embed(self.norm1(x), emb=pos_and_type_embed)
|
221 |
+
x = x + self.drop_path(self.attn.forward_as_mhsa(x_with_emb).mha_output)
|
222 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
223 |
+
|
224 |
+
return x
|
225 |
+
|
226 |
+
|
227 |
+
class SinePositionEmbedding():
|
228 |
+
"""
|
229 |
+
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
230 |
+
need paper, generalized to work on images.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(self,
|
234 |
+
embedding_dim: int = 64,
|
235 |
+
temperature: int = 10000,
|
236 |
+
normalize: bool = False,
|
237 |
+
scale: float = None) -> None:
|
238 |
+
super().__init__()
|
239 |
+
self.embedding_dim = embedding_dim
|
240 |
+
self.temperature = temperature
|
241 |
+
self.normalize = normalize
|
242 |
+
if scale is not None and normalize is False:
|
243 |
+
raise ValueError("normalize should be True if scale is passed")
|
244 |
+
if scale is None:
|
245 |
+
scale = 2 * math.pi
|
246 |
+
self.scale = scale
|
247 |
+
|
248 |
+
def __call__(self, mask: torch.Tensor) -> torch.Tensor:
|
249 |
+
assert mask is not None, "No pixel mask provided"
|
250 |
+
B, H, W = mask.shape
|
251 |
+
y_embed = mask.cumsum(1, dtype=torch.float32)
|
252 |
+
x_embed = mask.cumsum(2, dtype=torch.float32)
|
253 |
+
if self.normalize:
|
254 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
|
255 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
|
256 |
+
|
257 |
+
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
|
258 |
+
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim)
|
259 |
+
|
260 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
261 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
262 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
263 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
264 |
+
pos = torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2)
|
265 |
+
|
266 |
+
return pos
|
LLAVA_Biovil/biovil_t/types.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# -------------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
|
7 |
+
from __future__ import annotations
|
8 |
+
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from enum import Enum, unique
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
import torch
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class ImageModelOutput():
|
18 |
+
img_embedding: torch.Tensor
|
19 |
+
patch_embeddings: torch.Tensor
|
20 |
+
projected_global_embedding: torch.Tensor
|
21 |
+
class_logits: torch.Tensor
|
22 |
+
projected_patch_embeddings: torch.Tensor
|
23 |
+
|
24 |
+
|
25 |
+
@unique
|
26 |
+
class ImageEncoderType(str, Enum):
|
27 |
+
RESNET18 = "resnet18"
|
28 |
+
RESNET50 = "resnet50"
|
29 |
+
RESNET18_MULTI_IMAGE = "resnet18_multi_image"
|
30 |
+
RESNET50_MULTI_IMAGE = "resnet50_multi_image"
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def get_members(cls, multi_image_encoders_only: bool) -> List[ImageEncoderType]:
|
34 |
+
if multi_image_encoders_only:
|
35 |
+
return [cls.RESNET18_MULTI_IMAGE, cls.RESNET50_MULTI_IMAGE]
|
36 |
+
else:
|
37 |
+
return [member for member in cls]
|
LLAVA_Biovil/cog.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
3 |
+
|
4 |
+
build:
|
5 |
+
gpu: true
|
6 |
+
|
7 |
+
python_version: "3.11"
|
8 |
+
|
9 |
+
python_packages:
|
10 |
+
- "torch==2.0.1"
|
11 |
+
- "accelerate==0.21.0"
|
12 |
+
- "bitsandbytes==0.41.0"
|
13 |
+
- "deepspeed==0.9.5"
|
14 |
+
- "einops-exts==0.0.4"
|
15 |
+
- "einops==0.6.1"
|
16 |
+
- "gradio==3.35.2"
|
17 |
+
- "gradio_client==0.2.9"
|
18 |
+
- "httpx==0.24.0"
|
19 |
+
- "markdown2==2.4.10"
|
20 |
+
- "numpy==1.26.0"
|
21 |
+
- "peft==0.4.0"
|
22 |
+
- "scikit-learn==1.2.2"
|
23 |
+
- "sentencepiece==0.1.99"
|
24 |
+
- "shortuuid==1.0.11"
|
25 |
+
- "timm==0.6.13"
|
26 |
+
- "tokenizers==0.13.3"
|
27 |
+
- "torch==2.0.1"
|
28 |
+
- "torchvision==0.15.2"
|
29 |
+
- "transformers==4.31.0"
|
30 |
+
- "wandb==0.15.12"
|
31 |
+
- "wavedrom==2.0.3.post3"
|
32 |
+
- "Pygments==2.16.1"
|
33 |
+
run:
|
34 |
+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
|
35 |
+
|
36 |
+
# predict.py defines how predictions are run on your model
|
37 |
+
predict: "predict.py:Predictor"
|
LLAVA_Biovil/install.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
step 1: clone Llava
|
2 |
+
step 2: git clone https://github.com/Dao-AILab/flash-attention.git
|
3 |
+
step 3: conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
|
4 |
+
step 4: pip install -e .
|
5 |
+
step 5: pip install -e ".[train]"
|
6 |
+
step 6: in flash attention folder, run: python setup.py install
|
LLAVA_Biovil/llava/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import LlavaLlamaForCausalLM
|
LLAVA_Biovil/llava/constants.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
13 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
LLAVA_Biovil/llava/conversation.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
PLAIN = auto()
|
12 |
+
LLAMA_2 = auto()
|
13 |
+
|
14 |
+
|
15 |
+
@dataclasses.dataclass
|
16 |
+
class Conversation:
|
17 |
+
"""A class that keeps all conversation history."""
|
18 |
+
system: str
|
19 |
+
roles: List[str]
|
20 |
+
messages: List[List[str]]
|
21 |
+
offset: int
|
22 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
23 |
+
sep: str = "###"
|
24 |
+
sep2: str = None
|
25 |
+
version: str = "Unknown"
|
26 |
+
|
27 |
+
skip_next: bool = False
|
28 |
+
|
29 |
+
def get_prompt(self):
|
30 |
+
messages = self.messages
|
31 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
32 |
+
messages = self.messages.copy()
|
33 |
+
init_role, init_msg = messages[0].copy()
|
34 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
35 |
+
if 'mmtag' in self.version:
|
36 |
+
messages[0] = (init_role, init_msg)
|
37 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
38 |
+
messages.insert(1, (self.roles[1], "Received."))
|
39 |
+
else:
|
40 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
41 |
+
|
42 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
43 |
+
ret = self.system + self.sep
|
44 |
+
for role, message in messages:
|
45 |
+
if message:
|
46 |
+
if type(message) is tuple:
|
47 |
+
message, _, _ = message
|
48 |
+
ret += role + ": " + message + self.sep
|
49 |
+
else:
|
50 |
+
ret += role + ":"
|
51 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
52 |
+
seps = [self.sep, self.sep2]
|
53 |
+
ret = self.system + seps[0]
|
54 |
+
for i, (role, message) in enumerate(messages):
|
55 |
+
if message:
|
56 |
+
if type(message) is tuple:
|
57 |
+
message, _, _ = message
|
58 |
+
ret += role + ": " + message + seps[i % 2]
|
59 |
+
else:
|
60 |
+
ret += role + ":"
|
61 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
62 |
+
ret = self.system + self.sep
|
63 |
+
for role, message in messages:
|
64 |
+
if message:
|
65 |
+
if type(message) is tuple:
|
66 |
+
message, _, _ = message
|
67 |
+
ret += role + message + self.sep
|
68 |
+
else:
|
69 |
+
ret += role
|
70 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
71 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
72 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
73 |
+
ret = ""
|
74 |
+
|
75 |
+
for i, (role, message) in enumerate(messages):
|
76 |
+
if i == 0:
|
77 |
+
assert message, "first message should not be none"
|
78 |
+
assert role == self.roles[0], "first message should come from user"
|
79 |
+
if message:
|
80 |
+
if type(message) is tuple:
|
81 |
+
message, _, _ = message
|
82 |
+
if i == 0: message = wrap_sys(self.system) + message
|
83 |
+
if i % 2 == 0:
|
84 |
+
message = wrap_inst(message)
|
85 |
+
ret += self.sep + message
|
86 |
+
else:
|
87 |
+
ret += " " + message + " " + self.sep2
|
88 |
+
else:
|
89 |
+
ret += ""
|
90 |
+
ret = ret.lstrip(self.sep)
|
91 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
92 |
+
seps = [self.sep, self.sep2]
|
93 |
+
ret = self.system
|
94 |
+
for i, (role, message) in enumerate(messages):
|
95 |
+
if message:
|
96 |
+
if type(message) is tuple:
|
97 |
+
message, _, _ = message
|
98 |
+
ret += message + seps[i % 2]
|
99 |
+
else:
|
100 |
+
ret += ""
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
103 |
+
|
104 |
+
return ret
|
105 |
+
|
106 |
+
def append_message(self, role, message):
|
107 |
+
self.messages.append([role, message])
|
108 |
+
|
109 |
+
def get_images(self, return_pil=False):
|
110 |
+
images = []
|
111 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
112 |
+
if i % 2 == 0:
|
113 |
+
if type(msg) is tuple:
|
114 |
+
import base64
|
115 |
+
from io import BytesIO
|
116 |
+
from PIL import Image
|
117 |
+
msg, image, image_process_mode = msg
|
118 |
+
if image_process_mode == "Pad":
|
119 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
120 |
+
width, height = pil_img.size
|
121 |
+
if width == height:
|
122 |
+
return pil_img
|
123 |
+
elif width > height:
|
124 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
125 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
126 |
+
return result
|
127 |
+
else:
|
128 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
129 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
130 |
+
return result
|
131 |
+
image = expand2square(image)
|
132 |
+
elif image_process_mode in ["Default", "Crop"]:
|
133 |
+
pass
|
134 |
+
elif image_process_mode == "Resize":
|
135 |
+
image = image.resize((336, 336))
|
136 |
+
else:
|
137 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
138 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
139 |
+
aspect_ratio = max_hw / min_hw
|
140 |
+
max_len, min_len = 800, 400
|
141 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
142 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
143 |
+
W, H = image.size
|
144 |
+
if longest_edge != max(image.size):
|
145 |
+
if H > W:
|
146 |
+
H, W = longest_edge, shortest_edge
|
147 |
+
else:
|
148 |
+
H, W = shortest_edge, longest_edge
|
149 |
+
image = image.resize((W, H))
|
150 |
+
if return_pil:
|
151 |
+
images.append(image)
|
152 |
+
else:
|
153 |
+
buffered = BytesIO()
|
154 |
+
image.save(buffered, format="PNG")
|
155 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
156 |
+
images.append(img_b64_str)
|
157 |
+
return images
|
158 |
+
|
159 |
+
def to_gradio_chatbot(self):
|
160 |
+
ret = []
|
161 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
162 |
+
if i % 2 == 0:
|
163 |
+
if type(msg) is tuple:
|
164 |
+
import base64
|
165 |
+
from io import BytesIO
|
166 |
+
msg, image, image_process_mode = msg
|
167 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
168 |
+
aspect_ratio = max_hw / min_hw
|
169 |
+
max_len, min_len = 800, 400
|
170 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
171 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
172 |
+
W, H = image.size
|
173 |
+
if H > W:
|
174 |
+
H, W = longest_edge, shortest_edge
|
175 |
+
else:
|
176 |
+
H, W = shortest_edge, longest_edge
|
177 |
+
image = image.resize((W, H))
|
178 |
+
buffered = BytesIO()
|
179 |
+
image.save(buffered, format="JPEG")
|
180 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
181 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
182 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
183 |
+
ret.append([msg, None])
|
184 |
+
else:
|
185 |
+
ret.append([msg, None])
|
186 |
+
else:
|
187 |
+
ret[-1][-1] = msg
|
188 |
+
return ret
|
189 |
+
|
190 |
+
def copy(self):
|
191 |
+
return Conversation(
|
192 |
+
system=self.system,
|
193 |
+
roles=self.roles,
|
194 |
+
messages=[[x, y] for x, y in self.messages],
|
195 |
+
offset=self.offset,
|
196 |
+
sep_style=self.sep_style,
|
197 |
+
sep=self.sep,
|
198 |
+
sep2=self.sep2,
|
199 |
+
version=self.version)
|
200 |
+
|
201 |
+
def dict(self):
|
202 |
+
if len(self.get_images()) > 0:
|
203 |
+
return {
|
204 |
+
"system": self.system,
|
205 |
+
"roles": self.roles,
|
206 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
207 |
+
"offset": self.offset,
|
208 |
+
"sep": self.sep,
|
209 |
+
"sep2": self.sep2,
|
210 |
+
}
|
211 |
+
return {
|
212 |
+
"system": self.system,
|
213 |
+
"roles": self.roles,
|
214 |
+
"messages": self.messages,
|
215 |
+
"offset": self.offset,
|
216 |
+
"sep": self.sep,
|
217 |
+
"sep2": self.sep2,
|
218 |
+
}
|
219 |
+
|
220 |
+
|
221 |
+
conv_vicuna_v0 = Conversation(
|
222 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
223 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
224 |
+
roles=("Human", "Assistant"),
|
225 |
+
messages=(
|
226 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
227 |
+
("Assistant",
|
228 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
229 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
230 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
231 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
232 |
+
"renewable and non-renewable energy sources:\n"
|
233 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
234 |
+
"energy sources are finite and will eventually run out.\n"
|
235 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
236 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
237 |
+
"and other negative effects.\n"
|
238 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
239 |
+
"have lower operational costs than non-renewable sources.\n"
|
240 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
241 |
+
"locations than non-renewable sources.\n"
|
242 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
243 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
244 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
245 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
246 |
+
),
|
247 |
+
offset=2,
|
248 |
+
sep_style=SeparatorStyle.SINGLE,
|
249 |
+
sep="###",
|
250 |
+
)
|
251 |
+
|
252 |
+
conv_vicuna_v1 = Conversation(
|
253 |
+
# system="A chat between a curious user and an artificial intelligence assistant. "
|
254 |
+
# "The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
255 |
+
system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
|
256 |
+
"The assistant gives professional, detailed, and polite answers to the user's questions.",
|
257 |
+
roles=("USER", "ASSISTANT"),
|
258 |
+
version="v1",
|
259 |
+
messages=[],
|
260 |
+
offset=0,
|
261 |
+
sep_style=SeparatorStyle.TWO,
|
262 |
+
sep=" ",
|
263 |
+
sep2="</s>",
|
264 |
+
)
|
265 |
+
|
266 |
+
conv_llava_med = Conversation(
|
267 |
+
system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
|
268 |
+
"The assistant gives professional, detailed, and polite answers to the user's questions.",
|
269 |
+
roles=("USER", "ASSISTANT"),
|
270 |
+
version="v1",
|
271 |
+
messages=[],
|
272 |
+
offset=2,
|
273 |
+
sep_style=SeparatorStyle.TWO,
|
274 |
+
sep="###",
|
275 |
+
sep2="</s>"
|
276 |
+
)
|
277 |
+
|
278 |
+
simple_conv_multimodal = Conversation(
|
279 |
+
system="You are LLaVA-Med, a large language and vision assistant trained by a group of researchers at Microsoft, based on the general domain LLaVA architecture."
|
280 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of medical and clinical tasks using natural language."
|
281 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
282 |
+
roles=("Human", "Assistant"),
|
283 |
+
messages=(
|
284 |
+
("Human", "Hi!"),
|
285 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
286 |
+
),
|
287 |
+
offset=2,
|
288 |
+
sep_style=SeparatorStyle.SINGLE,
|
289 |
+
sep="###",
|
290 |
+
)
|
291 |
+
|
292 |
+
conv_llama_2 = Conversation(
|
293 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
294 |
+
|
295 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
296 |
+
roles=("USER", "ASSISTANT"),
|
297 |
+
version="llama_v2",
|
298 |
+
messages=(),
|
299 |
+
offset=0,
|
300 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
301 |
+
sep="<s>",
|
302 |
+
sep2="</s>",
|
303 |
+
)
|
304 |
+
|
305 |
+
conv_llava_llama_2 = Conversation(
|
306 |
+
system="You are a helpful language and vision assistant. "
|
307 |
+
"You are able to understand the visual content that the user provides, "
|
308 |
+
"and assist the user with a variety of tasks using natural language.",
|
309 |
+
roles=("USER", "ASSISTANT"),
|
310 |
+
version="llama_v2",
|
311 |
+
messages=(),
|
312 |
+
offset=0,
|
313 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
314 |
+
sep="<s>",
|
315 |
+
sep2="</s>",
|
316 |
+
)
|
317 |
+
|
318 |
+
conv_mpt = Conversation(
|
319 |
+
system="""<|im_start|>system
|
320 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
321 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
322 |
+
version="mpt",
|
323 |
+
messages=(),
|
324 |
+
offset=0,
|
325 |
+
sep_style=SeparatorStyle.MPT,
|
326 |
+
sep="<|im_end|>",
|
327 |
+
)
|
328 |
+
|
329 |
+
conv_llava_plain = Conversation(
|
330 |
+
system="",
|
331 |
+
roles=("", ""),
|
332 |
+
messages=(
|
333 |
+
),
|
334 |
+
offset=0,
|
335 |
+
sep_style=SeparatorStyle.PLAIN,
|
336 |
+
sep="\n",
|
337 |
+
)
|
338 |
+
|
339 |
+
conv_llava_v0 = Conversation(
|
340 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
341 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
342 |
+
roles=("Human", "Assistant"),
|
343 |
+
messages=(
|
344 |
+
),
|
345 |
+
offset=0,
|
346 |
+
sep_style=SeparatorStyle.SINGLE,
|
347 |
+
sep="###",
|
348 |
+
)
|
349 |
+
|
350 |
+
conv_llava_v0_mmtag = Conversation(
|
351 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
352 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
353 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
354 |
+
roles=("Human", "Assistant"),
|
355 |
+
messages=(
|
356 |
+
),
|
357 |
+
offset=0,
|
358 |
+
sep_style=SeparatorStyle.SINGLE,
|
359 |
+
sep="###",
|
360 |
+
version="v0_mmtag",
|
361 |
+
)
|
362 |
+
|
363 |
+
conv_llava_v1 = Conversation(
|
364 |
+
# system="A chat between a curious human and an artificial intelligence assistant. "
|
365 |
+
# "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
366 |
+
system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
|
367 |
+
"The assistant gives professional, detailed, and polite answers to the user's questions.",
|
368 |
+
roles=("USER", "ASSISTANT"),
|
369 |
+
version="v1",
|
370 |
+
messages=(),
|
371 |
+
offset=0,
|
372 |
+
sep_style=SeparatorStyle.TWO,
|
373 |
+
sep=" ",
|
374 |
+
sep2="</s>",
|
375 |
+
)
|
376 |
+
|
377 |
+
|
378 |
+
conv_llava_v1_mmtag = Conversation(
|
379 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
380 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
381 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
382 |
+
roles=("USER", "ASSISTANT"),
|
383 |
+
messages=(),
|
384 |
+
offset=0,
|
385 |
+
sep_style=SeparatorStyle.TWO,
|
386 |
+
sep=" ",
|
387 |
+
sep2="</s>",
|
388 |
+
version="v1_mmtag",
|
389 |
+
)
|
390 |
+
|
391 |
+
default_conversation = conv_vicuna_v1
|
392 |
+
conv_templates = {
|
393 |
+
"default": conv_vicuna_v0,
|
394 |
+
"v0": conv_vicuna_v0,
|
395 |
+
"v1": conv_vicuna_v1,
|
396 |
+
"llava_med": conv_llava_med,
|
397 |
+
"vicuna_v1": conv_vicuna_v1,
|
398 |
+
"llama_2": conv_llama_2,
|
399 |
+
|
400 |
+
"plain": conv_llava_plain,
|
401 |
+
"v0_plain": conv_llava_plain,
|
402 |
+
"llava_v0": conv_llava_v0,
|
403 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
404 |
+
"llava_v1": conv_llava_v1,
|
405 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
406 |
+
"llava_llama_2": conv_llava_llama_2,
|
407 |
+
"multimodal": simple_conv_multimodal,
|
408 |
+
|
409 |
+
"mpt": conv_mpt,
|
410 |
+
}
|
411 |
+
|
412 |
+
|
413 |
+
if __name__ == "__main__":
|
414 |
+
print(default_conversation.get_prompt())
|
LLAVA_Biovil/llava/eval/__init__.py
ADDED
File without changes
|
LLAVA_Biovil/llava/eval/eval_gpt_review.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import openai
|
6 |
+
import tqdm
|
7 |
+
import ray
|
8 |
+
import time
|
9 |
+
|
10 |
+
NUM_SECONDS_TO_SLEEP = 3
|
11 |
+
|
12 |
+
@ray.remote(num_cpus=4)
|
13 |
+
def get_eval(content: str, max_tokens: int):
|
14 |
+
while True:
|
15 |
+
try:
|
16 |
+
response = openai.ChatCompletion.create(
|
17 |
+
model='gpt-4',
|
18 |
+
messages=[{
|
19 |
+
'role': 'system',
|
20 |
+
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
|
21 |
+
}, {
|
22 |
+
'role': 'user',
|
23 |
+
'content': content,
|
24 |
+
}],
|
25 |
+
temperature=0.2, # TODO: figure out which temperature is best for evaluation
|
26 |
+
max_tokens=max_tokens,
|
27 |
+
)
|
28 |
+
break
|
29 |
+
except openai.error.RateLimitError:
|
30 |
+
pass
|
31 |
+
except Exception as e:
|
32 |
+
print(e)
|
33 |
+
time.sleep(NUM_SECONDS_TO_SLEEP)
|
34 |
+
|
35 |
+
print('success!')
|
36 |
+
return response['choices'][0]['message']['content']
|
37 |
+
|
38 |
+
|
39 |
+
def parse_score(review):
|
40 |
+
try:
|
41 |
+
score_pair = review.split('\n')[0]
|
42 |
+
score_pair = score_pair.replace(',', ' ')
|
43 |
+
sp = score_pair.split(' ')
|
44 |
+
if len(sp) == 2:
|
45 |
+
return [float(sp[0]), float(sp[1])]
|
46 |
+
else:
|
47 |
+
print('error', review)
|
48 |
+
return [-1, -1]
|
49 |
+
except Exception as e:
|
50 |
+
print(e)
|
51 |
+
print('error', review)
|
52 |
+
return [-1, -1]
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
|
57 |
+
parser.add_argument('-q', '--question')
|
58 |
+
# parser.add_argument('-a', '--answer')
|
59 |
+
parser.add_argument('-a', '--answer-list', nargs='+', default=[])
|
60 |
+
parser.add_argument('-r', '--rule')
|
61 |
+
parser.add_argument('-o', '--output')
|
62 |
+
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
|
63 |
+
args = parser.parse_args()
|
64 |
+
|
65 |
+
ray.init()
|
66 |
+
|
67 |
+
f_q = open(os.path.expanduser(args.question))
|
68 |
+
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
|
69 |
+
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
|
70 |
+
rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
|
71 |
+
|
72 |
+
review_file = open(f'{args.output}', 'w')
|
73 |
+
|
74 |
+
js_list = []
|
75 |
+
handles = []
|
76 |
+
idx = 0
|
77 |
+
for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
|
78 |
+
# if idx == 1:
|
79 |
+
# break
|
80 |
+
|
81 |
+
ques = json.loads(ques_js)
|
82 |
+
ans1 = json.loads(ans1_js)
|
83 |
+
ans2 = json.loads(ans2_js)
|
84 |
+
|
85 |
+
category = json.loads(ques_js)['category']
|
86 |
+
if category in rule_dict:
|
87 |
+
rule = rule_dict[category]
|
88 |
+
else:
|
89 |
+
rule = rule_dict['default']
|
90 |
+
prompt = rule['prompt']
|
91 |
+
role = rule['role']
|
92 |
+
content = (f'[Question]\n{ques["text"]}\n\n'
|
93 |
+
f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
|
94 |
+
f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
|
95 |
+
f'[System]\n{prompt}\n\n')
|
96 |
+
js_list.append({
|
97 |
+
'id': idx+1,
|
98 |
+
'question_id': ques['question_id'],
|
99 |
+
'answer1_id': ans1['answer_id'],
|
100 |
+
'answer2_id': ans2['answer_id'],
|
101 |
+
'category': category})
|
102 |
+
idx += 1
|
103 |
+
handles.append(get_eval.remote(content, args.max_tokens))
|
104 |
+
# To avoid the rate limit set by OpenAI
|
105 |
+
time.sleep(NUM_SECONDS_TO_SLEEP)
|
106 |
+
|
107 |
+
reviews = ray.get(handles)
|
108 |
+
for idx, review in enumerate(reviews):
|
109 |
+
scores = parse_score(review)
|
110 |
+
js_list[idx]['content'] = review
|
111 |
+
js_list[idx]['tuple'] = scores
|
112 |
+
review_file.write(json.dumps(js_list[idx]) + '\n')
|
113 |
+
review_file.close()
|
LLAVA_Biovil/llava/eval/eval_gpt_review_bench.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import openai
|
6 |
+
import time
|
7 |
+
|
8 |
+
NUM_SECONDS_TO_SLEEP = 0.5
|
9 |
+
|
10 |
+
|
11 |
+
def get_eval(content: str, max_tokens: int):
|
12 |
+
while True:
|
13 |
+
try:
|
14 |
+
response = openai.ChatCompletion.create(
|
15 |
+
model='gpt-4-0314',
|
16 |
+
messages=[{
|
17 |
+
'role': 'system',
|
18 |
+
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
|
19 |
+
}, {
|
20 |
+
'role': 'user',
|
21 |
+
'content': content,
|
22 |
+
}],
|
23 |
+
temperature=0.2, # TODO: figure out which temperature is best for evaluation
|
24 |
+
max_tokens=max_tokens,
|
25 |
+
)
|
26 |
+
break
|
27 |
+
except openai.error.RateLimitError:
|
28 |
+
pass
|
29 |
+
except Exception as e:
|
30 |
+
print(e)
|
31 |
+
time.sleep(NUM_SECONDS_TO_SLEEP)
|
32 |
+
|
33 |
+
return response['choices'][0]['message']['content']
|
34 |
+
|
35 |
+
|
36 |
+
def parse_score(review):
|
37 |
+
try:
|
38 |
+
score_pair = review.split('\n')[0]
|
39 |
+
score_pair = score_pair.replace(',', ' ')
|
40 |
+
sp = score_pair.split(' ')
|
41 |
+
if len(sp) == 2:
|
42 |
+
return [float(sp[0]), float(sp[1])]
|
43 |
+
else:
|
44 |
+
print('error', review)
|
45 |
+
return [-1, -1]
|
46 |
+
except Exception as e:
|
47 |
+
print(e)
|
48 |
+
print('error', review)
|
49 |
+
return [-1, -1]
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
|
54 |
+
parser.add_argument('-q', '--question')
|
55 |
+
parser.add_argument('-c', '--context')
|
56 |
+
parser.add_argument('-a', '--answer-list', nargs='+', default=[])
|
57 |
+
parser.add_argument('-r', '--rule')
|
58 |
+
parser.add_argument('-o', '--output')
|
59 |
+
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
f_q = open(os.path.expanduser(args.question))
|
63 |
+
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
|
64 |
+
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
|
65 |
+
rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
|
66 |
+
|
67 |
+
if os.path.isfile(os.path.expanduser(args.output)):
|
68 |
+
cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
|
69 |
+
else:
|
70 |
+
cur_reviews = []
|
71 |
+
|
72 |
+
review_file = open(f'{args.output}', 'a')
|
73 |
+
|
74 |
+
context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
|
75 |
+
image_to_context = {context['image']: context for context in context_list}
|
76 |
+
|
77 |
+
handles = []
|
78 |
+
idx = 0
|
79 |
+
for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
|
80 |
+
ques = json.loads(ques_js)
|
81 |
+
ans1 = json.loads(ans1_js)
|
82 |
+
ans2 = json.loads(ans2_js)
|
83 |
+
|
84 |
+
inst = image_to_context[ques['image']]
|
85 |
+
|
86 |
+
if isinstance(inst['caption'], list):
|
87 |
+
cap_str = '\n'.join(inst['caption'])
|
88 |
+
else:
|
89 |
+
cap_str = inst['caption']
|
90 |
+
|
91 |
+
category = 'llava_bench_' + json.loads(ques_js)['category']
|
92 |
+
if category in rule_dict:
|
93 |
+
rule = rule_dict[category]
|
94 |
+
else:
|
95 |
+
assert False, f"Visual QA category not found in rule file: {category}."
|
96 |
+
prompt = rule['prompt']
|
97 |
+
role = rule['role']
|
98 |
+
content = (f'[Context]\n{cap_str}\n\n'
|
99 |
+
f'[Question]\n{ques["text"]}\n\n'
|
100 |
+
f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
|
101 |
+
f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
|
102 |
+
f'[System]\n{prompt}\n\n')
|
103 |
+
cur_js = {
|
104 |
+
'id': idx+1,
|
105 |
+
'question_id': ques['question_id'],
|
106 |
+
'answer1_id': ans1.get('answer_id', ans1['question_id']),
|
107 |
+
'answer2_id': ans2.get('answer_id', ans2['answer_id']),
|
108 |
+
'category': category
|
109 |
+
}
|
110 |
+
if idx >= len(cur_reviews):
|
111 |
+
review = get_eval(content, args.max_tokens)
|
112 |
+
scores = parse_score(review)
|
113 |
+
cur_js['content'] = review
|
114 |
+
cur_js['tuple'] = scores
|
115 |
+
review_file.write(json.dumps(cur_js) + '\n')
|
116 |
+
review_file.flush()
|
117 |
+
else:
|
118 |
+
print(f'Skipping {idx} as we already have it.')
|
119 |
+
idx += 1
|
120 |
+
print(idx)
|
121 |
+
review_file.close()
|
LLAVA_Biovil/llava/eval/eval_gpt_review_visual.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import openai
|
6 |
+
import time
|
7 |
+
|
8 |
+
NUM_SECONDS_TO_SLEEP = 0.5
|
9 |
+
|
10 |
+
|
11 |
+
def get_eval(content: str, max_tokens: int):
|
12 |
+
while True:
|
13 |
+
try:
|
14 |
+
response = openai.ChatCompletion.create(
|
15 |
+
model='gpt-4-0314',
|
16 |
+
messages=[{
|
17 |
+
'role': 'system',
|
18 |
+
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
|
19 |
+
}, {
|
20 |
+
'role': 'user',
|
21 |
+
'content': content,
|
22 |
+
}],
|
23 |
+
temperature=0.2, # TODO: figure out which temperature is best for evaluation
|
24 |
+
max_tokens=max_tokens,
|
25 |
+
)
|
26 |
+
break
|
27 |
+
except openai.error.RateLimitError:
|
28 |
+
pass
|
29 |
+
except Exception as e:
|
30 |
+
print(e)
|
31 |
+
time.sleep(NUM_SECONDS_TO_SLEEP)
|
32 |
+
|
33 |
+
return response['choices'][0]['message']['content']
|
34 |
+
|
35 |
+
|
36 |
+
def parse_score(review):
|
37 |
+
try:
|
38 |
+
score_pair = review.split('\n')[0]
|
39 |
+
score_pair = score_pair.replace(',', ' ')
|
40 |
+
sp = score_pair.split(' ')
|
41 |
+
if len(sp) == 2:
|
42 |
+
return [float(sp[0]), float(sp[1])]
|
43 |
+
else:
|
44 |
+
print('error', review)
|
45 |
+
return [-1, -1]
|
46 |
+
except Exception as e:
|
47 |
+
print(e)
|
48 |
+
print('error', review)
|
49 |
+
return [-1, -1]
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
|
54 |
+
parser.add_argument('-q', '--question')
|
55 |
+
parser.add_argument('-c', '--context')
|
56 |
+
parser.add_argument('-a', '--answer-list', nargs='+', default=[])
|
57 |
+
parser.add_argument('-r', '--rule')
|
58 |
+
parser.add_argument('-o', '--output')
|
59 |
+
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
f_q = open(os.path.expanduser(args.question))
|
63 |
+
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
|
64 |
+
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
|
65 |
+
rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
|
66 |
+
|
67 |
+
if os.path.isfile(os.path.expanduser(args.output)):
|
68 |
+
cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
|
69 |
+
else:
|
70 |
+
cur_reviews = []
|
71 |
+
|
72 |
+
review_file = open(f'{args.output}', 'a')
|
73 |
+
|
74 |
+
context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
|
75 |
+
image_to_context = {context['image']: context for context in context_list}
|
76 |
+
|
77 |
+
handles = []
|
78 |
+
idx = 0
|
79 |
+
for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
|
80 |
+
ques = json.loads(ques_js)
|
81 |
+
ans1 = json.loads(ans1_js)
|
82 |
+
ans2 = json.loads(ans2_js)
|
83 |
+
|
84 |
+
inst = image_to_context[ques['image']]
|
85 |
+
cap_str = '\n'.join(inst['captions'])
|
86 |
+
box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
|
87 |
+
|
88 |
+
category = json.loads(ques_js)['category']
|
89 |
+
if category in rule_dict:
|
90 |
+
rule = rule_dict[category]
|
91 |
+
else:
|
92 |
+
assert False, f"Visual QA category not found in rule file: {category}."
|
93 |
+
prompt = rule['prompt']
|
94 |
+
role = rule['role']
|
95 |
+
content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
|
96 |
+
f'[Question]\n{ques["text"]}\n\n'
|
97 |
+
f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
|
98 |
+
f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
|
99 |
+
f'[System]\n{prompt}\n\n')
|
100 |
+
cur_js = {
|
101 |
+
'id': idx+1,
|
102 |
+
'question_id': ques['question_id'],
|
103 |
+
'answer1_id': ans1.get('answer_id', ans1['question_id']),
|
104 |
+
'answer2_id': ans2.get('answer_id', ans2['answer_id']),
|
105 |
+
'category': category
|
106 |
+
}
|
107 |
+
if idx >= len(cur_reviews):
|
108 |
+
review = get_eval(content, args.max_tokens)
|
109 |
+
scores = parse_score(review)
|
110 |
+
cur_js['content'] = review
|
111 |
+
cur_js['tuple'] = scores
|
112 |
+
review_file.write(json.dumps(cur_js) + '\n')
|
113 |
+
review_file.flush()
|
114 |
+
else:
|
115 |
+
print(f'Skipping {idx} as we already have it.')
|
116 |
+
idx += 1
|
117 |
+
print(idx)
|
118 |
+
review_file.close()
|
LLAVA_Biovil/llava/eval/eval_pope.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
def eval_pope(answers, label_file):
|
6 |
+
label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
|
7 |
+
|
8 |
+
for answer in answers:
|
9 |
+
text = answer['text']
|
10 |
+
|
11 |
+
# Only keep the first sentence
|
12 |
+
if text.find('.') != -1:
|
13 |
+
text = text.split('.')[0]
|
14 |
+
|
15 |
+
text = text.replace(',', '')
|
16 |
+
words = text.split(' ')
|
17 |
+
if 'No' in words or 'not' in words or 'no' in words:
|
18 |
+
answer['text'] = 'no'
|
19 |
+
else:
|
20 |
+
answer['text'] = 'yes'
|
21 |
+
|
22 |
+
for i in range(len(label_list)):
|
23 |
+
if label_list[i] == 'no':
|
24 |
+
label_list[i] = 0
|
25 |
+
else:
|
26 |
+
label_list[i] = 1
|
27 |
+
|
28 |
+
pred_list = []
|
29 |
+
for answer in answers:
|
30 |
+
if answer['text'] == 'no':
|
31 |
+
pred_list.append(0)
|
32 |
+
else:
|
33 |
+
pred_list.append(1)
|
34 |
+
|
35 |
+
pos = 1
|
36 |
+
neg = 0
|
37 |
+
yes_ratio = pred_list.count(1) / len(pred_list)
|
38 |
+
|
39 |
+
TP, TN, FP, FN = 0, 0, 0, 0
|
40 |
+
for pred, label in zip(pred_list, label_list):
|
41 |
+
if pred == pos and label == pos:
|
42 |
+
TP += 1
|
43 |
+
elif pred == pos and label == neg:
|
44 |
+
FP += 1
|
45 |
+
elif pred == neg and label == neg:
|
46 |
+
TN += 1
|
47 |
+
elif pred == neg and label == pos:
|
48 |
+
FN += 1
|
49 |
+
|
50 |
+
print('TP\tFP\tTN\tFN\t')
|
51 |
+
print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
|
52 |
+
|
53 |
+
precision = float(TP) / float(TP + FP)
|
54 |
+
recall = float(TP) / float(TP + FN)
|
55 |
+
f1 = 2*precision*recall / (precision + recall)
|
56 |
+
acc = (TP + TN) / (TP + TN + FP + FN)
|
57 |
+
print('Accuracy: {}'.format(acc))
|
58 |
+
print('Precision: {}'.format(precision))
|
59 |
+
print('Recall: {}'.format(recall))
|
60 |
+
print('F1 score: {}'.format(f1))
|
61 |
+
print('Yes ratio: {}'.format(yes_ratio))
|
62 |
+
print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
parser = argparse.ArgumentParser()
|
66 |
+
parser.add_argument("--annotation-dir", type=str)
|
67 |
+
parser.add_argument("--question-file", type=str)
|
68 |
+
parser.add_argument("--result-file", type=str)
|
69 |
+
args = parser.parse_args()
|
70 |
+
|
71 |
+
questions = [json.loads(line) for line in open(args.question_file)]
|
72 |
+
questions = {question['question_id']: question for question in questions}
|
73 |
+
answers = [json.loads(q) for q in open(args.result_file)]
|
74 |
+
for file in os.listdir(args.annotation_dir):
|
75 |
+
assert file.startswith('coco_pope_')
|
76 |
+
assert file.endswith('.json')
|
77 |
+
category = file[10:-5]
|
78 |
+
cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
|
79 |
+
print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
|
80 |
+
eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
|
81 |
+
print("====================================")
|
LLAVA_Biovil/llava/eval/eval_science_qa.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import random
|
6 |
+
|
7 |
+
|
8 |
+
def get_args():
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument('--base-dir', type=str)
|
11 |
+
parser.add_argument('--result-file', type=str)
|
12 |
+
parser.add_argument('--output-file', type=str)
|
13 |
+
parser.add_argument('--output-result', type=str)
|
14 |
+
parser.add_argument('--split', type=str, default='test')
|
15 |
+
parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
|
16 |
+
return parser.parse_args()
|
17 |
+
|
18 |
+
|
19 |
+
def convert_caps(results):
|
20 |
+
fakecaps = []
|
21 |
+
for result in results:
|
22 |
+
image_id = result['question_id']
|
23 |
+
caption = result['text']
|
24 |
+
fakecaps.append({"image_id": int(image_id), "caption": caption})
|
25 |
+
return fakecaps
|
26 |
+
|
27 |
+
|
28 |
+
def get_pred_idx(prediction, choices, options):
|
29 |
+
"""
|
30 |
+
Get the index (e.g. 2) from the prediction (e.g. 'C')
|
31 |
+
"""
|
32 |
+
if prediction in options[:len(choices)]:
|
33 |
+
return options.index(prediction)
|
34 |
+
else:
|
35 |
+
return -1
|
36 |
+
return random.choice(range(len(choices)))
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
args = get_args()
|
41 |
+
|
42 |
+
base_dir = args.base_dir
|
43 |
+
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
|
44 |
+
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
45 |
+
predictions = [json.loads(line) for line in open(args.result_file)]
|
46 |
+
predictions = {pred['question_id']: pred for pred in predictions}
|
47 |
+
split_problems = {idx: problems[idx] for idx in split_indices}
|
48 |
+
|
49 |
+
results = {'correct': [], 'incorrect': []}
|
50 |
+
sqa_results = {}
|
51 |
+
sqa_results['acc'] = None
|
52 |
+
sqa_results['correct'] = None
|
53 |
+
sqa_results['count'] = None
|
54 |
+
sqa_results['results'] = {}
|
55 |
+
sqa_results['outputs'] = {}
|
56 |
+
|
57 |
+
for prob_id, prob in split_problems.items():
|
58 |
+
if prob_id not in predictions:
|
59 |
+
pred = {'text': 'FAILED', 'prompt': 'Unknown'}
|
60 |
+
pred_text = 'FAILED'
|
61 |
+
else:
|
62 |
+
pred = predictions[prob_id]
|
63 |
+
pred_text = pred['text']
|
64 |
+
|
65 |
+
if pred_text in args.options:
|
66 |
+
answer = pred_text
|
67 |
+
elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
|
68 |
+
answer = pred_text[0]
|
69 |
+
else:
|
70 |
+
pattern = re.compile(r'The answer is ([A-Z]).')
|
71 |
+
res = pattern.findall(pred_text)
|
72 |
+
if len(res) == 1:
|
73 |
+
answer = res[0] # 'A', 'B', ...
|
74 |
+
else:
|
75 |
+
answer = "FAILED"
|
76 |
+
|
77 |
+
pred_idx = get_pred_idx(answer, prob['choices'], args.options)
|
78 |
+
|
79 |
+
analysis = {
|
80 |
+
'question_id': prob_id,
|
81 |
+
'parsed_ans': answer,
|
82 |
+
'ground_truth': args.options[prob['answer']],
|
83 |
+
'question': pred['prompt'],
|
84 |
+
'pred': pred_text,
|
85 |
+
'is_multimodal': '<image>' in pred['prompt'],
|
86 |
+
}
|
87 |
+
|
88 |
+
sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
|
89 |
+
sqa_results['outputs'][prob_id] = pred_text
|
90 |
+
|
91 |
+
if pred_idx == prob['answer']:
|
92 |
+
results['correct'].append(analysis)
|
93 |
+
else:
|
94 |
+
results['incorrect'].append(analysis)
|
95 |
+
|
96 |
+
correct = len(results['correct'])
|
97 |
+
total = len(results['correct']) + len(results['incorrect'])
|
98 |
+
|
99 |
+
###### IMG ######
|
100 |
+
multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
|
101 |
+
multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
|
102 |
+
multimodal_total = multimodal_correct + multimodal_incorrect
|
103 |
+
###### IMG ######
|
104 |
+
|
105 |
+
print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
|
106 |
+
|
107 |
+
sqa_results['acc'] = correct / total * 100
|
108 |
+
sqa_results['correct'] = correct
|
109 |
+
sqa_results['count'] = total
|
110 |
+
|
111 |
+
with open(args.output_file, 'w') as f:
|
112 |
+
json.dump(results, f, indent=2)
|
113 |
+
with open(args.output_result, 'w') as f:
|
114 |
+
json.dump(sqa_results, f, indent=2)
|
LLAVA_Biovil/llava/eval/eval_science_qa_gpt4.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import random
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
|
9 |
+
def get_args():
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument('--base-dir', type=str)
|
12 |
+
parser.add_argument('--gpt4-result', type=str)
|
13 |
+
parser.add_argument('--our-result', type=str)
|
14 |
+
parser.add_argument('--split', type=str, default='test')
|
15 |
+
parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
|
16 |
+
return parser.parse_args()
|
17 |
+
|
18 |
+
|
19 |
+
def convert_caps(results):
|
20 |
+
fakecaps = []
|
21 |
+
for result in results:
|
22 |
+
image_id = result['question_id']
|
23 |
+
caption = result['text']
|
24 |
+
fakecaps.append({"image_id": int(image_id), "caption": caption})
|
25 |
+
return fakecaps
|
26 |
+
|
27 |
+
|
28 |
+
def get_pred_idx(prediction, choices, options):
|
29 |
+
"""
|
30 |
+
Get the index (e.g. 2) from the prediction (e.g. 'C')
|
31 |
+
"""
|
32 |
+
if prediction in options[:len(choices)]:
|
33 |
+
return options.index(prediction)
|
34 |
+
else:
|
35 |
+
return random.choice(range(len(choices)))
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
args = get_args()
|
40 |
+
|
41 |
+
base_dir = args.base_dir
|
42 |
+
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
|
43 |
+
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
44 |
+
our_predictions = [json.loads(line) for line in open(args.our_result)]
|
45 |
+
our_predictions = {pred['question_id']: pred for pred in our_predictions}
|
46 |
+
split_problems = {idx: problems[idx] for idx in split_indices}
|
47 |
+
|
48 |
+
gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
|
49 |
+
|
50 |
+
results = defaultdict(lambda: 0)
|
51 |
+
|
52 |
+
for prob_id, prob in split_problems.items():
|
53 |
+
if prob_id not in our_predictions:
|
54 |
+
continue
|
55 |
+
if prob_id not in gpt4_predictions:
|
56 |
+
continue
|
57 |
+
our_pred = our_predictions[prob_id]['text']
|
58 |
+
gpt4_pred = gpt4_predictions[prob_id]
|
59 |
+
|
60 |
+
pattern = re.compile(r'The answer is ([A-Z]).')
|
61 |
+
our_res = pattern.findall(our_pred)
|
62 |
+
if len(our_res) == 1:
|
63 |
+
our_answer = our_res[0] # 'A', 'B', ...
|
64 |
+
else:
|
65 |
+
our_answer = "FAILED"
|
66 |
+
gpt4_res = pattern.findall(gpt4_pred)
|
67 |
+
if len(gpt4_res) == 1:
|
68 |
+
gpt4_answer = gpt4_res[0] # 'A', 'B', ...
|
69 |
+
else:
|
70 |
+
gpt4_answer = "FAILED"
|
71 |
+
|
72 |
+
our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
|
73 |
+
gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
|
74 |
+
|
75 |
+
if gpt4_answer == 'FAILED':
|
76 |
+
results['gpt4_failed'] += 1
|
77 |
+
# continue
|
78 |
+
gpt4_pred_idx = our_pred_idx
|
79 |
+
# if our_pred_idx != prob['answer']:
|
80 |
+
# print(our_predictions[prob_id]['prompt'])
|
81 |
+
# print('-----------------')
|
82 |
+
# print(f'LECTURE: {prob["lecture"]}')
|
83 |
+
# print(f'SOLUTION: {prob["solution"]}')
|
84 |
+
# print('=====================')
|
85 |
+
else:
|
86 |
+
# continue
|
87 |
+
pass
|
88 |
+
# gpt4_pred_idx = our_pred_idx
|
89 |
+
|
90 |
+
if gpt4_pred_idx == prob['answer']:
|
91 |
+
results['correct'] += 1
|
92 |
+
else:
|
93 |
+
results['incorrect'] += 1
|
94 |
+
|
95 |
+
|
96 |
+
if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
|
97 |
+
results['correct_upperbound'] += 1
|
98 |
+
|
99 |
+
correct = results['correct']
|
100 |
+
total = results['correct'] + results['incorrect']
|
101 |
+
print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
|
102 |
+
print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
|
103 |
+
print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
|
104 |
+
|
LLAVA_Biovil/llava/eval/eval_science_qa_gpt4_requery.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import random
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
|
9 |
+
def get_args():
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument('--base-dir', type=str)
|
12 |
+
parser.add_argument('--gpt4-result', type=str)
|
13 |
+
parser.add_argument('--requery-result', type=str)
|
14 |
+
parser.add_argument('--our-result', type=str)
|
15 |
+
parser.add_argument('--output-result', type=str)
|
16 |
+
parser.add_argument('--split', type=str, default='test')
|
17 |
+
parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
|
18 |
+
return parser.parse_args()
|
19 |
+
|
20 |
+
|
21 |
+
def convert_caps(results):
|
22 |
+
fakecaps = []
|
23 |
+
for result in results:
|
24 |
+
image_id = result['question_id']
|
25 |
+
caption = result['text']
|
26 |
+
fakecaps.append({"image_id": int(image_id), "caption": caption})
|
27 |
+
return fakecaps
|
28 |
+
|
29 |
+
|
30 |
+
def get_pred_idx(prediction, choices, options):
|
31 |
+
"""
|
32 |
+
Get the index (e.g. 2) from the prediction (e.g. 'C')
|
33 |
+
"""
|
34 |
+
if prediction in options[:len(choices)]:
|
35 |
+
return options.index(prediction)
|
36 |
+
else:
|
37 |
+
return random.choice(range(len(choices)))
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
args = get_args()
|
42 |
+
|
43 |
+
base_dir = args.base_dir
|
44 |
+
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
|
45 |
+
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
46 |
+
our_predictions = [json.loads(line) for line in open(args.our_result)]
|
47 |
+
our_predictions = {pred['question_id']: pred for pred in our_predictions}
|
48 |
+
split_problems = {idx: problems[idx] for idx in split_indices}
|
49 |
+
|
50 |
+
requery_predictions = [json.loads(line) for line in open(args.requery_result)]
|
51 |
+
requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
|
52 |
+
|
53 |
+
gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
|
54 |
+
|
55 |
+
results = defaultdict(lambda: 0)
|
56 |
+
|
57 |
+
sqa_results = {}
|
58 |
+
sqa_results['acc'] = None
|
59 |
+
sqa_results['correct'] = None
|
60 |
+
sqa_results['count'] = None
|
61 |
+
sqa_results['results'] = {}
|
62 |
+
sqa_results['outputs'] = {}
|
63 |
+
|
64 |
+
for prob_id, prob in split_problems.items():
|
65 |
+
if prob_id not in our_predictions:
|
66 |
+
assert False
|
67 |
+
if prob_id not in gpt4_predictions:
|
68 |
+
assert False
|
69 |
+
our_pred = our_predictions[prob_id]['text']
|
70 |
+
gpt4_pred = gpt4_predictions[prob_id]
|
71 |
+
if prob_id not in requery_predictions:
|
72 |
+
results['missing_requery'] += 1
|
73 |
+
requery_pred = "MISSING"
|
74 |
+
else:
|
75 |
+
requery_pred = requery_predictions[prob_id]['text']
|
76 |
+
|
77 |
+
pattern = re.compile(r'The answer is ([A-Z]).')
|
78 |
+
our_res = pattern.findall(our_pred)
|
79 |
+
if len(our_res) == 1:
|
80 |
+
our_answer = our_res[0] # 'A', 'B', ...
|
81 |
+
else:
|
82 |
+
our_answer = "FAILED"
|
83 |
+
|
84 |
+
requery_res = pattern.findall(requery_pred)
|
85 |
+
if len(requery_res) == 1:
|
86 |
+
requery_answer = requery_res[0] # 'A', 'B', ...
|
87 |
+
else:
|
88 |
+
requery_answer = "FAILED"
|
89 |
+
|
90 |
+
gpt4_res = pattern.findall(gpt4_pred)
|
91 |
+
if len(gpt4_res) == 1:
|
92 |
+
gpt4_answer = gpt4_res[0] # 'A', 'B', ...
|
93 |
+
else:
|
94 |
+
gpt4_answer = "FAILED"
|
95 |
+
|
96 |
+
our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
|
97 |
+
gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
|
98 |
+
requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
|
99 |
+
|
100 |
+
results['total'] += 1
|
101 |
+
|
102 |
+
if gpt4_answer == 'FAILED':
|
103 |
+
results['gpt4_failed'] += 1
|
104 |
+
if gpt4_pred_idx == prob['answer']:
|
105 |
+
results['gpt4_correct'] += 1
|
106 |
+
if our_pred_idx == prob['answer']:
|
107 |
+
results['gpt4_ourvisual_correct'] += 1
|
108 |
+
elif gpt4_pred_idx == prob['answer']:
|
109 |
+
results['gpt4_correct'] += 1
|
110 |
+
results['gpt4_ourvisual_correct'] += 1
|
111 |
+
|
112 |
+
if our_pred_idx == prob['answer']:
|
113 |
+
results['our_correct'] += 1
|
114 |
+
|
115 |
+
if requery_answer == 'FAILED':
|
116 |
+
sqa_results['results'][prob_id] = our_pred_idx
|
117 |
+
if our_pred_idx == prob['answer']:
|
118 |
+
results['requery_correct'] += 1
|
119 |
+
else:
|
120 |
+
sqa_results['results'][prob_id] = requery_pred_idx
|
121 |
+
if requery_pred_idx == prob['answer']:
|
122 |
+
results['requery_correct'] += 1
|
123 |
+
else:
|
124 |
+
print(f"""
|
125 |
+
Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
|
126 |
+
Our ({our_answer}): {our_pred}
|
127 |
+
GPT-4 ({gpt4_answer}): {gpt4_pred}
|
128 |
+
Requery ({requery_answer}): {requery_pred}
|
129 |
+
print("=====================================")
|
130 |
+
""")
|
131 |
+
|
132 |
+
if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
|
133 |
+
results['correct_upperbound'] += 1
|
134 |
+
|
135 |
+
total = results['total']
|
136 |
+
print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
|
137 |
+
print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
|
138 |
+
print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
|
139 |
+
print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
|
140 |
+
print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
|
141 |
+
print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
|
142 |
+
|
143 |
+
sqa_results['acc'] = results["requery_correct"] / total * 100
|
144 |
+
sqa_results['correct'] = results["requery_correct"]
|
145 |
+
sqa_results['count'] = total
|
146 |
+
|
147 |
+
with open(args.output_result, 'w') as f:
|
148 |
+
json.dump(sqa_results, f, indent=2)
|
149 |
+
|
LLAVA_Biovil/llava/eval/eval_textvqa.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
|
6 |
+
from LLAV.llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
|
7 |
+
|
8 |
+
|
9 |
+
def get_args():
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument('--annotation-file', type=str)
|
12 |
+
parser.add_argument('--result-file', type=str)
|
13 |
+
parser.add_argument('--result-dir', type=str)
|
14 |
+
return parser.parse_args()
|
15 |
+
|
16 |
+
|
17 |
+
def prompt_processor(prompt):
|
18 |
+
if prompt.startswith('OCR tokens: '):
|
19 |
+
pattern = r"Question: (.*?) Short answer:"
|
20 |
+
match = re.search(pattern, prompt, re.DOTALL)
|
21 |
+
question = match.group(1)
|
22 |
+
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
|
23 |
+
if prompt.startswith('Reference OCR token:'):
|
24 |
+
question = prompt.split('\n')[1]
|
25 |
+
else:
|
26 |
+
question = prompt.split('\n')[0]
|
27 |
+
elif len(prompt.split('\n')) == 2:
|
28 |
+
question = prompt.split('\n')[0]
|
29 |
+
else:
|
30 |
+
assert False
|
31 |
+
|
32 |
+
return question.lower()
|
33 |
+
|
34 |
+
|
35 |
+
def eval_single(annotation_file, result_file):
|
36 |
+
experiment_name = os.path.splitext(os.path.basename(result_file))[0]
|
37 |
+
print(experiment_name)
|
38 |
+
annotations = json.load(open(annotation_file))['data']
|
39 |
+
annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
|
40 |
+
results = [json.loads(line) for line in open(result_file)]
|
41 |
+
|
42 |
+
pred_list = []
|
43 |
+
for result in results:
|
44 |
+
annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
|
45 |
+
pred_list.append({
|
46 |
+
"pred_answer": result['text'],
|
47 |
+
"gt_answers": annotation['answers'],
|
48 |
+
})
|
49 |
+
|
50 |
+
evaluator = TextVQAAccuracyEvaluator()
|
51 |
+
print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
args = get_args()
|
56 |
+
|
57 |
+
if args.result_file is not None:
|
58 |
+
eval_single(args.annotation_file, args.result_file)
|
59 |
+
|
60 |
+
if args.result_dir is not None:
|
61 |
+
for result_file in sorted(os.listdir(args.result_dir)):
|
62 |
+
if not result_file.endswith('.jsonl'):
|
63 |
+
print(f'Skipping {result_file}')
|
64 |
+
continue
|
65 |
+
eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
|
LLAVA_Biovil/llava/eval/generate_webpage_data_from_table.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Generate json file for webpage."""
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
# models = ['llama', 'alpaca', 'gpt35', 'bard']
|
7 |
+
models = ['vicuna']
|
8 |
+
|
9 |
+
|
10 |
+
def read_jsonl(path: str, key: str=None):
|
11 |
+
data = []
|
12 |
+
with open(os.path.expanduser(path)) as f:
|
13 |
+
for line in f:
|
14 |
+
if not line:
|
15 |
+
continue
|
16 |
+
data.append(json.loads(line))
|
17 |
+
if key is not None:
|
18 |
+
data.sort(key=lambda x: x[key])
|
19 |
+
data = {item[key]: item for item in data}
|
20 |
+
return data
|
21 |
+
|
22 |
+
|
23 |
+
def trim_hanging_lines(s: str, n: int) -> str:
|
24 |
+
s = s.strip()
|
25 |
+
for _ in range(n):
|
26 |
+
s = s.split('\n', 1)[1].strip()
|
27 |
+
return s
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
questions = read_jsonl('table/question.jsonl', key='question_id')
|
32 |
+
|
33 |
+
# alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
|
34 |
+
# bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
|
35 |
+
# gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
|
36 |
+
# llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
|
37 |
+
vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
|
38 |
+
ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
|
39 |
+
|
40 |
+
review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
|
41 |
+
# review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
|
42 |
+
# review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
|
43 |
+
# review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
|
44 |
+
# review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
|
45 |
+
|
46 |
+
records = []
|
47 |
+
for qid in questions.keys():
|
48 |
+
r = {
|
49 |
+
'id': qid,
|
50 |
+
'category': questions[qid]['category'],
|
51 |
+
'question': questions[qid]['text'],
|
52 |
+
'answers': {
|
53 |
+
# 'alpaca': alpaca_answers[qid]['text'],
|
54 |
+
# 'llama': llama_answers[qid]['text'],
|
55 |
+
# 'bard': bard_answers[qid]['text'],
|
56 |
+
# 'gpt35': gpt35_answers[qid]['text'],
|
57 |
+
'vicuna': vicuna_answers[qid]['text'],
|
58 |
+
'ours': ours_answers[qid]['text'],
|
59 |
+
},
|
60 |
+
'evaluations': {
|
61 |
+
# 'alpaca': review_alpaca[qid]['text'],
|
62 |
+
# 'llama': review_llama[qid]['text'],
|
63 |
+
# 'bard': review_bard[qid]['text'],
|
64 |
+
'vicuna': review_vicuna[qid]['content'],
|
65 |
+
# 'gpt35': review_gpt35[qid]['text'],
|
66 |
+
},
|
67 |
+
'scores': {
|
68 |
+
'vicuna': review_vicuna[qid]['tuple'],
|
69 |
+
# 'alpaca': review_alpaca[qid]['score'],
|
70 |
+
# 'llama': review_llama[qid]['score'],
|
71 |
+
# 'bard': review_bard[qid]['score'],
|
72 |
+
# 'gpt35': review_gpt35[qid]['score'],
|
73 |
+
},
|
74 |
+
}
|
75 |
+
|
76 |
+
# cleanup data
|
77 |
+
cleaned_evals = {}
|
78 |
+
for k, v in r['evaluations'].items():
|
79 |
+
v = v.strip()
|
80 |
+
lines = v.split('\n')
|
81 |
+
# trim the first line if it's a pair of numbers
|
82 |
+
if re.match(r'\d+[, ]+\d+', lines[0]):
|
83 |
+
lines = lines[1:]
|
84 |
+
v = '\n'.join(lines)
|
85 |
+
cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
|
86 |
+
|
87 |
+
r['evaluations'] = cleaned_evals
|
88 |
+
records.append(r)
|
89 |
+
|
90 |
+
# Reorder the records, this is optional
|
91 |
+
for r in records:
|
92 |
+
if r['id'] <= 20:
|
93 |
+
r['id'] += 60
|
94 |
+
else:
|
95 |
+
r['id'] -= 20
|
96 |
+
for r in records:
|
97 |
+
if r['id'] <= 50:
|
98 |
+
r['id'] += 10
|
99 |
+
elif 50 < r['id'] <= 60:
|
100 |
+
r['id'] -= 50
|
101 |
+
for r in records:
|
102 |
+
if r['id'] == 7:
|
103 |
+
r['id'] = 1
|
104 |
+
elif r['id'] < 7:
|
105 |
+
r['id'] += 1
|
106 |
+
|
107 |
+
records.sort(key=lambda x: x['id'])
|
108 |
+
|
109 |
+
# Write to file
|
110 |
+
with open('webpage/data.json', 'w') as f:
|
111 |
+
json.dump({'questions': records, 'models': models}, f, indent=2)
|
LLAVA_Biovil/llava/eval/m4c_evaluator.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import re
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
class EvalAIAnswerProcessor:
|
8 |
+
"""
|
9 |
+
Processes an answer similar to Eval AI
|
10 |
+
copied from
|
11 |
+
https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
|
12 |
+
"""
|
13 |
+
|
14 |
+
CONTRACTIONS = {
|
15 |
+
"aint": "ain't",
|
16 |
+
"arent": "aren't",
|
17 |
+
"cant": "can't",
|
18 |
+
"couldve": "could've",
|
19 |
+
"couldnt": "couldn't",
|
20 |
+
"couldn'tve": "couldn't've",
|
21 |
+
"couldnt've": "couldn't've",
|
22 |
+
"didnt": "didn't",
|
23 |
+
"doesnt": "doesn't",
|
24 |
+
"dont": "don't",
|
25 |
+
"hadnt": "hadn't",
|
26 |
+
"hadnt've": "hadn't've",
|
27 |
+
"hadn'tve": "hadn't've",
|
28 |
+
"hasnt": "hasn't",
|
29 |
+
"havent": "haven't",
|
30 |
+
"hed": "he'd",
|
31 |
+
"hed've": "he'd've",
|
32 |
+
"he'dve": "he'd've",
|
33 |
+
"hes": "he's",
|
34 |
+
"howd": "how'd",
|
35 |
+
"howll": "how'll",
|
36 |
+
"hows": "how's",
|
37 |
+
"Id've": "I'd've",
|
38 |
+
"I'dve": "I'd've",
|
39 |
+
"Im": "I'm",
|
40 |
+
"Ive": "I've",
|
41 |
+
"isnt": "isn't",
|
42 |
+
"itd": "it'd",
|
43 |
+
"itd've": "it'd've",
|
44 |
+
"it'dve": "it'd've",
|
45 |
+
"itll": "it'll",
|
46 |
+
"let's": "let's",
|
47 |
+
"maam": "ma'am",
|
48 |
+
"mightnt": "mightn't",
|
49 |
+
"mightnt've": "mightn't've",
|
50 |
+
"mightn'tve": "mightn't've",
|
51 |
+
"mightve": "might've",
|
52 |
+
"mustnt": "mustn't",
|
53 |
+
"mustve": "must've",
|
54 |
+
"neednt": "needn't",
|
55 |
+
"notve": "not've",
|
56 |
+
"oclock": "o'clock",
|
57 |
+
"oughtnt": "oughtn't",
|
58 |
+
"ow's'at": "'ow's'at",
|
59 |
+
"'ows'at": "'ow's'at",
|
60 |
+
"'ow'sat": "'ow's'at",
|
61 |
+
"shant": "shan't",
|
62 |
+
"shed've": "she'd've",
|
63 |
+
"she'dve": "she'd've",
|
64 |
+
"she's": "she's",
|
65 |
+
"shouldve": "should've",
|
66 |
+
"shouldnt": "shouldn't",
|
67 |
+
"shouldnt've": "shouldn't've",
|
68 |
+
"shouldn'tve": "shouldn't've",
|
69 |
+
"somebody'd": "somebodyd",
|
70 |
+
"somebodyd've": "somebody'd've",
|
71 |
+
"somebody'dve": "somebody'd've",
|
72 |
+
"somebodyll": "somebody'll",
|
73 |
+
"somebodys": "somebody's",
|
74 |
+
"someoned": "someone'd",
|
75 |
+
"someoned've": "someone'd've",
|
76 |
+
"someone'dve": "someone'd've",
|
77 |
+
"someonell": "someone'll",
|
78 |
+
"someones": "someone's",
|
79 |
+
"somethingd": "something'd",
|
80 |
+
"somethingd've": "something'd've",
|
81 |
+
"something'dve": "something'd've",
|
82 |
+
"somethingll": "something'll",
|
83 |
+
"thats": "that's",
|
84 |
+
"thered": "there'd",
|
85 |
+
"thered've": "there'd've",
|
86 |
+
"there'dve": "there'd've",
|
87 |
+
"therere": "there're",
|
88 |
+
"theres": "there's",
|
89 |
+
"theyd": "they'd",
|
90 |
+
"theyd've": "they'd've",
|
91 |
+
"they'dve": "they'd've",
|
92 |
+
"theyll": "they'll",
|
93 |
+
"theyre": "they're",
|
94 |
+
"theyve": "they've",
|
95 |
+
"twas": "'twas",
|
96 |
+
"wasnt": "wasn't",
|
97 |
+
"wed've": "we'd've",
|
98 |
+
"we'dve": "we'd've",
|
99 |
+
"weve": "we've",
|
100 |
+
"werent": "weren't",
|
101 |
+
"whatll": "what'll",
|
102 |
+
"whatre": "what're",
|
103 |
+
"whats": "what's",
|
104 |
+
"whatve": "what've",
|
105 |
+
"whens": "when's",
|
106 |
+
"whered": "where'd",
|
107 |
+
"wheres": "where's",
|
108 |
+
"whereve": "where've",
|
109 |
+
"whod": "who'd",
|
110 |
+
"whod've": "who'd've",
|
111 |
+
"who'dve": "who'd've",
|
112 |
+
"wholl": "who'll",
|
113 |
+
"whos": "who's",
|
114 |
+
"whove": "who've",
|
115 |
+
"whyll": "why'll",
|
116 |
+
"whyre": "why're",
|
117 |
+
"whys": "why's",
|
118 |
+
"wont": "won't",
|
119 |
+
"wouldve": "would've",
|
120 |
+
"wouldnt": "wouldn't",
|
121 |
+
"wouldnt've": "wouldn't've",
|
122 |
+
"wouldn'tve": "wouldn't've",
|
123 |
+
"yall": "y'all",
|
124 |
+
"yall'll": "y'all'll",
|
125 |
+
"y'allll": "y'all'll",
|
126 |
+
"yall'd've": "y'all'd've",
|
127 |
+
"y'alld've": "y'all'd've",
|
128 |
+
"y'all'dve": "y'all'd've",
|
129 |
+
"youd": "you'd",
|
130 |
+
"youd've": "you'd've",
|
131 |
+
"you'dve": "you'd've",
|
132 |
+
"youll": "you'll",
|
133 |
+
"youre": "you're",
|
134 |
+
"youve": "you've",
|
135 |
+
}
|
136 |
+
|
137 |
+
NUMBER_MAP = {
|
138 |
+
"none": "0",
|
139 |
+
"zero": "0",
|
140 |
+
"one": "1",
|
141 |
+
"two": "2",
|
142 |
+
"three": "3",
|
143 |
+
"four": "4",
|
144 |
+
"five": "5",
|
145 |
+
"six": "6",
|
146 |
+
"seven": "7",
|
147 |
+
"eight": "8",
|
148 |
+
"nine": "9",
|
149 |
+
"ten": "10",
|
150 |
+
}
|
151 |
+
ARTICLES = ["a", "an", "the"]
|
152 |
+
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
153 |
+
COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
|
154 |
+
PUNCTUATIONS = [
|
155 |
+
";",
|
156 |
+
r"/",
|
157 |
+
"[",
|
158 |
+
"]",
|
159 |
+
'"',
|
160 |
+
"{",
|
161 |
+
"}",
|
162 |
+
"(",
|
163 |
+
")",
|
164 |
+
"=",
|
165 |
+
"+",
|
166 |
+
"\\",
|
167 |
+
"_",
|
168 |
+
"-",
|
169 |
+
">",
|
170 |
+
"<",
|
171 |
+
"@",
|
172 |
+
"`",
|
173 |
+
",",
|
174 |
+
"?",
|
175 |
+
"!",
|
176 |
+
]
|
177 |
+
|
178 |
+
def __init__(self, *args, **kwargs):
|
179 |
+
pass
|
180 |
+
|
181 |
+
def word_tokenize(self, word):
|
182 |
+
word = word.lower()
|
183 |
+
word = word.replace(",", "").replace("?", "").replace("'s", " 's")
|
184 |
+
return word.strip()
|
185 |
+
|
186 |
+
def process_punctuation(self, in_text):
|
187 |
+
out_text = in_text
|
188 |
+
for p in self.PUNCTUATIONS:
|
189 |
+
if (p + " " in in_text or " " + p in in_text) or (
|
190 |
+
re.search(self.COMMA_STRIP, in_text) is not None
|
191 |
+
):
|
192 |
+
out_text = out_text.replace(p, "")
|
193 |
+
else:
|
194 |
+
out_text = out_text.replace(p, " ")
|
195 |
+
out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
|
196 |
+
return out_text
|
197 |
+
|
198 |
+
def process_digit_article(self, in_text):
|
199 |
+
out_text = []
|
200 |
+
temp_text = in_text.lower().split()
|
201 |
+
for word in temp_text:
|
202 |
+
word = self.NUMBER_MAP.setdefault(word, word)
|
203 |
+
if word not in self.ARTICLES:
|
204 |
+
out_text.append(word)
|
205 |
+
else:
|
206 |
+
pass
|
207 |
+
for word_id, word in enumerate(out_text):
|
208 |
+
if word in self.CONTRACTIONS:
|
209 |
+
out_text[word_id] = self.CONTRACTIONS[word]
|
210 |
+
out_text = " ".join(out_text)
|
211 |
+
return out_text
|
212 |
+
|
213 |
+
def __call__(self, item):
|
214 |
+
item = self.word_tokenize(item)
|
215 |
+
item = item.replace("\n", " ").replace("\t", " ").strip()
|
216 |
+
item = self.process_punctuation(item)
|
217 |
+
item = self.process_digit_article(item)
|
218 |
+
return item
|
219 |
+
|
220 |
+
|
221 |
+
class TextVQAAccuracyEvaluator:
|
222 |
+
def __init__(self):
|
223 |
+
self.answer_processor = EvalAIAnswerProcessor()
|
224 |
+
|
225 |
+
def _compute_answer_scores(self, raw_answers):
|
226 |
+
"""
|
227 |
+
compute the accuracy (soft score) of human answers
|
228 |
+
"""
|
229 |
+
answers = [self.answer_processor(a) for a in raw_answers]
|
230 |
+
assert len(answers) == 10
|
231 |
+
gt_answers = list(enumerate(answers))
|
232 |
+
unique_answers = set(answers)
|
233 |
+
unique_answer_scores = {}
|
234 |
+
|
235 |
+
for unique_answer in unique_answers:
|
236 |
+
accs = []
|
237 |
+
for gt_answer in gt_answers:
|
238 |
+
other_answers = [item for item in gt_answers if item != gt_answer]
|
239 |
+
matching_answers = [
|
240 |
+
item for item in other_answers if item[1] == unique_answer
|
241 |
+
]
|
242 |
+
acc = min(1, float(len(matching_answers)) / 3)
|
243 |
+
accs.append(acc)
|
244 |
+
unique_answer_scores[unique_answer] = sum(accs) / len(accs)
|
245 |
+
|
246 |
+
return unique_answer_scores
|
247 |
+
|
248 |
+
def eval_pred_list(self, pred_list):
|
249 |
+
pred_scores = []
|
250 |
+
for entry in tqdm(pred_list):
|
251 |
+
pred_answer = self.answer_processor(entry["pred_answer"])
|
252 |
+
unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
|
253 |
+
score = unique_answer_scores.get(pred_answer, 0.0)
|
254 |
+
pred_scores.append(score)
|
255 |
+
|
256 |
+
accuracy = sum(pred_scores) / len(pred_scores)
|
257 |
+
return accuracy
|
258 |
+
|
259 |
+
|
260 |
+
class STVQAAccuracyEvaluator:
|
261 |
+
def __init__(self):
|
262 |
+
self.answer_processor = EvalAIAnswerProcessor()
|
263 |
+
|
264 |
+
def eval_pred_list(self, pred_list):
|
265 |
+
pred_scores = []
|
266 |
+
for entry in pred_list:
|
267 |
+
pred_answer = self.answer_processor(entry["pred_answer"])
|
268 |
+
gts = [self.answer_processor(a) for a in entry["gt_answers"]]
|
269 |
+
score = 1.0 if pred_answer in gts else 0.0
|
270 |
+
pred_scores.append(score)
|
271 |
+
|
272 |
+
accuracy = sum(pred_scores) / len(pred_scores)
|
273 |
+
return accuracy
|
274 |
+
|
275 |
+
|
276 |
+
class STVQAANLSEvaluator:
|
277 |
+
def __init__(self):
|
278 |
+
import editdistance # install with `pip install editdistance`
|
279 |
+
|
280 |
+
self.get_edit_distance = editdistance.eval
|
281 |
+
|
282 |
+
def get_anls(self, s1, s2):
|
283 |
+
s1 = s1.lower().strip()
|
284 |
+
s2 = s2.lower().strip()
|
285 |
+
iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
|
286 |
+
anls = iou if iou >= 0.5 else 0.0
|
287 |
+
return anls
|
288 |
+
|
289 |
+
def eval_pred_list(self, pred_list):
|
290 |
+
pred_scores = []
|
291 |
+
for entry in pred_list:
|
292 |
+
anls = max(
|
293 |
+
self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
|
294 |
+
)
|
295 |
+
pred_scores.append(anls)
|
296 |
+
|
297 |
+
accuracy = sum(pred_scores) / len(pred_scores)
|
298 |
+
return accuracy
|
299 |
+
|
300 |
+
|
301 |
+
class TextCapsBleu4Evaluator:
|
302 |
+
def __init__(self):
|
303 |
+
# The following script requires Java 1.8.0 and pycocotools installed.
|
304 |
+
# The pycocoevalcap can be installed with pip as
|
305 |
+
# pip install git+https://github.com/ronghanghu/coco-caption.git@python23
|
306 |
+
# Original pycocoevalcap code is at https://github.com/tylin/coco-caption
|
307 |
+
# but has no python3 support yet.
|
308 |
+
try:
|
309 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
310 |
+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
311 |
+
except ModuleNotFoundError:
|
312 |
+
print(
|
313 |
+
"Please install pycocoevalcap module using "
|
314 |
+
"pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
|
315 |
+
)
|
316 |
+
raise
|
317 |
+
|
318 |
+
self.tokenizer = PTBTokenizer()
|
319 |
+
self.scorer = Bleu(4)
|
320 |
+
|
321 |
+
def eval_pred_list(self, pred_list):
|
322 |
+
# Create reference and hypotheses captions.
|
323 |
+
gts = {}
|
324 |
+
res = {}
|
325 |
+
for idx, entry in enumerate(pred_list):
|
326 |
+
gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
|
327 |
+
res[idx] = [{"caption": entry["pred_answer"]}]
|
328 |
+
|
329 |
+
gts = self.tokenizer.tokenize(gts)
|
330 |
+
res = self.tokenizer.tokenize(res)
|
331 |
+
score, _ = self.scorer.compute_score(gts, res)
|
332 |
+
|
333 |
+
bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
|
334 |
+
return bleu4
|
LLAVA_Biovil/llava/eval/model_qa.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
from tqdm import tqdm
|
7 |
+
import shortuuid
|
8 |
+
|
9 |
+
from LLAV.llava.conversation import default_conversation
|
10 |
+
from LLAV.llava.utils import disable_torch_init
|
11 |
+
|
12 |
+
|
13 |
+
# new stopping implementation
|
14 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
15 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
16 |
+
self.keywords = keywords
|
17 |
+
self.tokenizer = tokenizer
|
18 |
+
self.start_len = None
|
19 |
+
self.input_ids = input_ids
|
20 |
+
|
21 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
22 |
+
if self.start_len is None:
|
23 |
+
self.start_len = self.input_ids.shape[1]
|
24 |
+
else:
|
25 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
|
26 |
+
for keyword in self.keywords:
|
27 |
+
if keyword in outputs:
|
28 |
+
return True
|
29 |
+
return False
|
30 |
+
|
31 |
+
|
32 |
+
@torch.inference_mode()
|
33 |
+
def eval_model(model_name, questions_file, answers_file):
|
34 |
+
# Model
|
35 |
+
disable_torch_init()
|
36 |
+
model_name = os.path.expanduser(model_name)
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
38 |
+
model = AutoModelForCausalLM.from_pretrained(model_name,
|
39 |
+
torch_dtype=torch.float16).cuda()
|
40 |
+
|
41 |
+
|
42 |
+
ques_file = open(os.path.expanduser(questions_file), "r")
|
43 |
+
ans_file = open(os.path.expanduser(answers_file), "w")
|
44 |
+
for i, line in enumerate(tqdm(ques_file)):
|
45 |
+
idx = json.loads(line)["question_id"]
|
46 |
+
qs = json.loads(line)["text"]
|
47 |
+
cat = json.loads(line)["category"]
|
48 |
+
conv = default_conversation.copy()
|
49 |
+
conv.append_message(conv.roles[0], qs)
|
50 |
+
prompt = conv.get_prompt()
|
51 |
+
inputs = tokenizer([prompt])
|
52 |
+
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
53 |
+
stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids)
|
54 |
+
output_ids = model.generate(
|
55 |
+
input_ids,
|
56 |
+
do_sample=True,
|
57 |
+
use_cache=True,
|
58 |
+
temperature=0.7,
|
59 |
+
max_new_tokens=1024,
|
60 |
+
stopping_criteria=[stopping_criteria])
|
61 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
62 |
+
try:
|
63 |
+
index = outputs.index(conv.sep, len(prompt))
|
64 |
+
except ValueError:
|
65 |
+
outputs += conv.sep
|
66 |
+
index = outputs.index(conv.sep, len(prompt))
|
67 |
+
|
68 |
+
outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
|
69 |
+
ans_id = shortuuid.uuid()
|
70 |
+
ans_file.write(json.dumps({"question_id": idx,
|
71 |
+
"text": outputs,
|
72 |
+
"answer_id": ans_id,
|
73 |
+
"model_id": model_name,
|
74 |
+
"metadata": {}}) + "\n")
|
75 |
+
ans_file.flush()
|
76 |
+
ans_file.close()
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
parser = argparse.ArgumentParser()
|
80 |
+
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
81 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
82 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
83 |
+
args = parser.parse_args()
|
84 |
+
|
85 |
+
eval_model(args.model_name, args.question_file, args.answers_file)
|
LLAVA_Biovil/llava/eval/model_vqa.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from LLAV.llava.conversation import conv_templates, SeparatorStyle
|
10 |
+
from LLAV.llava.model.builder import load_pretrained_model
|
11 |
+
from LLAV.llava.utils import disable_torch_init
|
12 |
+
from LLAV.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
import math
|
16 |
+
|
17 |
+
|
18 |
+
def split_list(lst, n):
|
19 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
20 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
21 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
22 |
+
|
23 |
+
|
24 |
+
def get_chunk(lst, n, k):
|
25 |
+
chunks = split_list(lst, n)
|
26 |
+
return chunks[k]
|
27 |
+
|
28 |
+
|
29 |
+
def eval_model(args):
|
30 |
+
# Model
|
31 |
+
disable_torch_init()
|
32 |
+
model_path = os.path.expanduser(args.model_path)
|
33 |
+
model_name = get_model_name_from_path(model_path)
|
34 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
35 |
+
|
36 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
37 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
38 |
+
answers_file = os.path.expanduser(args.answers_file)
|
39 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
40 |
+
ans_file = open(answers_file, "w")
|
41 |
+
for line in tqdm(questions):
|
42 |
+
idx = line["question_id"]
|
43 |
+
image_file = line["image"]
|
44 |
+
qs = line["text"]
|
45 |
+
cur_prompt = qs
|
46 |
+
if model.config.mm_use_im_start_end:
|
47 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
48 |
+
else:
|
49 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
50 |
+
|
51 |
+
conv = conv_templates[args.conv_mode].copy()
|
52 |
+
conv.append_message(conv.roles[0], qs)
|
53 |
+
conv.append_message(conv.roles[1], None)
|
54 |
+
prompt = conv.get_prompt()
|
55 |
+
|
56 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
57 |
+
|
58 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
59 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
60 |
+
|
61 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
62 |
+
keywords = [stop_str]
|
63 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
64 |
+
|
65 |
+
with torch.inference_mode():
|
66 |
+
output_ids = model.generate(
|
67 |
+
input_ids,
|
68 |
+
images=image_tensor.unsqueeze(0).half().cuda(),
|
69 |
+
do_sample=True if args.temperature > 0 else False,
|
70 |
+
temperature=args.temperature,
|
71 |
+
top_p=args.top_p,
|
72 |
+
num_beams=args.num_beams,
|
73 |
+
# no_repeat_ngram_size=3,
|
74 |
+
max_new_tokens=1024,
|
75 |
+
use_cache=True)
|
76 |
+
|
77 |
+
input_token_len = input_ids.shape[1]
|
78 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
79 |
+
if n_diff_input_output > 0:
|
80 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
81 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
82 |
+
outputs = outputs.strip()
|
83 |
+
if outputs.endswith(stop_str):
|
84 |
+
outputs = outputs[:-len(stop_str)]
|
85 |
+
outputs = outputs.strip()
|
86 |
+
|
87 |
+
ans_id = shortuuid.uuid()
|
88 |
+
ans_file.write(json.dumps({"question_id": idx,
|
89 |
+
"prompt": cur_prompt,
|
90 |
+
"text": outputs,
|
91 |
+
"answer_id": ans_id,
|
92 |
+
"model_id": model_name,
|
93 |
+
"metadata": {}}) + "\n")
|
94 |
+
ans_file.flush()
|
95 |
+
ans_file.close()
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
parser = argparse.ArgumentParser()
|
99 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
100 |
+
parser.add_argument("--model-base", type=str, default=None)
|
101 |
+
parser.add_argument("--image-folder", type=str, default="")
|
102 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
103 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
104 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
105 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
106 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
107 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
108 |
+
parser.add_argument("--top_p", type=float, default=None)
|
109 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
110 |
+
args = parser.parse_args()
|
111 |
+
|
112 |
+
eval_model(args)
|
LLAVA_Biovil/llava/eval/model_vqa_loader.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from LLAV.llava.conversation import conv_templates
|
10 |
+
from LLAV.llava.model.builder import load_pretrained_model
|
11 |
+
from LLAV.llava.utils import disable_torch_init
|
12 |
+
from LLAV.llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
|
13 |
+
from torch.utils.data import Dataset, DataLoader
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
import math
|
17 |
+
|
18 |
+
|
19 |
+
def split_list(lst, n):
|
20 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
21 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
22 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
23 |
+
|
24 |
+
|
25 |
+
def get_chunk(lst, n, k):
|
26 |
+
chunks = split_list(lst, n)
|
27 |
+
return chunks[k]
|
28 |
+
|
29 |
+
|
30 |
+
# Custom dataset class
|
31 |
+
class CustomDataset(Dataset):
|
32 |
+
def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
|
33 |
+
self.questions = questions
|
34 |
+
self.image_folder = image_folder
|
35 |
+
self.tokenizer = tokenizer
|
36 |
+
self.image_processor = image_processor
|
37 |
+
self.model_config = model_config
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
line = self.questions[index]
|
41 |
+
image_file = line["image"]
|
42 |
+
qs = line["text"]
|
43 |
+
if self.model_config.mm_use_im_start_end:
|
44 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
45 |
+
else:
|
46 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
47 |
+
|
48 |
+
conv = conv_templates[args.conv_mode].copy()
|
49 |
+
conv.append_message(conv.roles[0], qs)
|
50 |
+
conv.append_message(conv.roles[1], None)
|
51 |
+
prompt = conv.get_prompt()
|
52 |
+
|
53 |
+
image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
|
54 |
+
image_tensor = process_images([image], self.image_processor, self.model_config)[0]
|
55 |
+
|
56 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
57 |
+
|
58 |
+
return input_ids, image_tensor
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return len(self.questions)
|
62 |
+
|
63 |
+
|
64 |
+
# DataLoader
|
65 |
+
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
|
66 |
+
assert batch_size == 1, "batch_size must be 1"
|
67 |
+
dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
|
68 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
|
69 |
+
return data_loader
|
70 |
+
|
71 |
+
|
72 |
+
def eval_model(args):
|
73 |
+
# Model
|
74 |
+
disable_torch_init()
|
75 |
+
model_path = os.path.expanduser(args.model_path)
|
76 |
+
model_name = get_model_name_from_path(model_path)
|
77 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
78 |
+
|
79 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
80 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
81 |
+
answers_file = os.path.expanduser(args.answers_file)
|
82 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
83 |
+
ans_file = open(answers_file, "w")
|
84 |
+
|
85 |
+
if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
|
86 |
+
args.conv_mode = args.conv_mode + '_mmtag'
|
87 |
+
print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
|
88 |
+
|
89 |
+
data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
|
90 |
+
|
91 |
+
for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
|
92 |
+
idx = line["question_id"]
|
93 |
+
cur_prompt = line["text"]
|
94 |
+
|
95 |
+
input_ids = input_ids.to(device='cuda', non_blocking=True)
|
96 |
+
|
97 |
+
with torch.inference_mode():
|
98 |
+
output_ids = model.generate(
|
99 |
+
input_ids,
|
100 |
+
images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
|
101 |
+
do_sample=True if args.temperature > 0 else False,
|
102 |
+
temperature=args.temperature,
|
103 |
+
top_p=args.top_p,
|
104 |
+
num_beams=args.num_beams,
|
105 |
+
max_new_tokens=args.max_new_tokens,
|
106 |
+
use_cache=True)
|
107 |
+
|
108 |
+
input_token_len = input_ids.shape[1]
|
109 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
110 |
+
if n_diff_input_output > 0:
|
111 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
112 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
113 |
+
outputs = outputs.strip()
|
114 |
+
|
115 |
+
ans_id = shortuuid.uuid()
|
116 |
+
ans_file.write(json.dumps({"question_id": idx,
|
117 |
+
"prompt": cur_prompt,
|
118 |
+
"text": outputs,
|
119 |
+
"answer_id": ans_id,
|
120 |
+
"model_id": model_name,
|
121 |
+
"metadata": {}}) + "\n")
|
122 |
+
# ans_file.flush()
|
123 |
+
ans_file.close()
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
parser = argparse.ArgumentParser()
|
127 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
128 |
+
parser.add_argument("--model-base", type=str, default=None)
|
129 |
+
parser.add_argument("--image-folder", type=str, default="")
|
130 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
131 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
132 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
133 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
134 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
135 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
136 |
+
parser.add_argument("--top_p", type=float, default=None)
|
137 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
138 |
+
parser.add_argument("--max_new_tokens", type=int, default=128)
|
139 |
+
args = parser.parse_args()
|
140 |
+
|
141 |
+
eval_model(args)
|
LLAVA_Biovil/llava/eval/model_vqa_mmbench.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import pandas as pd
|
6 |
+
from tqdm import tqdm
|
7 |
+
import shortuuid
|
8 |
+
|
9 |
+
from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
10 |
+
from LLAV.llava.conversation import conv_templates, SeparatorStyle
|
11 |
+
from LLAV.llava.model.builder import load_pretrained_model
|
12 |
+
from LLAV.llava.utils import disable_torch_init
|
13 |
+
from LLAV.llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
|
14 |
+
|
15 |
+
import math
|
16 |
+
|
17 |
+
|
18 |
+
all_options = ['A', 'B', 'C', 'D']
|
19 |
+
|
20 |
+
|
21 |
+
def split_list(lst, n):
|
22 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
23 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
24 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
25 |
+
|
26 |
+
|
27 |
+
def get_chunk(lst, n, k):
|
28 |
+
chunks = split_list(lst, n)
|
29 |
+
return chunks[k]
|
30 |
+
|
31 |
+
|
32 |
+
def is_none(value):
|
33 |
+
if value is None:
|
34 |
+
return True
|
35 |
+
if type(value) is float and math.isnan(value):
|
36 |
+
return True
|
37 |
+
if type(value) is str and value.lower() == 'nan':
|
38 |
+
return True
|
39 |
+
if type(value) is str and value.lower() == 'none':
|
40 |
+
return True
|
41 |
+
return False
|
42 |
+
|
43 |
+
def get_options(row, options):
|
44 |
+
parsed_options = []
|
45 |
+
for option in options:
|
46 |
+
option_value = row[option]
|
47 |
+
if is_none(option_value):
|
48 |
+
break
|
49 |
+
parsed_options.append(option_value)
|
50 |
+
return parsed_options
|
51 |
+
|
52 |
+
|
53 |
+
def eval_model(args):
|
54 |
+
# Model
|
55 |
+
disable_torch_init()
|
56 |
+
model_path = os.path.expanduser(args.model_path)
|
57 |
+
model_name = get_model_name_from_path(model_path)
|
58 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
59 |
+
|
60 |
+
questions = pd.read_table(os.path.expanduser(args.question_file))
|
61 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
62 |
+
answers_file = os.path.expanduser(args.answers_file)
|
63 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
64 |
+
ans_file = open(answers_file, "w")
|
65 |
+
|
66 |
+
if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
|
67 |
+
args.conv_mode = args.conv_mode + '_mmtag'
|
68 |
+
print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
|
69 |
+
|
70 |
+
for index, row in tqdm(questions.iterrows(), total=len(questions)):
|
71 |
+
options = get_options(row, all_options)
|
72 |
+
cur_option_char = all_options[:len(options)]
|
73 |
+
|
74 |
+
if args.all_rounds:
|
75 |
+
num_rounds = len(options)
|
76 |
+
else:
|
77 |
+
num_rounds = 1
|
78 |
+
|
79 |
+
for round_idx in range(num_rounds):
|
80 |
+
idx = row['index']
|
81 |
+
question = row['question']
|
82 |
+
hint = row['hint']
|
83 |
+
image = load_image_from_base64(row['image'])
|
84 |
+
if not is_none(hint):
|
85 |
+
question = hint + '\n' + question
|
86 |
+
for option_char, option in zip(all_options[:len(options)], options):
|
87 |
+
question = question + '\n' + option_char + '. ' + option
|
88 |
+
qs = cur_prompt = question
|
89 |
+
if model.config.mm_use_im_start_end:
|
90 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
91 |
+
else:
|
92 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
93 |
+
|
94 |
+
if args.single_pred_prompt:
|
95 |
+
if args.lang == 'cn':
|
96 |
+
qs = qs + '\n' + "请直接回答选项字母。"
|
97 |
+
else:
|
98 |
+
qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
|
99 |
+
|
100 |
+
conv = conv_templates[args.conv_mode].copy()
|
101 |
+
conv.append_message(conv.roles[0], qs)
|
102 |
+
conv.append_message(conv.roles[1], None)
|
103 |
+
prompt = conv.get_prompt()
|
104 |
+
|
105 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
106 |
+
|
107 |
+
image_tensor = process_images([image], image_processor, model.config)[0]
|
108 |
+
# image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
109 |
+
|
110 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
111 |
+
|
112 |
+
with torch.inference_mode():
|
113 |
+
output_ids = model.generate(
|
114 |
+
input_ids,
|
115 |
+
images=image_tensor.unsqueeze(0).half().cuda(),
|
116 |
+
do_sample=True if args.temperature > 0 else False,
|
117 |
+
temperature=args.temperature,
|
118 |
+
top_p=args.top_p,
|
119 |
+
num_beams=args.num_beams,
|
120 |
+
# no_repeat_ngram_size=3,
|
121 |
+
max_new_tokens=1024,
|
122 |
+
use_cache=True)
|
123 |
+
|
124 |
+
input_token_len = input_ids.shape[1]
|
125 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
126 |
+
if n_diff_input_output > 0:
|
127 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
128 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
129 |
+
outputs = outputs.strip()
|
130 |
+
if outputs.endswith(stop_str):
|
131 |
+
outputs = outputs[:-len(stop_str)]
|
132 |
+
outputs = outputs.strip()
|
133 |
+
|
134 |
+
ans_id = shortuuid.uuid()
|
135 |
+
ans_file.write(json.dumps({"question_id": idx,
|
136 |
+
"round_id": round_idx,
|
137 |
+
"prompt": cur_prompt,
|
138 |
+
"text": outputs,
|
139 |
+
"options": options,
|
140 |
+
"option_char": cur_option_char,
|
141 |
+
"answer_id": ans_id,
|
142 |
+
"model_id": model_name,
|
143 |
+
"metadata": {}}) + "\n")
|
144 |
+
ans_file.flush()
|
145 |
+
|
146 |
+
# rotate options
|
147 |
+
options = options[1:] + options[:1]
|
148 |
+
cur_option_char = cur_option_char[1:] + cur_option_char[:1]
|
149 |
+
ans_file.close()
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
parser = argparse.ArgumentParser()
|
153 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
154 |
+
parser.add_argument("--model-base", type=str, default=None)
|
155 |
+
parser.add_argument("--image-folder", type=str, default="")
|
156 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
157 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
158 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
159 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
160 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
161 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
162 |
+
parser.add_argument("--top_p", type=float, default=None)
|
163 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
164 |
+
parser.add_argument("--all-rounds", action="store_true")
|
165 |
+
parser.add_argument("--single-pred-prompt", action="store_true")
|
166 |
+
parser.add_argument("--lang", type=str, default="en")
|
167 |
+
args = parser.parse_args()
|
168 |
+
|
169 |
+
eval_model(args)
|
LLAVA_Biovil/llava/eval/model_vqa_qbench.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import json
|
5 |
+
|
6 |
+
from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
7 |
+
from LLAV.llava.conversation import conv_templates, SeparatorStyle
|
8 |
+
from LLAV.llava.model.builder import load_pretrained_model
|
9 |
+
from LLAV.llava.utils import disable_torch_init
|
10 |
+
from LLAV.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
11 |
+
|
12 |
+
import requests
|
13 |
+
from PIL import Image
|
14 |
+
from io import BytesIO
|
15 |
+
|
16 |
+
|
17 |
+
def load_image(image_file):
|
18 |
+
if image_file.startswith('http') or image_file.startswith('https'):
|
19 |
+
response = requests.get(image_file)
|
20 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
21 |
+
else:
|
22 |
+
image = Image.open(image_file).convert('RGB')
|
23 |
+
return image
|
24 |
+
|
25 |
+
|
26 |
+
def eval_model(args):
|
27 |
+
# Model
|
28 |
+
disable_torch_init()
|
29 |
+
|
30 |
+
model_name = get_model_name_from_path(args.model_path)
|
31 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, True)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
with open(args.questions_file) as f:
|
37 |
+
llvqa_data = json.load(f)
|
38 |
+
|
39 |
+
for i, llddata in enumerate(tqdm(llvqa_data)):
|
40 |
+
filename = llddata["img_path"]
|
41 |
+
if args.lang == "en":
|
42 |
+
message = llddata["question"] + "\nChoose between one of the options as follows:\n"
|
43 |
+
elif args.lang == "zh":
|
44 |
+
message = llddata["question"] + "\在下列选项中选择一个:\n"
|
45 |
+
else:
|
46 |
+
raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.")
|
47 |
+
for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
|
48 |
+
message += f"{choice} {ans}\n"
|
49 |
+
qs = message
|
50 |
+
|
51 |
+
if model.config.mm_use_im_start_end:
|
52 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
53 |
+
else:
|
54 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
55 |
+
|
56 |
+
if 'llama-2' in model_name.lower():
|
57 |
+
conv_mode = "llava_llama_2"
|
58 |
+
elif "v1" in model_name.lower():
|
59 |
+
conv_mode = "llava_v1"
|
60 |
+
elif "mpt" in model_name.lower():
|
61 |
+
conv_mode = "mpt"
|
62 |
+
else:
|
63 |
+
conv_mode = "llava_v0"
|
64 |
+
|
65 |
+
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
66 |
+
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
|
67 |
+
else:
|
68 |
+
args.conv_mode = conv_mode
|
69 |
+
|
70 |
+
conv = conv_templates[args.conv_mode].copy()
|
71 |
+
conv.append_message(conv.roles[0], qs)
|
72 |
+
conv.append_message(conv.roles[1], None)
|
73 |
+
prompt = conv.get_prompt()
|
74 |
+
|
75 |
+
image = load_image(args.image_folder + filename)
|
76 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
|
77 |
+
|
78 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
79 |
+
|
80 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
81 |
+
keywords = [stop_str]
|
82 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
83 |
+
|
84 |
+
|
85 |
+
with torch.inference_mode():
|
86 |
+
output_ids = model.generate(
|
87 |
+
input_ids,
|
88 |
+
images=image_tensor,
|
89 |
+
num_beams=1,
|
90 |
+
do_sample=False,
|
91 |
+
temperature=0,
|
92 |
+
max_new_tokens=1024,
|
93 |
+
use_cache=True,
|
94 |
+
stopping_criteria=[stopping_criteria])
|
95 |
+
|
96 |
+
input_token_len = input_ids.shape[1]
|
97 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
98 |
+
if n_diff_input_output > 0:
|
99 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
100 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
101 |
+
outputs = outputs.strip()
|
102 |
+
if outputs.endswith(stop_str):
|
103 |
+
outputs = outputs[:-len(stop_str)]
|
104 |
+
outputs = outputs.strip()
|
105 |
+
llddata["response"] = outputs
|
106 |
+
with open(args.answers_file, "a") as wf:
|
107 |
+
json.dump(llddata, wf)
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument("--model-path", type=str, default="llava-v1.5")
|
112 |
+
parser.add_argument("--model-base", type=str, default=None)
|
113 |
+
parser.add_argument("--image-folder", type=str, default="./playground/data/qbench/images_llvisionqa")
|
114 |
+
parser.add_argument("--questions-file", type=str, default="./playground/data/qbench/llvisionqa_dev.json")
|
115 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
116 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
117 |
+
parser.add_argument("--lang", type=str, default="en")
|
118 |
+
args = parser.parse_args()
|
119 |
+
|
120 |
+
eval_model(args)
|
LLAVA_Biovil/llava/eval/model_vqa_science.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from LLAV.llava.conversation import conv_templates, SeparatorStyle
|
10 |
+
from LLAV.llava.model.builder import load_pretrained_model
|
11 |
+
from LLAV.llava.utils import disable_torch_init
|
12 |
+
from LLAV.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
import math
|
16 |
+
|
17 |
+
|
18 |
+
def split_list(lst, n):
|
19 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
20 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
21 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
22 |
+
|
23 |
+
|
24 |
+
def get_chunk(lst, n, k):
|
25 |
+
chunks = split_list(lst, n)
|
26 |
+
return chunks[k]
|
27 |
+
|
28 |
+
|
29 |
+
def eval_model(args):
|
30 |
+
# Model
|
31 |
+
disable_torch_init()
|
32 |
+
model_path = os.path.expanduser(args.model_path)
|
33 |
+
model_name = get_model_name_from_path(model_path)
|
34 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
35 |
+
|
36 |
+
questions = json.load(open(os.path.expanduser(args.question_file), "r"))
|
37 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
38 |
+
answers_file = os.path.expanduser(args.answers_file)
|
39 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
40 |
+
ans_file = open(answers_file, "w")
|
41 |
+
for i, line in enumerate(tqdm(questions)):
|
42 |
+
idx = line["id"]
|
43 |
+
question = line['conversations'][0]
|
44 |
+
qs = question['value'].replace('<image>', '').strip()
|
45 |
+
cur_prompt = qs
|
46 |
+
|
47 |
+
if 'image' in line:
|
48 |
+
image_file = line["image"]
|
49 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
50 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
51 |
+
images = image_tensor.unsqueeze(0).half().cuda()
|
52 |
+
if getattr(model.config, 'mm_use_im_start_end', False):
|
53 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
54 |
+
else:
|
55 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
56 |
+
cur_prompt = '<image>' + '\n' + cur_prompt
|
57 |
+
else:
|
58 |
+
images = None
|
59 |
+
|
60 |
+
if args.single_pred_prompt:
|
61 |
+
qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
|
62 |
+
cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
|
63 |
+
|
64 |
+
conv = conv_templates[args.conv_mode].copy()
|
65 |
+
conv.append_message(conv.roles[0], qs)
|
66 |
+
conv.append_message(conv.roles[1], None)
|
67 |
+
prompt = conv.get_prompt()
|
68 |
+
|
69 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
70 |
+
|
71 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
72 |
+
keywords = [stop_str]
|
73 |
+
stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None
|
74 |
+
|
75 |
+
with torch.inference_mode():
|
76 |
+
output_ids = model.generate(
|
77 |
+
input_ids,
|
78 |
+
images=images,
|
79 |
+
do_sample=True if args.temperature > 0 else False,
|
80 |
+
temperature=args.temperature,
|
81 |
+
max_new_tokens=1024,
|
82 |
+
use_cache=True,
|
83 |
+
stopping_criteria=stopping_criteria,
|
84 |
+
)
|
85 |
+
|
86 |
+
input_token_len = input_ids.shape[1]
|
87 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
88 |
+
if n_diff_input_output > 0:
|
89 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
90 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
91 |
+
outputs = outputs.strip()
|
92 |
+
if outputs.endswith(stop_str):
|
93 |
+
outputs = outputs[:-len(stop_str)]
|
94 |
+
outputs = outputs.strip()
|
95 |
+
|
96 |
+
# prompt for answer
|
97 |
+
if args.answer_prompter:
|
98 |
+
outputs_reasoning = outputs
|
99 |
+
input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
100 |
+
|
101 |
+
with torch.inference_mode():
|
102 |
+
output_ids = model.generate(
|
103 |
+
input_ids,
|
104 |
+
images=images,
|
105 |
+
do_sample=True if args.temperature > 0 else False,
|
106 |
+
temperature=args.temperature,
|
107 |
+
max_new_tokens=64,
|
108 |
+
use_cache=True,
|
109 |
+
stopping_criteria=[stopping_criteria])
|
110 |
+
|
111 |
+
input_token_len = input_ids.shape[1]
|
112 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
113 |
+
if n_diff_input_output > 0:
|
114 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
115 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
116 |
+
outputs = outputs.strip()
|
117 |
+
if outputs.endswith(stop_str):
|
118 |
+
outputs = outputs[:-len(stop_str)]
|
119 |
+
outputs = outputs.strip()
|
120 |
+
outputs = outputs_reasoning + '\n The answer is ' + outputs
|
121 |
+
|
122 |
+
ans_id = shortuuid.uuid()
|
123 |
+
ans_file.write(json.dumps({"question_id": idx,
|
124 |
+
"prompt": cur_prompt,
|
125 |
+
"text": outputs,
|
126 |
+
"answer_id": ans_id,
|
127 |
+
"model_id": model_name,
|
128 |
+
"metadata": {}}) + "\n")
|
129 |
+
ans_file.flush()
|
130 |
+
ans_file.close()
|
131 |
+
|
132 |
+
if __name__ == "__main__":
|
133 |
+
parser = argparse.ArgumentParser()
|
134 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
135 |
+
parser.add_argument("--model-base", type=str, default=None)
|
136 |
+
parser.add_argument("--image-folder", type=str, default="")
|
137 |
+
parser.add_argument("--question-file", type=str, default="tables/question.json")
|
138 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
139 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v0")
|
140 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
141 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
142 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
143 |
+
parser.add_argument("--answer-prompter", action="store_true")
|
144 |
+
parser.add_argument("--single-pred-prompt", action="store_true")
|
145 |
+
args = parser.parse_args()
|
146 |
+
|
147 |
+
eval_model(args)
|
LLAVA_Biovil/llava/eval/qa_baseline_gpt35.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Generate answers with GPT-3.5"""
|
2 |
+
# Note: you need to be using OpenAI Python v0.27.0 for the code below to work
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import concurrent.futures
|
8 |
+
|
9 |
+
import openai
|
10 |
+
import tqdm
|
11 |
+
import shortuuid
|
12 |
+
|
13 |
+
MODEL = 'gpt-3.5-turbo'
|
14 |
+
MODEL_ID = 'gpt-3.5-turbo:20230327'
|
15 |
+
|
16 |
+
def get_answer(question_id: int, question: str, max_tokens: int):
|
17 |
+
ans = {
|
18 |
+
'answer_id': shortuuid.uuid(),
|
19 |
+
'question_id': question_id,
|
20 |
+
'model_id': MODEL_ID,
|
21 |
+
}
|
22 |
+
for _ in range(3):
|
23 |
+
try:
|
24 |
+
response = openai.ChatCompletion.create(
|
25 |
+
model=MODEL,
|
26 |
+
messages=[{
|
27 |
+
'role': 'system',
|
28 |
+
'content': 'You are a helpful assistant.'
|
29 |
+
}, {
|
30 |
+
'role': 'user',
|
31 |
+
'content': question,
|
32 |
+
}],
|
33 |
+
max_tokens=max_tokens,
|
34 |
+
)
|
35 |
+
ans['text'] = response['choices'][0]['message']['content']
|
36 |
+
return ans
|
37 |
+
except Exception as e:
|
38 |
+
print('[ERROR]', e)
|
39 |
+
ans['text'] = '#ERROR#'
|
40 |
+
time.sleep(1)
|
41 |
+
return ans
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
|
46 |
+
parser.add_argument('-q', '--question')
|
47 |
+
parser.add_argument('-o', '--output')
|
48 |
+
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
|
49 |
+
args = parser.parse_args()
|
50 |
+
|
51 |
+
questions_dict = {}
|
52 |
+
with open(os.path.expanduser(args.question)) as f:
|
53 |
+
for line in f:
|
54 |
+
if not line:
|
55 |
+
continue
|
56 |
+
q = json.loads(line)
|
57 |
+
questions_dict[q['question_id']] = q['text']
|
58 |
+
|
59 |
+
answers = []
|
60 |
+
|
61 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
62 |
+
futures = []
|
63 |
+
for qid, question in questions_dict.items():
|
64 |
+
future = executor.submit(get_answer, qid, question, args.max_tokens)
|
65 |
+
futures.append(future)
|
66 |
+
|
67 |
+
for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
|
68 |
+
answers.append(future.result())
|
69 |
+
|
70 |
+
answers.sort(key=lambda x: x['question_id'])
|
71 |
+
|
72 |
+
with open(os.path.expanduser(args.output), 'w') as f:
|
73 |
+
table = [json.dumps(ans) for ans in answers]
|
74 |
+
f.write('\n'.join(table))
|
LLAVA_Biovil/llava/eval/run_llava.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from LLAV.llava.constants import (
|
5 |
+
IMAGE_TOKEN_INDEX,
|
6 |
+
DEFAULT_IMAGE_TOKEN,
|
7 |
+
DEFAULT_IM_START_TOKEN,
|
8 |
+
DEFAULT_IM_END_TOKEN,
|
9 |
+
IMAGE_PLACEHOLDER,
|
10 |
+
)
|
11 |
+
from LLAV.llava.conversation import conv_templates, SeparatorStyle
|
12 |
+
from LLAV.llava.model.builder import load_pretrained_model
|
13 |
+
from LLAV.llava.utils import disable_torch_init
|
14 |
+
from LLAV.llava.mm_utils import (
|
15 |
+
process_images,
|
16 |
+
tokenizer_image_token,
|
17 |
+
get_model_name_from_path,
|
18 |
+
KeywordsStoppingCriteria,
|
19 |
+
)
|
20 |
+
|
21 |
+
import requests
|
22 |
+
from PIL import Image
|
23 |
+
from io import BytesIO
|
24 |
+
import re
|
25 |
+
|
26 |
+
|
27 |
+
def image_parser(args):
|
28 |
+
out = args.image_file.split(args.sep)
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
def load_image(image_file):
|
33 |
+
if image_file.startswith("http") or image_file.startswith("https"):
|
34 |
+
response = requests.get(image_file)
|
35 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
36 |
+
else:
|
37 |
+
image = Image.open(image_file).convert("RGB")
|
38 |
+
return image
|
39 |
+
|
40 |
+
|
41 |
+
def load_images(image_files):
|
42 |
+
out = []
|
43 |
+
for image_file in image_files:
|
44 |
+
image = load_image(image_file)
|
45 |
+
out.append(image)
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
def eval_model(args):
|
50 |
+
# Model
|
51 |
+
disable_torch_init()
|
52 |
+
|
53 |
+
model_name = get_model_name_from_path(args.model_path)
|
54 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
55 |
+
args.model_path, args.model_base, model_name
|
56 |
+
)
|
57 |
+
|
58 |
+
qs = args.query
|
59 |
+
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
60 |
+
if IMAGE_PLACEHOLDER in qs:
|
61 |
+
if model.config.mm_use_im_start_end:
|
62 |
+
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
|
63 |
+
else:
|
64 |
+
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
|
65 |
+
else:
|
66 |
+
if model.config.mm_use_im_start_end:
|
67 |
+
qs = image_token_se + "\n" + qs
|
68 |
+
else:
|
69 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
70 |
+
|
71 |
+
if "llama-2" in model_name.lower():
|
72 |
+
conv_mode = "llava_llama_2"
|
73 |
+
elif "v1" in model_name.lower():
|
74 |
+
conv_mode = "llava_v1"
|
75 |
+
elif "mpt" in model_name.lower():
|
76 |
+
conv_mode = "mpt"
|
77 |
+
else:
|
78 |
+
conv_mode = "llava_v0"
|
79 |
+
|
80 |
+
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
81 |
+
print(
|
82 |
+
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
|
83 |
+
conv_mode, args.conv_mode, args.conv_mode
|
84 |
+
)
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
args.conv_mode = conv_mode
|
88 |
+
|
89 |
+
conv = conv_templates[args.conv_mode].copy()
|
90 |
+
conv.append_message(conv.roles[0], qs)
|
91 |
+
conv.append_message(conv.roles[1], None)
|
92 |
+
prompt = conv.get_prompt()
|
93 |
+
|
94 |
+
image_files = image_parser(args)
|
95 |
+
images = load_images(image_files)
|
96 |
+
images_tensor = process_images(
|
97 |
+
images,
|
98 |
+
image_processor,
|
99 |
+
model.config
|
100 |
+
).to(model.device, dtype=torch.float16)
|
101 |
+
|
102 |
+
input_ids = (
|
103 |
+
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
104 |
+
.unsqueeze(0)
|
105 |
+
.cuda()
|
106 |
+
)
|
107 |
+
|
108 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
109 |
+
keywords = [stop_str]
|
110 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
111 |
+
|
112 |
+
with torch.inference_mode():
|
113 |
+
output_ids = model.generate(
|
114 |
+
input_ids,
|
115 |
+
images=images_tensor,
|
116 |
+
do_sample=True if args.temperature > 0 else False,
|
117 |
+
temperature=args.temperature,
|
118 |
+
top_p=args.top_p,
|
119 |
+
num_beams=args.num_beams,
|
120 |
+
max_new_tokens=args.max_new_tokens,
|
121 |
+
use_cache=True,
|
122 |
+
stopping_criteria=[stopping_criteria],
|
123 |
+
)
|
124 |
+
|
125 |
+
input_token_len = input_ids.shape[1]
|
126 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
127 |
+
if n_diff_input_output > 0:
|
128 |
+
print(
|
129 |
+
f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
|
130 |
+
)
|
131 |
+
outputs = tokenizer.batch_decode(
|
132 |
+
output_ids[:, input_token_len:], skip_special_tokens=True
|
133 |
+
)[0]
|
134 |
+
outputs = outputs.strip()
|
135 |
+
if outputs.endswith(stop_str):
|
136 |
+
outputs = outputs[: -len(stop_str)]
|
137 |
+
outputs = outputs.strip()
|
138 |
+
print(outputs)
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == "__main__":
|
142 |
+
parser = argparse.ArgumentParser()
|
143 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
144 |
+
parser.add_argument("--model-base", type=str, default=None)
|
145 |
+
parser.add_argument("--image-file", type=str, required=True)
|
146 |
+
parser.add_argument("--query", type=str, required=True)
|
147 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
148 |
+
parser.add_argument("--sep", type=str, default=",")
|
149 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
150 |
+
parser.add_argument("--top_p", type=float, default=None)
|
151 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
152 |
+
parser.add_argument("--max_new_tokens", type=int, default=512)
|
153 |
+
args = parser.parse_args()
|
154 |
+
|
155 |
+
eval_model(args)
|
LLAVA_Biovil/llava/eval/summarize_gpt_review.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
def parse_args():
|
10 |
+
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
|
11 |
+
parser.add_argument('-d', '--dir', default=None)
|
12 |
+
parser.add_argument('-v', '--version', default=None)
|
13 |
+
parser.add_argument('-s', '--select', nargs='*', default=None)
|
14 |
+
parser.add_argument('-f', '--files', nargs='*', default=[])
|
15 |
+
parser.add_argument('-i', '--ignore', nargs='*', default=[])
|
16 |
+
return parser.parse_args()
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
args = parse_args()
|
21 |
+
|
22 |
+
if args.ignore is not None:
|
23 |
+
args.ignore = [int(x) for x in args.ignore]
|
24 |
+
|
25 |
+
if len(args.files) > 0:
|
26 |
+
review_files = args.files
|
27 |
+
else:
|
28 |
+
review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
|
29 |
+
|
30 |
+
for review_file in sorted(review_files):
|
31 |
+
config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
|
32 |
+
if args.select is not None and any(x not in config for x in args.select):
|
33 |
+
continue
|
34 |
+
if '0613' in config:
|
35 |
+
version = '0613'
|
36 |
+
else:
|
37 |
+
version = '0314'
|
38 |
+
if args.version is not None and args.version != version:
|
39 |
+
continue
|
40 |
+
scores = defaultdict(list)
|
41 |
+
print(config)
|
42 |
+
with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
|
43 |
+
for review_str in f:
|
44 |
+
review = json.loads(review_str)
|
45 |
+
if review['question_id'] in args.ignore:
|
46 |
+
continue
|
47 |
+
if 'category' in review:
|
48 |
+
scores[review['category']].append(review['tuple'])
|
49 |
+
scores['all'].append(review['tuple'])
|
50 |
+
else:
|
51 |
+
if 'tuple' in review:
|
52 |
+
scores['all'].append(review['tuple'])
|
53 |
+
else:
|
54 |
+
scores['all'].append(review['score'])
|
55 |
+
for k, v in sorted(scores.items()):
|
56 |
+
stats = np.asarray(v).mean(0).tolist()
|
57 |
+
stats = [round(x, 3) for x in stats]
|
58 |
+
# print(k, stats, round(stats[1]/stats[0]*100, 1))
|
59 |
+
print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
|
60 |
+
print('=================================')
|
LLAVA_Biovil/llava/eval/webpage/figures/alpaca.png
ADDED
LLAVA_Biovil/llava/eval/webpage/figures/bard.jpg
ADDED
LLAVA_Biovil/llava/eval/webpage/figures/chatgpt.svg
ADDED
LLAVA_Biovil/llava/eval/webpage/figures/llama.jpg
ADDED
LLAVA_Biovil/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg
ADDED
LLAVA_Biovil/llava/eval/webpage/figures/vicuna.jpeg
ADDED
LLAVA_Biovil/llava/eval/webpage/index.html
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</title>
|
7 |
+
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css">
|
8 |
+
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
9 |
+
<link rel="stylesheet" href="styles.css">
|
10 |
+
</head>
|
11 |
+
|
12 |
+
<body>
|
13 |
+
<nav class="navbar navbar-expand-lg navbar-dark bg-dark">
|
14 |
+
<a class="navbar-brand" href="#">🏔️ Vicuna Evaluation Examples</a>
|
15 |
+
<button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarNav" aria-controls="navbarNav" aria-expanded="false" aria-label="Toggle navigation">
|
16 |
+
<span class="navbar-toggler-icon"></span>
|
17 |
+
</button>
|
18 |
+
<div class="collapse navbar-collapse" id="navbarNav">
|
19 |
+
<ul class="navbar-nav mr-auto">
|
20 |
+
<li class="nav-item">
|
21 |
+
<a class="nav-link" href="https://chat.lmsys.org/">Demo</a>
|
22 |
+
</li>
|
23 |
+
<li class="nav-item">
|
24 |
+
<a class="nav-link" href="https://vicuna.lmsys.org">Blog</a>
|
25 |
+
</li>
|
26 |
+
<li class="nav-item">
|
27 |
+
<a class="nav-link" href="https://github.com/lm-sys/FastChat">Github</a>
|
28 |
+
</li>
|
29 |
+
</ul>
|
30 |
+
</div>
|
31 |
+
</nav>
|
32 |
+
|
33 |
+
<div class="container mt-5">
|
34 |
+
<h2 class="text-center mb-5">Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</h2>
|
35 |
+
|
36 |
+
<!-- Selection -->
|
37 |
+
<div class="form-row">
|
38 |
+
<div class="form-group col-md-2">
|
39 |
+
<label for="category-select">Category</label>
|
40 |
+
<select class="form-control" id="category-select"></select>
|
41 |
+
</div>
|
42 |
+
<div class="form-group col-md-8">
|
43 |
+
<label for="question-select">Question</label>
|
44 |
+
<select class="form-control" id="question-select"></select>
|
45 |
+
</div>
|
46 |
+
<div class="form-group col-md-2">
|
47 |
+
<div class="col-md-2"><label> </label></div>
|
48 |
+
<div class="btn-group" role="group" aria-label="Left and Right Controller">
|
49 |
+
<button type="button" class="form-control btn btn-primary" id="prev-question"><i class="material-icons">keyboard_arrow_left</i></button>
|
50 |
+
<button type="button" class="form-control btn btn-primary" id="next-question"><i class="material-icons">keyboard_arrow_right</i></button>
|
51 |
+
</div>
|
52 |
+
</div>
|
53 |
+
</div>
|
54 |
+
|
55 |
+
<!-- "Battle" -->
|
56 |
+
<div class="row mb-4" style="justify-content: center;">
|
57 |
+
<div class="col" style="display: flex; justify-content: center; align-items: center;">
|
58 |
+
<label class="adjustable-font-size" id="other-score-label">*/10</label>
|
59 |
+
</div>
|
60 |
+
<div class="col">
|
61 |
+
<div class="vertical-flex-layout">
|
62 |
+
<img class="shadow figure-img img-fluid" src="" alt="other logo" width="150" id="other-model-figure">
|
63 |
+
</div>
|
64 |
+
</div>
|
65 |
+
<div class="col">
|
66 |
+
<div class="vertical-flex-layout">
|
67 |
+
<!-- from: https://fonts.google.com/icons?icon.query=battle&selected=Material+Symbols+Outlined:swords:FILL@0;wght@300;GRAD@0;opsz@48&icon.style=Outlined -->
|
68 |
+
<img class="figure-img img-fluid" src="figures/swords_FILL0_wght300_GRAD0_opsz48.svg" width="60" height="60">
|
69 |
+
</div>
|
70 |
+
</div>
|
71 |
+
<div class="col">
|
72 |
+
<div class="vertical-flex-layout">
|
73 |
+
<img class="shadow figure-img img-fluid" src="figures/vicuna.jpeg" alt="vicuna logo" width="150" id="our-model-figure">
|
74 |
+
</div>
|
75 |
+
</div>
|
76 |
+
<div class="col" style="display: flex; justify-content: center; align-items: center;">
|
77 |
+
<label class="adjustable-font-size" id="our-score-label">*/10</label>
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
|
81 |
+
<!-- Question Card -->
|
82 |
+
<div class="card mb-4">
|
83 |
+
<div class="card-body" id="selected-question"></div>
|
84 |
+
</div>
|
85 |
+
|
86 |
+
<!-- Answer Cards -->
|
87 |
+
<div class="row">
|
88 |
+
<div class="col-md-6">
|
89 |
+
<div class="card mb-4 expandable-card">
|
90 |
+
<div class="card-header" style="padding-bottom: 0.2rem" id="other-model-header-bg">
|
91 |
+
<div class="row">
|
92 |
+
<div class="col-md-5" style="align-items: center; display: flex;">
|
93 |
+
<label id="other-model-header">Assistant #1</label>
|
94 |
+
</div>
|
95 |
+
<div class="col-md-7">
|
96 |
+
<select class="form-control" id="model-select" style="height: fit-content; margin-top: -0.3rem;"></select>
|
97 |
+
</div>
|
98 |
+
</div>
|
99 |
+
</div>
|
100 |
+
<div class="card-body">
|
101 |
+
<div class="card-text-container">
|
102 |
+
<div class="card-text" id="other-model-answer"></div>
|
103 |
+
</div>
|
104 |
+
<div class="btn btn-primary expand-btn" style="display:flex;"></div>
|
105 |
+
</div>
|
106 |
+
</div>
|
107 |
+
</div>
|
108 |
+
<div class="col-md-6">
|
109 |
+
<div class="card mb-4 expandable-card">
|
110 |
+
<div class="card-header" id="our-model-header">
|
111 |
+
Assistant #2 (Vicuna, our model)
|
112 |
+
</div>
|
113 |
+
<div class="card-body">
|
114 |
+
<div class="card-text-container">
|
115 |
+
<div class="card-text" id="our-model-answer"></div>
|
116 |
+
</div>
|
117 |
+
<div class="btn btn-primary expand-btn" style="display:flex;"></div>
|
118 |
+
</div>
|
119 |
+
</div>
|
120 |
+
</div>
|
121 |
+
</div>
|
122 |
+
|
123 |
+
<!-- Evaluation -->
|
124 |
+
<div class="card expandable-card">
|
125 |
+
<div class="card-header" style="background-color: #c9c9f2;" id="evaluation-header">GPT-4 Evaluation</div>
|
126 |
+
<div class="card-body">
|
127 |
+
<div class="card-text-container">
|
128 |
+
<div class="card-text" id="evaluation-result"></div>
|
129 |
+
</div>
|
130 |
+
<div class="btn btn-primary expand-btn" style="display:flex;"></div>
|
131 |
+
</div>
|
132 |
+
</div>
|
133 |
+
</div>
|
134 |
+
|
135 |
+
<div class="container-fluid bg-light py-2">
|
136 |
+
<div class="text-center">
|
137 |
+
<small class="text-muted">This website is co-authored with <a href="https://openai.com" target="_blank">GPT-4</a>.</small>
|
138 |
+
</div>
|
139 |
+
</div>
|
140 |
+
|
141 |
+
<!-- Marked.js -->
|
142 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/lib/marked.umd.min.js"></script>
|
143 |
+
<!-- Bootstrap and Popper.js JavaScript dependencies -->
|
144 |
+
<script src="https://code.jquery.com/jquery-3.5.1.slim.min.js"></script>
|
145 |
+
<script src="https://cdn.jsdelivr.net/npm/@popperjs/[email protected]/dist/umd/popper.min.js"></script>
|
146 |
+
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"></script>
|
147 |
+
|
148 |
+
<script src="script.js"></script>
|
149 |
+
<script>
|
150 |
+
// Fetch the JSON file
|
151 |
+
fetch('data.json')
|
152 |
+
.then(response => response.json())
|
153 |
+
.then(json_data => {
|
154 |
+
// Populate the models and questions.
|
155 |
+
populateModels(json_data.models);
|
156 |
+
populateQuestions(json_data.questions);
|
157 |
+
displayQuestion(currentQuestionIndex);
|
158 |
+
}).catch(error => console.error(error));
|
159 |
+
</script>
|
160 |
+
</body>
|
161 |
+
|
162 |
+
</html>
|
LLAVA_Biovil/llava/eval/webpage/script.js
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Description: Script for the evaluation webpage.
|
2 |
+
|
3 |
+
let currentQuestionIndex = 1;
|
4 |
+
|
5 |
+
// Store the model name mapping for later use.
|
6 |
+
modelNameMapping = {
|
7 |
+
"gpt35": "ChatGPT-3.5",
|
8 |
+
"gpt4": "GPT-4",
|
9 |
+
"alpaca": "Alpaca-13b",
|
10 |
+
"vicuna": "Vicuna-13b",
|
11 |
+
"llama": "LLaMA-13b",
|
12 |
+
"bard": "Bard",
|
13 |
+
};
|
14 |
+
|
15 |
+
modelFigureMapping = {
|
16 |
+
"vicuna": "figures/vicuna.jpeg",
|
17 |
+
// Image from: https://commons.wikimedia.org/wiki/File:ChatGPT_logo.svg
|
18 |
+
"gpt35": "figures/chatgpt.svg",
|
19 |
+
// Image from: https://www.reddit.com/r/logodesign/comments/1128aat/google_ai_bard_logo_design/
|
20 |
+
"bard": "figures/bard.jpg",
|
21 |
+
// Image from: https://crfm.stanford.edu/2023/03/13/alpaca.html
|
22 |
+
"alpaca": "figures/alpaca.png",
|
23 |
+
// Image adapted from https://commons.wikimedia.org/wiki/File:Llama_on_Machu_Picchu.jpg
|
24 |
+
"llama": "figures/llama.jpg",
|
25 |
+
}
|
26 |
+
|
27 |
+
// Store the question data in a mapping for later use.
|
28 |
+
questionMapping = {};
|
29 |
+
// Store the question ids in a mapping for later use.
|
30 |
+
categoryMapping = {};
|
31 |
+
// Store the number of questions for later use.
|
32 |
+
questionsCount = 0;
|
33 |
+
|
34 |
+
|
35 |
+
function text2Markdown(text) {
|
36 |
+
// Normalize the text for markdown rendering.
|
37 |
+
text = text.trim().replaceAll('\n\n', '\n').replaceAll('\n', '\n\n');
|
38 |
+
return marked.parse(text);
|
39 |
+
}
|
40 |
+
|
41 |
+
function capitalizeFirstChar(str) {
|
42 |
+
if (!str || str.length === 0) {
|
43 |
+
return str;
|
44 |
+
}
|
45 |
+
return str.charAt(0).toUpperCase() + str.slice(1);
|
46 |
+
}
|
47 |
+
|
48 |
+
function updateQuestionSelect(question_id) {
|
49 |
+
const select = document.getElementById('question-select');
|
50 |
+
// Clear the question select.
|
51 |
+
select.innerHTML = '';
|
52 |
+
// Populate the question select.
|
53 |
+
category = questionMapping[question_id].category;
|
54 |
+
categoryMapping[category].forEach(question_id => {
|
55 |
+
const question = questionMapping[question_id];
|
56 |
+
const option = document.createElement('option');
|
57 |
+
option.value = question_id;
|
58 |
+
option.textContent = 'Q' + question_id.toString() + ': ' + question.question;
|
59 |
+
select.appendChild(option);
|
60 |
+
});
|
61 |
+
select.value = question_id;
|
62 |
+
}
|
63 |
+
|
64 |
+
function updateModelSelect() {
|
65 |
+
const select = document.getElementById('model-select');
|
66 |
+
img_path = modelFigureMapping[select.value];
|
67 |
+
document.getElementById('other-model-figure').src = img_path;
|
68 |
+
}
|
69 |
+
|
70 |
+
function populateModels(models) {
|
71 |
+
const select = document.getElementById('model-select');
|
72 |
+
models.forEach(model => {
|
73 |
+
const option = document.createElement('option');
|
74 |
+
option.value = model;
|
75 |
+
option.textContent = modelNameMapping[model];
|
76 |
+
select.appendChild(option);
|
77 |
+
});
|
78 |
+
updateModelSelect();
|
79 |
+
}
|
80 |
+
|
81 |
+
function populateQuestions(questions) {
|
82 |
+
const category_select = document.getElementById('category-select');
|
83 |
+
|
84 |
+
questionsCount = questions.length;
|
85 |
+
questions.forEach(question => {
|
86 |
+
const option = document.createElement('option');
|
87 |
+
// Store the question data in a mapping for later use.
|
88 |
+
questionMapping[question.id] = {
|
89 |
+
category: question.category,
|
90 |
+
question: question.question,
|
91 |
+
answers: question.answers,
|
92 |
+
evaluations: question.evaluations,
|
93 |
+
scores: question.scores,
|
94 |
+
};
|
95 |
+
// Store the question id in the category mapping.
|
96 |
+
if (question.category in categoryMapping) {
|
97 |
+
categoryMapping[question.category].push(question.id);
|
98 |
+
} else {
|
99 |
+
categoryMapping[question.category] = [question.id];
|
100 |
+
const category_option = document.createElement('option');
|
101 |
+
category_option.value = question.category;
|
102 |
+
category_option.textContent = capitalizeFirstChar(question.category);
|
103 |
+
category_select.appendChild(category_option);
|
104 |
+
}
|
105 |
+
});
|
106 |
+
// Set the default category.
|
107 |
+
updateQuestionSelect(currentQuestionIndex);
|
108 |
+
}
|
109 |
+
|
110 |
+
function displayQuestion(index) {
|
111 |
+
const question = questionMapping[index].question;
|
112 |
+
document.getElementById('selected-question').innerHTML = text2Markdown('**Question:** ' + question);
|
113 |
+
displayAnswers(index);
|
114 |
+
}
|
115 |
+
|
116 |
+
function displayAnswers(index) {
|
117 |
+
const question = questionMapping[index];
|
118 |
+
const otherModel = document.getElementById('model-select').value;
|
119 |
+
// render the answers with markdown
|
120 |
+
document.getElementById('other-model-answer').innerHTML = text2Markdown(question.answers[otherModel]);
|
121 |
+
document.getElementById('our-model-answer').innerHTML = text2Markdown(question.answers.vicuna);
|
122 |
+
|
123 |
+
// Display evaluation
|
124 |
+
score = question.scores[otherModel];
|
125 |
+
score_text = modelNameMapping[otherModel] + " " + score[0] + "/10, Vicuna-13b " + score[1] + "/10";
|
126 |
+
document.getElementById('evaluation-header').textContent = "GPT-4 Evaluation" + " (Score: " + score_text + ")";
|
127 |
+
document.getElementById('evaluation-result').innerHTML = text2Markdown(question.evaluations[otherModel]);
|
128 |
+
|
129 |
+
// Update model names
|
130 |
+
let assistant1_title = "Assistant #1"; // (" + modelNameMapping[otherModel] + ")";
|
131 |
+
let assistant2_title = "Assistant #2 (Vicuna-13b, our model)";
|
132 |
+
// Update scores/labels.
|
133 |
+
let assistant1_score_label = score[0].toString() + '/10';
|
134 |
+
let assistant2_score_label = score[1].toString() + '/10';
|
135 |
+
|
136 |
+
const colorRed ='#fa9'; // '#eb978d';
|
137 |
+
// const colorGreen = '#c9f2c9';
|
138 |
+
const colorBlue = '#8ef'; // '#71dbf9';
|
139 |
+
const colorYellow = '#fe7'; // '#fada57';
|
140 |
+
let otherModelHeaderColor = '';
|
141 |
+
let ourModelHeaderColor = '';
|
142 |
+
// Update the winner.
|
143 |
+
if (score[0] == score[1]) {
|
144 |
+
assistant1_title = '🏆 ' + assistant1_title;
|
145 |
+
assistant1_score_label = '🏆 ' + assistant1_score_label;
|
146 |
+
assistant2_title = '🏆 ' + assistant2_title;
|
147 |
+
assistant2_score_label = '🏆 ' + assistant2_score_label;
|
148 |
+
otherModelHeaderColor = colorYellow;
|
149 |
+
ourModelHeaderColor = colorYellow;
|
150 |
+
} else if (score[0] > score[1]) {
|
151 |
+
assistant1_title = '🏆 ' + assistant1_title;
|
152 |
+
assistant1_score_label = '🏆 ' + assistant1_score_label;
|
153 |
+
otherModelHeaderColor = colorBlue;
|
154 |
+
ourModelHeaderColor = colorRed;
|
155 |
+
} else if (score[0] < score[1]) {
|
156 |
+
assistant2_title = '🏆 ' + assistant2_title;
|
157 |
+
assistant2_score_label = '🏆 ' + assistant2_score_label;
|
158 |
+
otherModelHeaderColor = colorRed;
|
159 |
+
ourModelHeaderColor = colorBlue;
|
160 |
+
}
|
161 |
+
|
162 |
+
document.getElementById('other-model-header-bg').style.backgroundColor = otherModelHeaderColor;
|
163 |
+
document.getElementById('our-model-header').style.backgroundColor = ourModelHeaderColor;
|
164 |
+
|
165 |
+
document.getElementById('other-model-header').textContent = assistant1_title;
|
166 |
+
document.getElementById('our-model-header').textContent = assistant2_title;
|
167 |
+
|
168 |
+
document.getElementById('other-score-label').textContent = assistant1_score_label;
|
169 |
+
document.getElementById('our-score-label').textContent = assistant2_score_label;
|
170 |
+
|
171 |
+
// Update expand buttons visibility for both cards after displaying answers
|
172 |
+
// Reset the expanded state and update expand buttons visibility for both cards after displaying answers
|
173 |
+
document.querySelectorAll('.expandable-card').forEach(card => {
|
174 |
+
card.classList.remove('expanded');
|
175 |
+
updateExpandButtonVisibility(card);
|
176 |
+
const expandBtn = card.querySelector('.expand-btn');
|
177 |
+
expandBtn.innerHTML = '<i class="material-icons" style="pointer-events: none">keyboard_arrow_down</i> Show more'; // .textContent = 'Show more';
|
178 |
+
});
|
179 |
+
}
|
180 |
+
|
181 |
+
document.getElementById('question-select').addEventListener('change', e => {
|
182 |
+
currentQuestionIndex = parseInt(e.target.value);
|
183 |
+
displayQuestion(currentQuestionIndex);
|
184 |
+
});
|
185 |
+
|
186 |
+
document.getElementById('category-select').addEventListener('change', e => {
|
187 |
+
let currentCategory = e.target.value;
|
188 |
+
const questionIds = categoryMapping[currentCategory];
|
189 |
+
currentQuestionIndex = questionIds[0];
|
190 |
+
updateQuestionSelect(currentQuestionIndex);
|
191 |
+
displayQuestion(currentQuestionIndex);
|
192 |
+
});
|
193 |
+
|
194 |
+
// Update expand buttons whenever the model is changed
|
195 |
+
document.getElementById('model-select').addEventListener('change', () => {
|
196 |
+
displayAnswers(currentQuestionIndex);
|
197 |
+
document.querySelectorAll('.expandable-card').forEach(card => {
|
198 |
+
updateExpandButtonVisibility(card);
|
199 |
+
});
|
200 |
+
updateModelSelect();
|
201 |
+
});
|
202 |
+
|
203 |
+
function switchQuestionAndCategory() {
|
204 |
+
document.getElementById('question-select').value = currentQuestionIndex;
|
205 |
+
old_category = document.getElementById('category-select').value;
|
206 |
+
new_category = questionMapping[currentQuestionIndex].category;
|
207 |
+
if (old_category != new_category) {
|
208 |
+
document.getElementById('category-select').value = new_category;
|
209 |
+
updateQuestionSelect(currentQuestionIndex);
|
210 |
+
}
|
211 |
+
displayQuestion(currentQuestionIndex);
|
212 |
+
}
|
213 |
+
|
214 |
+
document.getElementById('prev-question').addEventListener('click', () => {
|
215 |
+
// Question index starts from 1.
|
216 |
+
currentQuestionIndex = Math.max(1, currentQuestionIndex - 1);
|
217 |
+
switchQuestionAndCategory();
|
218 |
+
});
|
219 |
+
|
220 |
+
document.getElementById('next-question').addEventListener('click', () => {
|
221 |
+
// Question index starts from 1.
|
222 |
+
currentQuestionIndex = Math.min(questionsCount, currentQuestionIndex + 1);
|
223 |
+
switchQuestionAndCategory();
|
224 |
+
});
|
225 |
+
|
226 |
+
function updateExpandButtonVisibility(card) {
|
227 |
+
const cardTextContainer = card.querySelector('.card-text-container');
|
228 |
+
const expandBtn = card.querySelector('.expand-btn');
|
229 |
+
if (cardTextContainer.scrollHeight > cardTextContainer.offsetHeight) {
|
230 |
+
expandBtn.style.display = 'flex';
|
231 |
+
} else {
|
232 |
+
expandBtn.style.display = 'none';
|
233 |
+
card.classList.add('expanded');
|
234 |
+
}
|
235 |
+
}
|
236 |
+
|
237 |
+
document.querySelectorAll('.expand-btn').forEach(btn => {
|
238 |
+
btn.addEventListener('click', e => {
|
239 |
+
const card = e.target.closest('.expandable-card');
|
240 |
+
card.classList.toggle('expanded');
|
241 |
+
const more = '<i class="material-icons" style="pointer-events: none">keyboard_arrow_down</i> Show more';
|
242 |
+
const less = '<i class="material-icons" style="pointer-events: none">keyboard_arrow_up</i> Show less';
|
243 |
+
e.target.innerHTML = card.classList.contains('expanded') ? less : more;
|
244 |
+
});
|
245 |
+
});
|
LLAVA_Biovil/llava/eval/webpage/styles.css
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
3 |
+
background-color: #f8f9fa;
|
4 |
+
}
|
5 |
+
|
6 |
+
.navbar-dark .navbar-nav .nav-link {
|
7 |
+
color: #f1cf68;
|
8 |
+
font-size: 1.1rem;
|
9 |
+
padding: 0.5rem 0.6rem;
|
10 |
+
}
|
11 |
+
|
12 |
+
.card-header {
|
13 |
+
font-weight: bold;
|
14 |
+
}
|
15 |
+
|
16 |
+
.card {
|
17 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
18 |
+
transition: 0.3s;
|
19 |
+
}
|
20 |
+
|
21 |
+
.card:hover {
|
22 |
+
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
|
23 |
+
}
|
24 |
+
|
25 |
+
button {
|
26 |
+
transition: background-color 0.3s;
|
27 |
+
}
|
28 |
+
|
29 |
+
button:hover {
|
30 |
+
background-color: #007bff;
|
31 |
+
}
|
32 |
+
|
33 |
+
@media (max-width: 767px) {
|
34 |
+
.form-row .form-group {
|
35 |
+
margin-bottom: 10px;
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
/* Extra styles */
|
40 |
+
|
41 |
+
.expandable-card .card-text-container {
|
42 |
+
max-height: 200px;
|
43 |
+
overflow-y: hidden;
|
44 |
+
position: relative;
|
45 |
+
}
|
46 |
+
|
47 |
+
.expandable-card.expanded .card-text-container {
|
48 |
+
max-height: none;
|
49 |
+
}
|
50 |
+
|
51 |
+
.expand-btn {
|
52 |
+
position: relative;
|
53 |
+
display: none;
|
54 |
+
background-color: rgba(255, 255, 255, 0.8);
|
55 |
+
color: #510c75;
|
56 |
+
border-color: transparent;
|
57 |
+
}
|
58 |
+
|
59 |
+
.expand-btn:hover {
|
60 |
+
background-color: rgba(200, 200, 200, 0.8);
|
61 |
+
text-decoration: none;
|
62 |
+
border-color: transparent;
|
63 |
+
color: #510c75;
|
64 |
+
}
|
65 |
+
|
66 |
+
.expand-btn:focus {
|
67 |
+
outline: none;
|
68 |
+
text-decoration: none;
|
69 |
+
}
|
70 |
+
|
71 |
+
.expandable-card:not(.expanded) .card-text-container:after {
|
72 |
+
content: "";
|
73 |
+
position: absolute;
|
74 |
+
bottom: 0;
|
75 |
+
left: 0;
|
76 |
+
width: 100%;
|
77 |
+
height: 90px;
|
78 |
+
background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
|
79 |
+
}
|
80 |
+
|
81 |
+
.expandable-card:not(.expanded) .expand-btn {
|
82 |
+
margin-top: -40px;
|
83 |
+
}
|
84 |
+
|
85 |
+
.card-body {
|
86 |
+
padding-bottom: 5px;
|
87 |
+
}
|
88 |
+
|
89 |
+
.vertical-flex-layout {
|
90 |
+
justify-content: center;
|
91 |
+
align-items: center;
|
92 |
+
height: 100%;
|
93 |
+
display: flex;
|
94 |
+
flex-direction: column;
|
95 |
+
gap: 5px;
|
96 |
+
}
|
97 |
+
|
98 |
+
.figure-img {
|
99 |
+
max-width: 100%;
|
100 |
+
height: auto;
|
101 |
+
}
|
102 |
+
|
103 |
+
.adjustable-font-size {
|
104 |
+
font-size: calc(0.5rem + 2vw);
|
105 |
+
}
|
LLAVA_Biovil/llava/mm_utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import StoppingCriteria
|
8 |
+
from llava.constants import IMAGE_TOKEN_INDEX
|
9 |
+
|
10 |
+
|
11 |
+
def load_image_from_base64(image):
|
12 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
13 |
+
|
14 |
+
def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray:
|
15 |
+
"""Remap values in input so the output range is :math:`[0, 255]`.
|
16 |
+
|
17 |
+
Percentiles can be used to specify the range of values to remap.
|
18 |
+
This is useful to discard outliers in the input data.
|
19 |
+
|
20 |
+
:param array: Input array.
|
21 |
+
:param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``.
|
22 |
+
Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster).
|
23 |
+
:returns: Array with ``0`` and ``255`` as minimum and maximum values.
|
24 |
+
"""
|
25 |
+
array = array.astype(float)
|
26 |
+
if percentiles is not None:
|
27 |
+
len_percentiles = len(percentiles)
|
28 |
+
if len_percentiles != 2:
|
29 |
+
message = (
|
30 |
+
'The value for percentiles should be a sequence of length 2,'
|
31 |
+
f' but has length {len_percentiles}'
|
32 |
+
)
|
33 |
+
raise ValueError(message)
|
34 |
+
a, b = percentiles
|
35 |
+
if a >= b:
|
36 |
+
raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed')
|
37 |
+
if a < 0 or b > 100:
|
38 |
+
raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed')
|
39 |
+
cutoff: np.ndarray = np.percentile(array, percentiles)
|
40 |
+
array = np.clip(array, *cutoff)
|
41 |
+
array -= array.min()
|
42 |
+
array /= array.max()
|
43 |
+
array *= 255
|
44 |
+
return array.astype(np.uint8)
|
45 |
+
def load_image_from_base64_biovil(image):
|
46 |
+
image = Image.open(BytesIO(base64.b64decode(image)))
|
47 |
+
image = remap_to_uint8(np.array(image))
|
48 |
+
return Image.fromarray(image).convert("L")
|
49 |
+
|
50 |
+
def expand2square(pil_img, background_color):
|
51 |
+
width, height = pil_img.size
|
52 |
+
if width == height:
|
53 |
+
return pil_img
|
54 |
+
elif width > height:
|
55 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
56 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
57 |
+
return result
|
58 |
+
else:
|
59 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
60 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
61 |
+
return result
|
62 |
+
|
63 |
+
def process_images(images, image_processor, model_cfg):
|
64 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
65 |
+
new_images = []
|
66 |
+
if image_aspect_ratio == 'pad':
|
67 |
+
for image in images:
|
68 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
69 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
70 |
+
new_images.append(image)
|
71 |
+
else:
|
72 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
73 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
74 |
+
new_images = torch.stack(new_images, dim=0)
|
75 |
+
return new_images
|
76 |
+
|
77 |
+
def process_image_biovil(images, image_processor):
|
78 |
+
new_images = []
|
79 |
+
for image in images:
|
80 |
+
image = image_processor(image)
|
81 |
+
new_images.append(image)
|
82 |
+
|
83 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
84 |
+
new_images = torch.stack(new_images, dim=0)
|
85 |
+
return new_images
|
86 |
+
|
87 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
88 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
89 |
+
|
90 |
+
def insert_separator(X, sep):
|
91 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
92 |
+
|
93 |
+
input_ids = []
|
94 |
+
offset = 0
|
95 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
96 |
+
offset = 1
|
97 |
+
input_ids.append(prompt_chunks[0][0])
|
98 |
+
|
99 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
100 |
+
input_ids.extend(x[offset:])
|
101 |
+
|
102 |
+
if return_tensors is not None:
|
103 |
+
if return_tensors == 'pt':
|
104 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
105 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
106 |
+
return input_ids
|
107 |
+
|
108 |
+
|
109 |
+
def get_model_name_from_path(model_path):
|
110 |
+
model_path = model_path.strip("/")
|
111 |
+
model_paths = model_path.split("/")
|
112 |
+
if model_paths[-1].startswith('checkpoint-'):
|
113 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
114 |
+
else:
|
115 |
+
return model_paths[-1]
|
116 |
+
|
117 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
118 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
119 |
+
self.keywords = keywords
|
120 |
+
self.keyword_ids = []
|
121 |
+
self.max_keyword_len = 0
|
122 |
+
for keyword in keywords:
|
123 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
124 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
125 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
126 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
127 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
128 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
129 |
+
self.tokenizer = tokenizer
|
130 |
+
self.start_len = input_ids.shape[1]
|
131 |
+
|
132 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
133 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
134 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
135 |
+
for keyword_id in self.keyword_ids:
|
136 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
137 |
+
return True
|
138 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
139 |
+
for keyword in self.keywords:
|
140 |
+
if keyword in outputs:
|
141 |
+
return True
|
142 |
+
return False
|
143 |
+
|
144 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
145 |
+
outputs = []
|
146 |
+
for i in range(output_ids.shape[0]):
|
147 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
148 |
+
return all(outputs)
|