Spaces:
No application file
No application file
Upload 244 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- lm-watermarking-main/.gitignore +129 -0
- lm-watermarking-main/LICENSE.md +201 -0
- lm-watermarking-main/MANIFEST.in +7 -0
- lm-watermarking-main/README.md +119 -0
- lm-watermarking-main/__pycache__/alternative_prf_schemes.cpython-310.pyc +0 -0
- lm-watermarking-main/__pycache__/demo_watermark.cpython-310.pyc +0 -0
- lm-watermarking-main/__pycache__/demo_watermark.cpython-39.pyc +0 -0
- lm-watermarking-main/__pycache__/extended_watermark_processor.cpython-310.pyc +0 -0
- lm-watermarking-main/__pycache__/extended_watermark_processor.cpython-39.pyc +0 -0
- lm-watermarking-main/__pycache__/homoglyphs.cpython-310.pyc +0 -0
- lm-watermarking-main/__pycache__/normalizers.cpython-310.pyc +0 -0
- lm-watermarking-main/alternative_prf_schemes.py +184 -0
- lm-watermarking-main/app.py +52 -0
- lm-watermarking-main/demo_watermark.py +702 -0
- lm-watermarking-main/experiments/README.md +91 -0
- lm-watermarking-main/experiments/io_utils.py +116 -0
- lm-watermarking-main/experiments/launch.py +222 -0
- lm-watermarking-main/experiments/process_rows.py +250 -0
- lm-watermarking-main/experiments/run_watermarking.py +705 -0
- lm-watermarking-main/experiments/submitit_utils.py +79 -0
- lm-watermarking-main/experiments/watermark.py +820 -0
- lm-watermarking-main/experiments/watermarking_analysis.ipynb +2049 -0
- lm-watermarking-main/experiments/watermarking_example_finding.ipynb +1007 -0
- lm-watermarking-main/extended_watermark_processor.py +625 -0
- lm-watermarking-main/homoglyph_data/__init__.py +40 -0
- lm-watermarking-main/homoglyph_data/categories.json +0 -0
- lm-watermarking-main/homoglyph_data/confusables_sept2022.json +0 -0
- lm-watermarking-main/homoglyph_data/languages.json +34 -0
- lm-watermarking-main/homoglyphs.py +265 -0
- lm-watermarking-main/normalizers.py +208 -0
- lm-watermarking-main/pyproject.toml +6 -0
- lm-watermarking-main/requirements.txt +6 -0
- lm-watermarking-main/setup.cfg +68 -0
- lm-watermarking-main/watermark_processor.py +282 -0
- lm-watermarking-main/watermark_reliability_release/PIPELINE.md +154 -0
- lm-watermarking-main/watermark_reliability_release/README.md +27 -0
- lm-watermarking-main/watermark_reliability_release/alternative_prf_schemes.py +172 -0
- lm-watermarking-main/watermark_reliability_release/attack_pipeline.py +506 -0
- lm-watermarking-main/watermark_reliability_release/broadcast_token_prefixes.py +436 -0
- lm-watermarking-main/watermark_reliability_release/detectgpt/debug.sh +77 -0
- lm-watermarking-main/watermark_reliability_release/detectgpt/detectgpt_main.py +807 -0
- lm-watermarking-main/watermark_reliability_release/detectgpt/make_plot.py +124 -0
- lm-watermarking-main/watermark_reliability_release/detectgpt/plot.sh +46 -0
- lm-watermarking-main/watermark_reliability_release/detectgpt/run_detectgpt.sh +63 -0
- lm-watermarking-main/watermark_reliability_release/evaluation_pipeline.py +1330 -0
- lm-watermarking-main/watermark_reliability_release/figure_notebooks/baseline_comparison.ipynb +0 -0
- lm-watermarking-main/watermark_reliability_release/figure_notebooks/baseline_comparison_transpose.ipynb +0 -0
- lm-watermarking-main/watermark_reliability_release/figure_notebooks/core_robustness.ipynb +0 -0
- lm-watermarking-main/watermark_reliability_release/figure_notebooks/data_model.ipynb +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
lm-watermarking-main/watermark_reliability_release/utils/data/lfqa.jsonl filter=lfs diff=lfs merge=lfs -text
|
lm-watermarking-main/.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
lm-watermarking-main/LICENSE.md
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2023] [Authors of 'A Watermark for Large Language Models']
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
lm-watermarking-main/MANIFEST.in
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# added manually
|
2 |
+
include *.py
|
3 |
+
include *.json
|
4 |
+
include *.md
|
5 |
+
|
6 |
+
global-exclude *.pyc
|
7 |
+
global-exclude __pycache__
|
lm-watermarking-main/README.md
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍
|
2 |
+
|
3 |
+
### [Demo](https://huggingface.co/spaces/tomg-group-umd/lm-watermarking) | [Paper](https://arxiv.org/abs/2301.10226)
|
4 |
+
|
5 |
+
Official implementation of the watermarking and detection algorithms presented in the paper:
|
6 |
+
|
7 |
+
"A Watermark for Large language Models" by _John Kirchenbauer*, Jonas Geiping*, Yuxin Wen, Jonathan Katz, Ian Miers, Tom Goldstein_
|
8 |
+
|
9 |
+
### Updates:
|
10 |
+
|
11 |
+
- **(6/7/23)** We're thrilled to announce the release of ["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634) Our new preprint documents a deep dive into the robustness properties of more advanced watermarks.
|
12 |
+
|
13 |
+
- **(6/9/23)** Initial code release implementing the alternate watermark and detector variants in the new work. Files located in this subdirectory: [`watermark_reliability_release`](watermark_reliability_release).
|
14 |
+
|
15 |
+
- **(9/23/23)** Update to the docs with recommendations on parameter settings. Extended implementation (recommended) available in `extended_watermark_processor.py`.
|
16 |
+
|
17 |
+
---
|
18 |
+
|
19 |
+
Implementation is based on the "logit processor" abstraction provided by the [huggingface/transformers 🤗](https://github.com/huggingface/transformers) library.
|
20 |
+
|
21 |
+
The `WatermarkLogitsProcessor` is designed to be readily compatible with any model that supports the `generate` API.
|
22 |
+
Any model that can be constructed using the `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` factories _should_ be compatible.
|
23 |
+
|
24 |
+
### Repo contents
|
25 |
+
|
26 |
+
The core implementation is defined by the `WatermarkBase`, `WatermarkLogitsProcessor`, and `WatermarkDetector` classes in the files `watermark_processor.py`, for a minimal implementation and `extended_watermark_processor.py` for the more full featured implementation (recommended).
|
27 |
+
The `demo_watermark.py` script implements a gradio demo interface as well as minimum working example in the `main` function using the minimal version.
|
28 |
+
|
29 |
+
Details about the parameters and the detection outputs are provided in the gradio app markdown blocks as well as the argparse definition.
|
30 |
+
|
31 |
+
The `homoglyphs.py` and `normalizers.py` modules implement algorithms used by the `WatermarkDetector`. `homoglyphs.py` (and its raw data in `homoglyph_data`) is an updated version of the homoglyph code from the deprecated package described here: https://github.com/life4/homoglyphs.
|
32 |
+
The `experiments` directory contains pipeline code that we used to run the original experiments in the paper. However this is stale/deprecated
|
33 |
+
in favor of the implementation in `watermark_processor.py`.
|
34 |
+
|
35 |
+
### Demo Usage
|
36 |
+
|
37 |
+
As a quickstart, the app can be launched with default args (or deployed to a [huggingface Space](https://huggingface.co/spaces)) using `app.py`
|
38 |
+
which is just a thin wrapper around the demo script.
|
39 |
+
```sh
|
40 |
+
python app.py
|
41 |
+
gradio app.py # for hot reloading
|
42 |
+
# or
|
43 |
+
python demo_watermark.py --model_name_or_path facebook/opt-6.7b
|
44 |
+
```
|
45 |
+
|
46 |
+
|
47 |
+
### How to Watermark - A short guide on watermark hyperparameters
|
48 |
+
What watermark hyperparameters are optimal for your task or for a comparison to new watermarks? We'll provide a brief overview about all important settings below, and best practices for future work. This guide represents our current understanding of optimal settings as of August 2023, and so is a bit more up to date than our ICML 2023 conference paper.
|
49 |
+
|
50 |
+
**TL;DR**: As a baseline generation setting, we suggest default values of `gamma=0.25` and `delta=2.0`. Reduce delta if text quality is negatively impacted. For the context width, h, we recommend a moderate value, i.e. h=4, and as a default PRF we recommend `selfhash`, but can use `minhash` if you want. Reduce h if more robustness against edits is required. Note however that the choice of PRF only matters if h>1. The recommended PRF and context width can be easily selected by instantiating the watermark processor and detector with `seeding_scheme="selfhash"` (a shorthand for `seeding_scheme="ff-anchored_minhash_prf-4-True-15485863"`, but do use a different base key if actually deploying). For detection, always run with `--ignore--repeated-ngrams=True`.
|
51 |
+
|
52 |
+
1) **Logit bias delta**: The magnitude of delta determines the strength of the watermark. A sufficiently large value of delta recovers a "hard" watermark that encodes 1 bit of information at every token, but this is not an advisable setting, as it strongly affects model quality. A moderate delta in the range of [0.5, 2.0] is appropriate for normal use cases, but the strength of delta is relative to the entropy of the output distribution. Models that are overconfident, such as instruction-tuned models, may benefit from choosing a larger delta value. With non-infinite delta values, the watermark strength is directly proportional to the (spike) entropy of the text and exp(delta) (see Theorem 4.2 in our paper).
|
53 |
+
|
54 |
+
2) **Context width h**: Context width is the length of the context which is taken into account when seeding the watermark at each location. The longer the context, the "more random" the red/green list partitions are, and the less detectable the watermark is. For private watermarks, this implies that the watermark is harder to discover via brute-force (with an exponential increase in hardness with increasing context width h).
|
55 |
+
In the limit of a very long context width, we approach the "undetectable" setting of https://eprint.iacr.org/2023/763. However, the longer the context width, the less "nuclear" the watermark is, and robustness to paraphrasing and other attacks decreases. In the limit of h=0, the watermark is independent of local context and, as such, it is minimally random, but maximally robust against edits (see https://arxiv.org/abs/2306.17439).
|
56 |
+
|
57 |
+
3) **Ignoring repeated ngrams**: The watermark is only pseudo-random based on the local context. Whenever local context repeats, this constitutes a violation of the assumption that the PRNG numbers used to seed the green/red partition operation are drawn iid. (See Sec.4. in our paper for details). For this reason, p-values for text with repeated n-grams (n-gram here meaning context + chosen token) will be misleading. As such, detection should be run with `--ignore-repeated-ngrams` set to `True`. An additional, detailed analysis of this effect can be found in http://arxiv.org/abs/2308.00113.
|
58 |
+
|
59 |
+
4) **Choice of pseudo-random-function** (PRF): This choice is only relevant if context width h>1 and determines the robustness of the hash of the context against edits. In our experiments we find "min"-hash PRFs to be the most performant in striking a balance between maximizing robustness and minimizing impact on text quality. In comparison to a PRF that depends on the entire context, this PRF only depends on a single, randomly chosen token from the context.
|
60 |
+
|
61 |
+
5) **Self-Hashing**: It is possible to extend the context width of the watermark onto the current token. This effectively extends the context width "for-free" by one. The only downside is that this approach requires hashing all possible next tokens, and applying the logit bias only to tokens where their inclusion in the context would produce a hash that includes this token on the green list. This is slow in the way we implement it, because we use cuda's pseudorandom number generator and a simple inner-loop implementation, but in principle has a negligible cost, compared to generating new tokens if engineered for deployment. A generalized algorithm for self-hashing can be found as Alg.1 in http://arxiv.org/abs/2306.04634.
|
62 |
+
|
63 |
+
6) **Gamma**: Gamma denotes the fraction of the vocabulary that will be in each green list. We find gamma=0.25 to be slightly more optimal empirically, but this is a minor effect and reasonable values of gamma between 0.25 and 0.75 will lead to reasonable watermark. A intuitive argument can be made for why this makes it easier to achieve a fraction of green tokens sufficiently higher than gamma to reject the null hypothesis, when you choose a lower gamma value.
|
64 |
+
|
65 |
+
7) **Base Key**: Our watermark is salted with a small base key of 15485863 (the millionth prime). If you deploy this watermark, we do not advise re-using this key.
|
66 |
+
|
67 |
+
### How to use the watermark in your own code.
|
68 |
+
|
69 |
+
Our implementation can be added into any huggingface generation pipeline as an additional `LogitProcessor`, only the classes `WatermarkLogitsProcessor` and `WatermarkDetector` from the `extended_watermark_processor.py` file are required.
|
70 |
+
|
71 |
+
Example snippet to generate watermarked text:
|
72 |
+
```python
|
73 |
+
|
74 |
+
from extended_watermark_processor import WatermarkLogitsProcessor
|
75 |
+
|
76 |
+
watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
|
77 |
+
gamma=0.25,
|
78 |
+
delta=2.0,
|
79 |
+
seeding_scheme="selfhash") #equivalent to `ff-anchored_minhash_prf-4-True-15485863`
|
80 |
+
# Note:
|
81 |
+
# You can turn off self-hashing by setting the seeding scheme to `minhash`.
|
82 |
+
|
83 |
+
tokenized_input = tokenizer(input_text).to(model.device)
|
84 |
+
# note that if the model is on cuda, then the input is on cuda
|
85 |
+
# and thus the watermarking rng is cuda-based.
|
86 |
+
# This is a different generator than the cpu-based rng in pytorch!
|
87 |
+
|
88 |
+
output_tokens = model.generate(**tokenized_input,
|
89 |
+
logits_processor=LogitsProcessorList([watermark_processor]))
|
90 |
+
|
91 |
+
# if decoder only model, then we need to isolate the
|
92 |
+
# newly generated tokens as only those are watermarked, the input/prompt is not
|
93 |
+
output_tokens = output_tokens[:,tokenized_input["input_ids"].shape[-1]:]
|
94 |
+
|
95 |
+
output_text = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
|
96 |
+
```
|
97 |
+
|
98 |
+
Example snippet to detect watermarked text:
|
99 |
+
```python
|
100 |
+
|
101 |
+
from extended_watermark_processor import WatermarkDetector
|
102 |
+
|
103 |
+
watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
|
104 |
+
gamma=0.25, # should match original setting
|
105 |
+
seeding_scheme="selfhash", # should match original setting
|
106 |
+
device=model.device, # must match the original rng device type
|
107 |
+
tokenizer=tokenizer,
|
108 |
+
z_threshold=4.0,
|
109 |
+
normalizers=[],
|
110 |
+
ignore_repeated_ngrams=True)
|
111 |
+
|
112 |
+
score_dict = watermark_detector.detect(output_text) # or any other text of interest to analyze
|
113 |
+
```
|
114 |
+
|
115 |
+
To recover the main settings of the experiments in the original work (for historical reasons), use the seeding scheme `simple_1` and set `ignore_repeated_ngrams=False` at detection time.
|
116 |
+
|
117 |
+
|
118 |
+
### Contributing
|
119 |
+
Suggestions and PR's welcome 🙂
|
lm-watermarking-main/__pycache__/alternative_prf_schemes.cpython-310.pyc
ADDED
Binary file (4.92 kB). View file
|
|
lm-watermarking-main/__pycache__/demo_watermark.cpython-310.pyc
ADDED
Binary file (28.4 kB). View file
|
|
lm-watermarking-main/__pycache__/demo_watermark.cpython-39.pyc
ADDED
Binary file (28.5 kB). View file
|
|
lm-watermarking-main/__pycache__/extended_watermark_processor.cpython-310.pyc
ADDED
Binary file (18.3 kB). View file
|
|
lm-watermarking-main/__pycache__/extended_watermark_processor.cpython-39.pyc
ADDED
Binary file (17.8 kB). View file
|
|
lm-watermarking-main/__pycache__/homoglyphs.cpython-310.pyc
ADDED
Binary file (7.79 kB). View file
|
|
lm-watermarking-main/__pycache__/normalizers.cpython-310.pyc
ADDED
Binary file (6.86 kB). View file
|
|
lm-watermarking-main/alternative_prf_schemes.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implement other PRF functions (These all vary only how they generate a single hash from the tokens in the context).
|
2 |
+
|
3 |
+
Can be hooked into existing WatermarkLogitsProcessor as modified base class WatermarkBase, see implementation in
|
4 |
+
extended_watermark_processor.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
# coding=utf-8
|
8 |
+
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
9 |
+
# available at https://arxiv.org/abs/2301.10226
|
10 |
+
#
|
11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
12 |
+
# you may not use this file except in compliance with the License.
|
13 |
+
# You may obtain a copy of the License at
|
14 |
+
#
|
15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
16 |
+
#
|
17 |
+
# Unless required by applicable law or agreed to in writing, software
|
18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
20 |
+
# See the License for the specific language governing permissions and
|
21 |
+
# limitations under the License.
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from itertools import combinations
|
25 |
+
from functools import cache
|
26 |
+
|
27 |
+
# Key properties of a hashing scheme
|
28 |
+
props = {
|
29 |
+
"prf_type": str, # string name of the underlying PRF mapping multiple token ids to a random seed
|
30 |
+
"context_width": int, # this is h in the paper, how many previous tokens should be considered for each PRF
|
31 |
+
"self_salt": bool, # Use the rules laid in robust-watermarking to use the token itself to seed and possibly reject its own list
|
32 |
+
"hash_key": int, # integer, large prime, used to move seed away from low-entrop bit sequences in PRF chosen above
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def seeding_scheme_lookup(seeding_scheme: str):
|
37 |
+
if not isinstance(seeding_scheme, str):
|
38 |
+
raise ValueError("Seeding scheme should be a string summarizing the procedure.")
|
39 |
+
if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
|
40 |
+
# Default, simple bigram hash # alias for ff-additive_prf-1-False-15485863
|
41 |
+
prf_type = "additive_prf"
|
42 |
+
context_width = 1
|
43 |
+
self_salt = False
|
44 |
+
hash_key = 15485863
|
45 |
+
elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
|
46 |
+
prf_type = "anchored_minhash_prf"
|
47 |
+
context_width = 4
|
48 |
+
self_salt = True
|
49 |
+
hash_key = 15485863
|
50 |
+
elif seeding_scheme == "minhash":
|
51 |
+
prf_type = "minhash_prf"
|
52 |
+
context_width = 4
|
53 |
+
self_salt = False
|
54 |
+
hash_key = 15485863
|
55 |
+
elif seeding_scheme == "skipgram":
|
56 |
+
prf_type = "skipgram_prf"
|
57 |
+
context_width = 5
|
58 |
+
self_salt = False
|
59 |
+
hash_key = 15485863
|
60 |
+
elif seeding_scheme.startswith("ff"): # freeform seeding scheme API - only use for experimenting
|
61 |
+
# expects strings of the form ff-additive_prf-4-True-hash or ff-additive_prf-5-True (hash key is optional)
|
62 |
+
split_scheme = seeding_scheme.split("-")
|
63 |
+
prf_type = str(split_scheme[1])
|
64 |
+
context_width = int(split_scheme[2])
|
65 |
+
self_salt = split_scheme[3] == "True"
|
66 |
+
if len(split_scheme) == 5:
|
67 |
+
hash_key = int(split_scheme[4])
|
68 |
+
else:
|
69 |
+
hash_key = 15485863
|
70 |
+
else:
|
71 |
+
raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?")
|
72 |
+
|
73 |
+
assert prf_type in prf_lookup.keys()
|
74 |
+
return prf_type, context_width, self_salt, hash_key
|
75 |
+
|
76 |
+
|
77 |
+
def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
78 |
+
return salt_key * input_ids.prod().item()
|
79 |
+
|
80 |
+
|
81 |
+
def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
82 |
+
return salt_key * input_ids.sum().item()
|
83 |
+
|
84 |
+
|
85 |
+
def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
86 |
+
# not a great idea for non-random input ids as in text
|
87 |
+
return salt_key * input_ids.min().item()
|
88 |
+
|
89 |
+
|
90 |
+
def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
|
91 |
+
# k is the skip distance
|
92 |
+
return hashint(salt_key * input_ids[::k]).prod().item()
|
93 |
+
|
94 |
+
|
95 |
+
def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
96 |
+
# maximum distance skipgram within context
|
97 |
+
return hashint(salt_key * input_ids[0]).item()
|
98 |
+
|
99 |
+
|
100 |
+
def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
|
101 |
+
# maximum distance skipgram within context
|
102 |
+
return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()
|
103 |
+
|
104 |
+
|
105 |
+
def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
106 |
+
# slightly less not the greatest idea for non-random input ids as in text
|
107 |
+
return hashint(salt_key * input_ids).min().item()
|
108 |
+
|
109 |
+
|
110 |
+
def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
|
111 |
+
# Anchor to one key to produce a min over pairs again
|
112 |
+
return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()
|
113 |
+
|
114 |
+
|
115 |
+
def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
|
116 |
+
# min over all skipgrams in context, k=2 is all pairs
|
117 |
+
skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
|
118 |
+
return skipgrams.prod(dim=1).min().item()
|
119 |
+
|
120 |
+
|
121 |
+
def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
|
122 |
+
key = torch.as_tensor(salt_key, dtype=torch.long)
|
123 |
+
for entry in input_ids:
|
124 |
+
key *= hashint(key * entry)
|
125 |
+
key %= 2**32
|
126 |
+
return key.item()
|
127 |
+
|
128 |
+
|
129 |
+
def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
|
130 |
+
return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item()
|
131 |
+
|
132 |
+
|
133 |
+
prf_lookup = {
|
134 |
+
"multiplicative_prf": multiplicative_prf,
|
135 |
+
"additive_prf": additive_prf,
|
136 |
+
"minfunc_prf": minfunc_prf,
|
137 |
+
"simple_skip_prf": simple_skip_prf,
|
138 |
+
"skipgram_prf": skipgram_prf,
|
139 |
+
"anchored_skipgram_prf": anchored_skipgram_prf,
|
140 |
+
"minhash_prf": minhash_prf,
|
141 |
+
"anchored_minhash_prf": anchored_minhash_prf,
|
142 |
+
"minskipgram_prf": minskipgram_prf,
|
143 |
+
"noncomm_prf": noncomm_prf,
|
144 |
+
"position_prf": position_prf,
|
145 |
+
}
|
146 |
+
|
147 |
+
# Generate a global permute table once at startup
|
148 |
+
rng = torch.Generator(device=torch.device("cpu"))
|
149 |
+
rng.manual_seed(2971215073) # fib47 is prime
|
150 |
+
table_size = 1_000_003
|
151 |
+
fixed_table = torch.randperm(1_000_003, device=torch.device("cpu"), generator=rng) # actually faster than I thought
|
152 |
+
|
153 |
+
|
154 |
+
def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
|
155 |
+
"""Sane version, in the end we only need a small permutation table."""
|
156 |
+
return fixed_table[integer_tensor.cpu() % table_size] + 1 # minor cheat here, this function always return CPU values
|
157 |
+
|
158 |
+
|
159 |
+
def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
|
160 |
+
"""http://burtleburtle.net/bob/hash/integer.html, ported into pytorch, runs on tensors. Apparently a decent avalanche."""
|
161 |
+
i = integer_tensor.to(torch.int32).clone() # or torch.int16?
|
162 |
+
i -= i << 6
|
163 |
+
i ^= i >> 17
|
164 |
+
i -= i << 9
|
165 |
+
i ^= i << 4
|
166 |
+
i -= i << 3
|
167 |
+
i ^= i << 10
|
168 |
+
i ^= i >> 15
|
169 |
+
return i.to(torch.long)
|
170 |
+
|
171 |
+
|
172 |
+
@cache
|
173 |
+
def _hashint_avalanche_int(integer: int):
|
174 |
+
"""http://burtleburtle.net/bob/hash/integer.html, runs in base python, caches based on access.
|
175 |
+
Does this make sense for signed 64bit ints?"""
|
176 |
+
i = integer % (2**32)
|
177 |
+
i -= i << 6
|
178 |
+
i ^= i >> 17
|
179 |
+
i -= i << 9
|
180 |
+
i ^= i << 4
|
181 |
+
i -= i << 3
|
182 |
+
i ^= i << 10
|
183 |
+
i ^= i >> 15
|
184 |
+
return i
|
lm-watermarking-main/app.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
3 |
+
# available at https://arxiv.org/abs/2301.10226
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from argparse import Namespace
|
18 |
+
args = Namespace()
|
19 |
+
|
20 |
+
arg_dict = {
|
21 |
+
'run_gradio': True,
|
22 |
+
'demo_public': False,
|
23 |
+
# 'model_name_or_path': 'facebook/opt-125m',
|
24 |
+
# 'model_name_or_path': 'facebook/opt-1.3b',
|
25 |
+
# 'model_name_or_path': 'facebook/opt-2.7b',
|
26 |
+
'model_name_or_path': 'facebook/opt-6.7b',
|
27 |
+
# 'model_name_or_path': 'facebook/opt-13b',
|
28 |
+
# 'load_fp16' : True,
|
29 |
+
'load_fp16' : False,
|
30 |
+
'prompt_max_length': None,
|
31 |
+
'max_new_tokens': 200,
|
32 |
+
'generation_seed': 123,
|
33 |
+
'use_sampling': True,
|
34 |
+
'n_beams': 1,
|
35 |
+
'sampling_temp': 0.7,
|
36 |
+
'use_gpu': True,
|
37 |
+
'seeding_scheme': 'simple_1',
|
38 |
+
'gamma': 0.25,
|
39 |
+
'delta': 2.0,
|
40 |
+
'normalizers': '',
|
41 |
+
'ignore_repeated_bigrams': False,
|
42 |
+
'detection_z_threshold': 4.0,
|
43 |
+
'select_green_tokens': True,
|
44 |
+
'skip_model_load': False,
|
45 |
+
'seed_separately': True,
|
46 |
+
}
|
47 |
+
|
48 |
+
args.__dict__.update(arg_dict)
|
49 |
+
|
50 |
+
from demo_watermark import main
|
51 |
+
|
52 |
+
main(args)
|
lm-watermarking-main/demo_watermark.py
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
3 |
+
# available at https://arxiv.org/abs/2301.10226
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import os
|
18 |
+
import argparse
|
19 |
+
from argparse import Namespace
|
20 |
+
from pprint import pprint
|
21 |
+
from functools import partial
|
22 |
+
|
23 |
+
import numpy # for gradio hot reload
|
24 |
+
import gradio as gr
|
25 |
+
|
26 |
+
import torch
|
27 |
+
|
28 |
+
from transformers import (AutoTokenizer,
|
29 |
+
AutoModelForSeq2SeqLM,
|
30 |
+
AutoModelForCausalLM,
|
31 |
+
LogitsProcessorList)
|
32 |
+
|
33 |
+
from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
|
34 |
+
|
35 |
+
def str2bool(v):
|
36 |
+
"""Util function for user friendly boolean flag args"""
|
37 |
+
if isinstance(v, bool):
|
38 |
+
return v
|
39 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
40 |
+
return True
|
41 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
42 |
+
return False
|
43 |
+
else:
|
44 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
45 |
+
|
46 |
+
def parse_args():
|
47 |
+
"""Command line argument specification"""
|
48 |
+
|
49 |
+
parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")
|
50 |
+
|
51 |
+
parser.add_argument(
|
52 |
+
"--run_gradio",
|
53 |
+
type=str2bool,
|
54 |
+
default=True,
|
55 |
+
help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--demo_public",
|
59 |
+
type=str2bool,
|
60 |
+
default=False,
|
61 |
+
help="Whether to expose the gradio demo to the internet.",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--model_name_or_path",
|
65 |
+
type=str,
|
66 |
+
default="facebook/opt-6.7b",
|
67 |
+
help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--prompt_max_length",
|
71 |
+
type=int,
|
72 |
+
default=None,
|
73 |
+
help="Truncation length for prompt, overrides model config's max length field.",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--max_new_tokens",
|
77 |
+
type=int,
|
78 |
+
default=200,
|
79 |
+
help="Maximmum number of new tokens to generate.",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--generation_seed",
|
83 |
+
type=int,
|
84 |
+
default=123,
|
85 |
+
help="Seed for setting the torch global rng prior to generation.",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--use_sampling",
|
89 |
+
type=str2bool,
|
90 |
+
default=True,
|
91 |
+
help="Whether to generate using multinomial sampling.",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--sampling_temp",
|
95 |
+
type=float,
|
96 |
+
default=0.7,
|
97 |
+
help="Sampling temperature to use when generating using multinomial sampling.",
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--n_beams",
|
101 |
+
type=int,
|
102 |
+
default=1,
|
103 |
+
help="Number of beams to use for beam search. 1 is normal greedy decoding",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--use_gpu",
|
107 |
+
type=str2bool,
|
108 |
+
default=True,
|
109 |
+
help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--seeding_scheme",
|
113 |
+
type=str,
|
114 |
+
default="simple_1",
|
115 |
+
help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--gamma",
|
119 |
+
type=float,
|
120 |
+
default=0.25,
|
121 |
+
help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--delta",
|
125 |
+
type=float,
|
126 |
+
default=2.0,
|
127 |
+
help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--normalizers",
|
131 |
+
type=str,
|
132 |
+
default="",
|
133 |
+
help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--ignore_repeated_bigrams",
|
137 |
+
type=str2bool,
|
138 |
+
default=False,
|
139 |
+
help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--detection_z_threshold",
|
143 |
+
type=float,
|
144 |
+
default=4.0,
|
145 |
+
help="The test statistic threshold for the detection hypothesis test.",
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--select_green_tokens",
|
149 |
+
type=str2bool,
|
150 |
+
default=True,
|
151 |
+
help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--skip_model_load",
|
155 |
+
type=str2bool,
|
156 |
+
default=False,
|
157 |
+
help="Skip the model loading to debug the interface.",
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--seed_separately",
|
161 |
+
type=str2bool,
|
162 |
+
default=True,
|
163 |
+
help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--load_fp16",
|
167 |
+
type=str2bool,
|
168 |
+
default=False,
|
169 |
+
help="Whether to run model in float16 precsion.",
|
170 |
+
)
|
171 |
+
args = parser.parse_args()
|
172 |
+
return args
|
173 |
+
|
174 |
+
def load_model(args):
|
175 |
+
"""Load and return the model and tokenizer"""
|
176 |
+
|
177 |
+
args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
|
178 |
+
args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
|
179 |
+
if args.is_seq2seq_model:
|
180 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
|
181 |
+
elif args.is_decoder_only_model:
|
182 |
+
if args.load_fp16:
|
183 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
|
184 |
+
else:
|
185 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
|
186 |
+
else:
|
187 |
+
raise ValueError(f"Unknown model type: {args.model_name_or_path}")
|
188 |
+
|
189 |
+
if args.use_gpu:
|
190 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
191 |
+
if args.load_fp16:
|
192 |
+
pass
|
193 |
+
else:
|
194 |
+
model = model.to(device)
|
195 |
+
else:
|
196 |
+
device = "cpu"
|
197 |
+
model.eval()
|
198 |
+
|
199 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
200 |
+
|
201 |
+
return model, tokenizer, device
|
202 |
+
|
203 |
+
def generate(prompt, args, model=None, device=None, tokenizer=None):
|
204 |
+
"""Instatiate the WatermarkLogitsProcessor according to the watermark parameters
|
205 |
+
and generate watermarked text by passing it to the generate method of the model
|
206 |
+
as a logits processor. """
|
207 |
+
|
208 |
+
print(f"Generating with {args}")
|
209 |
+
|
210 |
+
watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
|
211 |
+
gamma=args.gamma,
|
212 |
+
delta=args.delta,
|
213 |
+
seeding_scheme=args.seeding_scheme,
|
214 |
+
select_green_tokens=args.select_green_tokens)
|
215 |
+
|
216 |
+
gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
|
217 |
+
|
218 |
+
if args.use_sampling:
|
219 |
+
gen_kwargs.update(dict(
|
220 |
+
do_sample=True,
|
221 |
+
top_k=0,
|
222 |
+
temperature=args.sampling_temp
|
223 |
+
))
|
224 |
+
else:
|
225 |
+
gen_kwargs.update(dict(
|
226 |
+
num_beams=args.n_beams
|
227 |
+
))
|
228 |
+
|
229 |
+
generate_without_watermark = partial(
|
230 |
+
model.generate,
|
231 |
+
**gen_kwargs
|
232 |
+
)
|
233 |
+
generate_with_watermark = partial(
|
234 |
+
model.generate,
|
235 |
+
logits_processor=LogitsProcessorList([watermark_processor]),
|
236 |
+
**gen_kwargs
|
237 |
+
)
|
238 |
+
if args.prompt_max_length:
|
239 |
+
pass
|
240 |
+
elif hasattr(model.config,"max_position_embedding"):
|
241 |
+
args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
|
242 |
+
else:
|
243 |
+
args.prompt_max_length = 2048-args.max_new_tokens
|
244 |
+
|
245 |
+
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
|
246 |
+
truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
|
247 |
+
redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
|
248 |
+
|
249 |
+
torch.manual_seed(args.generation_seed)
|
250 |
+
output_without_watermark = generate_without_watermark(**tokd_input)
|
251 |
+
|
252 |
+
# optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
|
253 |
+
if args.seed_separately:
|
254 |
+
torch.manual_seed(args.generation_seed)
|
255 |
+
output_with_watermark = generate_with_watermark(**tokd_input)
|
256 |
+
|
257 |
+
if args.is_decoder_only_model:
|
258 |
+
# need to isolate the newly generated tokens
|
259 |
+
output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
|
260 |
+
output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
|
261 |
+
|
262 |
+
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
|
263 |
+
decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
|
264 |
+
|
265 |
+
return (redecoded_input,
|
266 |
+
int(truncation_warning),
|
267 |
+
decoded_output_without_watermark,
|
268 |
+
decoded_output_with_watermark,
|
269 |
+
args)
|
270 |
+
# decoded_output_with_watermark)
|
271 |
+
|
272 |
+
def format_names(s):
|
273 |
+
"""Format names for the gradio demo interface"""
|
274 |
+
s=s.replace("num_tokens_scored","Tokens Counted (T)")
|
275 |
+
s=s.replace("num_green_tokens","# Tokens in Greenlist")
|
276 |
+
s=s.replace("green_fraction","Fraction of T in Greenlist")
|
277 |
+
s=s.replace("z_score","z-score")
|
278 |
+
s=s.replace("p_value","p value")
|
279 |
+
s=s.replace("prediction","Prediction")
|
280 |
+
s=s.replace("confidence","Confidence")
|
281 |
+
return s
|
282 |
+
|
283 |
+
def list_format_scores(score_dict, detection_threshold):
|
284 |
+
"""Format the detection metrics into a gradio dataframe input format"""
|
285 |
+
lst_2d = []
|
286 |
+
# lst_2d.append(["z-score threshold", f"{detection_threshold}"])
|
287 |
+
for k,v in score_dict.items():
|
288 |
+
if k=='green_fraction':
|
289 |
+
lst_2d.append([format_names(k), f"{v:.1%}"])
|
290 |
+
elif k=='confidence':
|
291 |
+
lst_2d.append([format_names(k), f"{v:.3%}"])
|
292 |
+
elif isinstance(v, float):
|
293 |
+
lst_2d.append([format_names(k), f"{v:.3g}"])
|
294 |
+
elif isinstance(v, bool):
|
295 |
+
lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
|
296 |
+
else:
|
297 |
+
lst_2d.append([format_names(k), f"{v}"])
|
298 |
+
if "confidence" in score_dict:
|
299 |
+
lst_2d.insert(-2,["z-score Threshold", f"{detection_threshold}"])
|
300 |
+
else:
|
301 |
+
lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
|
302 |
+
return lst_2d
|
303 |
+
|
304 |
+
def detect(input_text, args, device=None, tokenizer=None):
|
305 |
+
"""Instantiate the WatermarkDetection object and call detect on
|
306 |
+
the input text returning the scores and outcome of the test"""
|
307 |
+
watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
|
308 |
+
gamma=args.gamma,
|
309 |
+
seeding_scheme=args.seeding_scheme,
|
310 |
+
device=device,
|
311 |
+
tokenizer=tokenizer,
|
312 |
+
z_threshold=args.detection_z_threshold,
|
313 |
+
normalizers=args.normalizers,
|
314 |
+
ignore_repeated_bigrams=args.ignore_repeated_bigrams,
|
315 |
+
select_green_tokens=args.select_green_tokens)
|
316 |
+
if len(input_text)-1 > watermark_detector.min_prefix_len:
|
317 |
+
score_dict = watermark_detector.detect(input_text)
|
318 |
+
# output = str_format_scores(score_dict, watermark_detector.z_threshold)
|
319 |
+
output = list_format_scores(score_dict, watermark_detector.z_threshold)
|
320 |
+
else:
|
321 |
+
# output = (f"Error: string not long enough to compute watermark presence.")
|
322 |
+
output = [["Error","string too short to compute metrics"]]
|
323 |
+
output += [["",""] for _ in range(6)]
|
324 |
+
return output, args
|
325 |
+
|
326 |
+
def run_gradio(args, model=None, device=None, tokenizer=None):
|
327 |
+
"""Define and launch the gradio demo interface"""
|
328 |
+
generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
|
329 |
+
detect_partial = partial(detect, device=device, tokenizer=tokenizer)
|
330 |
+
|
331 |
+
with gr.Blocks() as demo:
|
332 |
+
# Top section, greeting and instructions
|
333 |
+
with gr.Row():
|
334 |
+
with gr.Column(scale=9):
|
335 |
+
gr.Markdown(
|
336 |
+
"""
|
337 |
+
## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍
|
338 |
+
"""
|
339 |
+
)
|
340 |
+
with gr.Column(scale=1):
|
341 |
+
gr.Markdown(
|
342 |
+
"""
|
343 |
+
[![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)
|
344 |
+
"""
|
345 |
+
)
|
346 |
+
# with gr.Column(scale=2):
|
347 |
+
# pass
|
348 |
+
# ![visitor badge](https://visitor-badge.glitch.me/badge?page_id=tomg-group-umd_lm-watermarking) # buggy
|
349 |
+
|
350 |
+
with gr.Accordion("Understanding the output metrics",open=False):
|
351 |
+
gr.Markdown(
|
352 |
+
"""
|
353 |
+
- `z-score threshold` : The cuttoff for the hypothesis test
|
354 |
+
- `Tokens Counted (T)` : The number of tokens in the output that were counted by the detection algorithm.
|
355 |
+
The first token is ommitted in the simple, single token seeding scheme since there is no way to generate
|
356 |
+
a greenlist for it as it has no prefix token(s). Under the "Ignore Bigram Repeats" detection algorithm,
|
357 |
+
described in the bottom panel, this can be much less than the total number of tokens generated if there is a lot of repetition.
|
358 |
+
- `# Tokens in Greenlist` : The number of tokens that were observed to fall in their respective greenlist
|
359 |
+
- `Fraction of T in Greenlist` : The `# Tokens in Greenlist` / `T`. This is expected to be approximately `gamma` for human/unwatermarked text.
|
360 |
+
- `z-score` : The test statistic for the detection hypothesis test. If larger than the `z-score threshold`
|
361 |
+
we "reject the null hypothesis" that the text is human/unwatermarked, and conclude it is watermarked
|
362 |
+
- `p value` : The likelihood of observing the computed `z-score` under the null hypothesis. This is the likelihood of
|
363 |
+
observing the `Fraction of T in Greenlist` given that the text was generated without knowledge of the watermark procedure/greenlists.
|
364 |
+
If this is extremely _small_ we are confident that this many green tokens was not chosen by random chance.
|
365 |
+
- `prediction` : The outcome of the hypothesis test - whether the observed `z-score` was higher than the `z-score threshold`
|
366 |
+
- `confidence` : If we reject the null hypothesis, and the `prediction` is "Watermarked", then we report 1-`p value` to represent
|
367 |
+
the confidence of the detection based on the unlikeliness of this `z-score` observation.
|
368 |
+
"""
|
369 |
+
)
|
370 |
+
|
371 |
+
with gr.Accordion("A note on model capability",open=True):
|
372 |
+
gr.Markdown(
|
373 |
+
"""
|
374 |
+
This demo uses open-source language models that fit on a single GPU. These models are less powerful than proprietary commercial tools like ChatGPT, Claude, or Bard.
|
375 |
+
|
376 |
+
Importantly, we use a language model that is designed to "complete" your prompt, and not a model this is fine-tuned to follow instructions.
|
377 |
+
For best results, prompt the model with a few sentences that form the beginning of a paragraph, and then allow it to "continue" your paragraph.
|
378 |
+
Some examples include the opening paragraph of a wikipedia article, or the first few sentences of a story.
|
379 |
+
Longer prompts that end mid-sentence will result in more fluent generations.
|
380 |
+
"""
|
381 |
+
)
|
382 |
+
gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
|
383 |
+
|
384 |
+
# Construct state for parameters, define updates and toggles
|
385 |
+
default_prompt = args.__dict__.pop("default_prompt")
|
386 |
+
session_args = gr.State(value=args)
|
387 |
+
|
388 |
+
with gr.Tab("Generate and Detect"):
|
389 |
+
|
390 |
+
with gr.Row():
|
391 |
+
prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
|
392 |
+
with gr.Row():
|
393 |
+
generate_btn = gr.Button("Generate")
|
394 |
+
with gr.Row():
|
395 |
+
with gr.Column(scale=2):
|
396 |
+
output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=14,max_lines=14)
|
397 |
+
with gr.Column(scale=1):
|
398 |
+
# without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
|
399 |
+
without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
|
400 |
+
with gr.Row():
|
401 |
+
with gr.Column(scale=2):
|
402 |
+
output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=14,max_lines=14)
|
403 |
+
with gr.Column(scale=1):
|
404 |
+
# with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
|
405 |
+
with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
|
406 |
+
|
407 |
+
redecoded_input = gr.Textbox(visible=False)
|
408 |
+
truncation_warning = gr.Number(visible=False)
|
409 |
+
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
|
410 |
+
if truncation_warning:
|
411 |
+
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
|
412 |
+
else:
|
413 |
+
return orig_prompt, args
|
414 |
+
|
415 |
+
with gr.Tab("Detector Only"):
|
416 |
+
with gr.Row():
|
417 |
+
with gr.Column(scale=2):
|
418 |
+
detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14)
|
419 |
+
with gr.Column(scale=1):
|
420 |
+
# detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
|
421 |
+
detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
|
422 |
+
with gr.Row():
|
423 |
+
detect_btn = gr.Button("Detect")
|
424 |
+
|
425 |
+
# Parameter selection group
|
426 |
+
with gr.Accordion("Advanced Settings",open=False):
|
427 |
+
with gr.Row():
|
428 |
+
with gr.Column(scale=1):
|
429 |
+
gr.Markdown(f"#### Generation Parameters")
|
430 |
+
with gr.Row():
|
431 |
+
decoding = gr.Radio(label="Decoding Method",choices=["multinomial", "greedy"], value=("multinomial" if args.use_sampling else "greedy"))
|
432 |
+
with gr.Row():
|
433 |
+
sampling_temp = gr.Slider(label="Sampling Temperature", minimum=0.1, maximum=1.0, step=0.1, value=args.sampling_temp, visible=True)
|
434 |
+
with gr.Row():
|
435 |
+
generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
|
436 |
+
with gr.Row():
|
437 |
+
n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=(not args.use_sampling))
|
438 |
+
with gr.Row():
|
439 |
+
max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
|
440 |
+
|
441 |
+
with gr.Column(scale=1):
|
442 |
+
gr.Markdown(f"#### Watermark Parameters")
|
443 |
+
with gr.Row():
|
444 |
+
gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
|
445 |
+
with gr.Row():
|
446 |
+
delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
|
447 |
+
gr.Markdown(f"#### Detector Parameters")
|
448 |
+
with gr.Row():
|
449 |
+
detection_z_threshold = gr.Slider(label="z-score threshold",minimum=0.0, maximum=10.0, step=0.1, value=args.detection_z_threshold)
|
450 |
+
with gr.Row():
|
451 |
+
ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
|
452 |
+
with gr.Row():
|
453 |
+
normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
|
454 |
+
# with gr.Accordion("Actual submitted parameters:",open=False):
|
455 |
+
with gr.Row():
|
456 |
+
gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._")
|
457 |
+
with gr.Row():
|
458 |
+
current_parameters = gr.Textbox(label="Current Parameters", value=args)
|
459 |
+
with gr.Accordion("Legacy Settings",open=False):
|
460 |
+
with gr.Row():
|
461 |
+
with gr.Column(scale=1):
|
462 |
+
seed_separately = gr.Checkbox(label="Seed both generations separately", value=args.seed_separately)
|
463 |
+
with gr.Column(scale=1):
|
464 |
+
select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens)
|
465 |
+
|
466 |
+
with gr.Accordion("Understanding the settings",open=False):
|
467 |
+
gr.Markdown(
|
468 |
+
"""
|
469 |
+
#### Generation Parameters:
|
470 |
+
|
471 |
+
- Decoding Method : We can generate tokens from the model using either multinomial sampling or we can generate using greedy decoding.
|
472 |
+
- Sampling Temperature : If using multinomial sampling we can set the temperature of the sampling distribution.
|
473 |
+
0.0 is equivalent to greedy decoding, and 1.0 is the maximum amount of variability/entropy in the next token distribution.
|
474 |
+
0.7 strikes a nice balance between faithfulness to the model's estimate of top candidates while adding variety. Does not apply for greedy decoding.
|
475 |
+
- Generation Seed : The integer to pass to the torch random number generator before running generation. Makes the multinomial sampling strategy
|
476 |
+
outputs reproducible. Does not apply for greedy decoding.
|
477 |
+
- Number of Beams : When using greedy decoding, we can also set the number of beams to > 1 to enable beam search.
|
478 |
+
This is not implemented/excluded from paper for multinomial sampling but may be added in future.
|
479 |
+
- Max Generated Tokens : The `max_new_tokens` parameter passed to the generation method to stop the output at a certain number of new tokens.
|
480 |
+
Note that the model is free to generate fewer tokens depending on the prompt.
|
481 |
+
Implicitly this sets the maximum number of prompt tokens possible as the model's maximum input length minus `max_new_tokens`,
|
482 |
+
and inputs will be truncated accordingly.
|
483 |
+
|
484 |
+
#### Watermark Parameters:
|
485 |
+
|
486 |
+
- gamma : The fraction of the vocabulary to be partitioned into the greenlist at each generation step.
|
487 |
+
Smaller gamma values create a stronger watermark by enabling the watermarked model to achieve
|
488 |
+
a greater differentiation from human/unwatermarked text because it is preferentially sampling
|
489 |
+
from a smaller green set making those tokens less likely to occur by chance.
|
490 |
+
- delta : The amount of positive bias to add to the logits of every token in the greenlist
|
491 |
+
at each generation step before sampling/choosing the next token. Higher delta values
|
492 |
+
mean that the greenlist tokens are more heavily preferred by the watermarked model
|
493 |
+
and as the bias becomes very large the watermark transitions from "soft" to "hard".
|
494 |
+
For a hard watermark, nearly all tokens are green, but this can have a detrimental effect on
|
495 |
+
generation quality, especially when there is not a lot of flexibility in the distribution.
|
496 |
+
|
497 |
+
#### Detector Parameters:
|
498 |
+
|
499 |
+
- z-score threshold : the z-score cuttoff for the hypothesis test. Higher thresholds (such as 4.0) make
|
500 |
+
_false positives_ (predicting that human/unwatermarked text is watermarked) very unlikely
|
501 |
+
as a genuine human text with a significant number of tokens will almost never achieve
|
502 |
+
that high of a z-score. Lower thresholds will capture more _true positives_ as some watermarked
|
503 |
+
texts will contain less green tokens and achive a lower z-score, but still pass the lower bar and
|
504 |
+
be flagged as "watermarked". However, a lowere threshold will increase the chance that human text
|
505 |
+
that contains a slightly higher than average number of green tokens is erroneously flagged.
|
506 |
+
4.0-5.0 offers extremely low false positive rates while still accurately catching most watermarked text.
|
507 |
+
- Ignore Bigram Repeats : This alternate detection algorithm only considers the unique bigrams in the text during detection,
|
508 |
+
computing the greenlists based on the first in each pair and checking whether the second falls within the list.
|
509 |
+
This means that `T` is now the unique number of bigrams in the text, which becomes less than the total
|
510 |
+
number of tokens generated if the text contains a lot of repetition. See the paper for a more detailed discussion.
|
511 |
+
- Normalizations : we implement a few basic normaliations to defend against various adversarial perturbations of the
|
512 |
+
text analyzed during detection. Currently we support converting all chracters to unicode,
|
513 |
+
replacing homoglyphs with a canonical form, and standardizing the capitalization.
|
514 |
+
See the paper for a detailed discussion of input normalization.
|
515 |
+
"""
|
516 |
+
)
|
517 |
+
|
518 |
+
gr.HTML("""
|
519 |
+
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
520 |
+
Follow the github link at the top and host the demo on your own GPU hardware to test out larger models.
|
521 |
+
<br/>
|
522 |
+
<a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true">
|
523 |
+
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
524 |
+
<p/>
|
525 |
+
""")
|
526 |
+
|
527 |
+
# Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
|
528 |
+
generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
|
529 |
+
# Show truncated version of prompt if truncation occurred
|
530 |
+
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
|
531 |
+
# Call detection when the outputs (of the generate function) are updated
|
532 |
+
output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
|
533 |
+
output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
|
534 |
+
# Register main detection tab click
|
535 |
+
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
|
536 |
+
|
537 |
+
# State management logic
|
538 |
+
# update callbacks that change the state dict
|
539 |
+
def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
|
540 |
+
def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
|
541 |
+
def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
|
542 |
+
def update_delta(session_state, value): session_state.delta = float(value); return session_state
|
543 |
+
def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value); return session_state
|
544 |
+
def update_decoding(session_state, value):
|
545 |
+
if value == "multinomial":
|
546 |
+
session_state.use_sampling = True
|
547 |
+
elif value == "greedy":
|
548 |
+
session_state.use_sampling = False
|
549 |
+
return session_state
|
550 |
+
def toggle_sampling_vis(value):
|
551 |
+
if value == "multinomial":
|
552 |
+
return gr.update(visible=True)
|
553 |
+
elif value == "greedy":
|
554 |
+
return gr.update(visible=False)
|
555 |
+
def toggle_sampling_vis_inv(value):
|
556 |
+
if value == "multinomial":
|
557 |
+
return gr.update(visible=False)
|
558 |
+
elif value == "greedy":
|
559 |
+
return gr.update(visible=True)
|
560 |
+
def update_n_beams(session_state, value): session_state.n_beams = value; return session_state
|
561 |
+
def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
|
562 |
+
def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
|
563 |
+
def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
|
564 |
+
def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
|
565 |
+
def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
|
566 |
+
# registering callbacks for toggling the visibilty of certain parameters
|
567 |
+
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
|
568 |
+
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
|
569 |
+
decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
|
570 |
+
# registering all state update callbacks
|
571 |
+
decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
|
572 |
+
sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
|
573 |
+
generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
|
574 |
+
n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
|
575 |
+
max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
|
576 |
+
gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
|
577 |
+
delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
|
578 |
+
detection_z_threshold.change(update_detection_z_threshold,inputs=[session_args, detection_z_threshold], outputs=[session_args])
|
579 |
+
ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
|
580 |
+
normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
|
581 |
+
seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args])
|
582 |
+
select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args])
|
583 |
+
# register additional callback on button clicks that updates the shown parameters window
|
584 |
+
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
585 |
+
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
586 |
+
# When the parameters change, display the update and fire detection, since some detection params dont change the model output.
|
587 |
+
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
588 |
+
gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
|
589 |
+
gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
|
590 |
+
gamma.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
|
591 |
+
detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
592 |
+
detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
|
593 |
+
detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
|
594 |
+
detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
|
595 |
+
ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
596 |
+
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
|
597 |
+
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
|
598 |
+
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
|
599 |
+
normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
600 |
+
normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
|
601 |
+
normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
|
602 |
+
normalizers.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
|
603 |
+
select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
604 |
+
select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
|
605 |
+
select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
|
606 |
+
select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
|
607 |
+
|
608 |
+
|
609 |
+
demo.queue(concurrency_count=3)
|
610 |
+
|
611 |
+
if args.demo_public:
|
612 |
+
demo.launch(share=True) # exposes app to the internet via randomly generated link
|
613 |
+
else:
|
614 |
+
demo.launch()
|
615 |
+
|
616 |
+
def main(args):
|
617 |
+
"""Run a command line version of the generation and detection operations
|
618 |
+
and optionally launch and serve the gradio demo"""
|
619 |
+
# Initial arg processing and log
|
620 |
+
args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
|
621 |
+
print(args)
|
622 |
+
|
623 |
+
if not args.skip_model_load:
|
624 |
+
model, tokenizer, device = load_model(args)
|
625 |
+
else:
|
626 |
+
model, tokenizer, device = None, None, None
|
627 |
+
|
628 |
+
# Generate and detect, report to stdout
|
629 |
+
if not args.skip_model_load:
|
630 |
+
input_text = (
|
631 |
+
"The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a "
|
632 |
+
"species of turtle native to the brackish coastal tidal marshes of the "
|
633 |
+
"Northeastern and southern United States, and in Bermuda.[6] It belongs "
|
634 |
+
"to the monotypic genus Malaclemys. It has one of the largest ranges of "
|
635 |
+
"all turtles in North America, stretching as far south as the Florida Keys "
|
636 |
+
"and as far north as Cape Cod.[7] The name 'terrapin' is derived from the "
|
637 |
+
"Algonquian word torope.[8] It applies to Malaclemys terrapin in both "
|
638 |
+
"British English and American English. The name originally was used by "
|
639 |
+
"early European settlers in North America to describe these brackish-water "
|
640 |
+
"turtles that inhabited neither freshwater habitats nor the sea. It retains "
|
641 |
+
"this primary meaning in American English.[8] In British English, however, "
|
642 |
+
"other semi-aquatic turtle species, such as the red-eared slider, might "
|
643 |
+
"also be called terrapins. The common name refers to the diamond pattern "
|
644 |
+
"on top of its shell (carapace), but the overall pattern and coloration "
|
645 |
+
"vary greatly. The shell is usually wider at the back than in the front, "
|
646 |
+
"and from above it appears wedge-shaped. The shell coloring can vary "
|
647 |
+
"from brown to grey, and its body color can be grey, brown, yellow, "
|
648 |
+
"or white. All have a unique pattern of wiggly, black markings or spots "
|
649 |
+
"on their body and head. The diamondback terrapin has large webbed "
|
650 |
+
"feet.[9] The species is"
|
651 |
+
)
|
652 |
+
|
653 |
+
args.default_prompt = input_text
|
654 |
+
|
655 |
+
term_width = 80
|
656 |
+
print("#"*term_width)
|
657 |
+
print("Prompt:")
|
658 |
+
print(input_text)
|
659 |
+
|
660 |
+
_, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
|
661 |
+
args,
|
662 |
+
model=model,
|
663 |
+
device=device,
|
664 |
+
tokenizer=tokenizer)
|
665 |
+
without_watermark_detection_result = detect(decoded_output_without_watermark,
|
666 |
+
args,
|
667 |
+
device=device,
|
668 |
+
tokenizer=tokenizer)
|
669 |
+
with_watermark_detection_result = detect(decoded_output_with_watermark,
|
670 |
+
args,
|
671 |
+
device=device,
|
672 |
+
tokenizer=tokenizer)
|
673 |
+
|
674 |
+
print("#"*term_width)
|
675 |
+
print("Output without watermark:")
|
676 |
+
print(decoded_output_without_watermark)
|
677 |
+
print("-"*term_width)
|
678 |
+
print(f"Detection result @ {args.detection_z_threshold}:")
|
679 |
+
pprint(without_watermark_detection_result)
|
680 |
+
print("-"*term_width)
|
681 |
+
|
682 |
+
print("#"*term_width)
|
683 |
+
print("Output with watermark:")
|
684 |
+
print(decoded_output_with_watermark)
|
685 |
+
print("-"*term_width)
|
686 |
+
print(f"Detection result @ {args.detection_z_threshold}:")
|
687 |
+
pprint(with_watermark_detection_result)
|
688 |
+
print("-"*term_width)
|
689 |
+
|
690 |
+
|
691 |
+
# Launch the app to generate and detect interactively (implements the hf space demo)
|
692 |
+
if args.run_gradio:
|
693 |
+
run_gradio(args, model=model, tokenizer=tokenizer, device=device)
|
694 |
+
|
695 |
+
return
|
696 |
+
|
697 |
+
if __name__ == "__main__":
|
698 |
+
|
699 |
+
args = parse_args()
|
700 |
+
print(args)
|
701 |
+
|
702 |
+
main(args)
|
lm-watermarking-main/experiments/README.md
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## [Stale/Deprecated] Experimental Pipeline Code
|
2 |
+
|
3 |
+
This subdirectory contains reproducibility artifacts for the experiments described in the paper. All code here is deprecated in favor of the implementation and demo in the root of the repository.
|
4 |
+
|
5 |
+
In effect, the file `/watermark_processor.py` in the root of the repo, is a clean, user friendly reimplementation of the watermarking and detection logic from `watermark.py`. We suggest using the official release version over any code found in the `experiments` directory.
|
6 |
+
|
7 |
+
## Overview
|
8 |
+
|
9 |
+
Unless stated, all files discussed here are in the `experiments` directory. The `bl` naming convention across many variables and function definition refers to "blacklist". Black/white was the original language used in the development of the paper and was updated to green/red based on feed back from the community.
|
10 |
+
|
11 |
+
The implementation for the main experiments in the paper have two high level steps:
|
12 |
+
- **(1) generate watermarked samples**
|
13 |
+
- **(2) compute metrics**
|
14 |
+
|
15 |
+
The code provided here implements these steps in the following files: `run_watermarking.py` and `process_rows.py`, where the core logic is implemented in `watermark.py` a single file library.
|
16 |
+
|
17 |
+
Generally speaking, the code implementing the watermark itself is a series of classes and functions based on the `LogitsProcessor` abstraction from [huggingface/transformers](https://github.com/huggingface/transformers) and the code that turns it into a workflow is based on the `dataset.map` functionality from [huggingface/datasets](https://github.com/huggingface/datasets).
|
18 |
+
|
19 |
+
The files `io_utils.py`, `submitit_utils.py` and `launch.py` contain utilites for file operations (mostly `jsonl`) and for hyperparameter sweeping via jobs launched on our compute cluster (managed using [SLURM](https://slurm.schedmd.com/documentation.html)). The [`submitit`](https://github.com/facebookincubator/submitit) workflow tool is an extra dependency only required if using `launch.py`.
|
20 |
+
|
21 |
+
## Generation (`run_watermarking.py`)
|
22 |
+
|
23 |
+
`run_watermarking.py` is a command line script that:
|
24 |
+
|
25 |
+
1. loads a huggingface `dataset` that will be used to create text prompts for the language model
|
26 |
+
2. loads a huggingface language model that can perform text generation via `model.generate`, and prepares to call the generation method with a special `LogitsProcessor` that implements watermarking at the current hyperparameter values
|
27 |
+
3. composes a series of functions that are applied to the dataset via `map` that preprocess and tokenize the prompt data, and generate completions to it via the model
|
28 |
+
4. loads a second huggingface language model to be used as perplexity "oracle" for evaluating the quality of the texts generated by the watermarked model
|
29 |
+
5. Computes the teacher-forced loss (and perplexity) of the oracle model on the generated outputs
|
30 |
+
|
31 |
+
Here is an example of the argument set required to implement a single (representative) hyperparameter combination from the paper:
|
32 |
+
|
33 |
+
```
|
34 |
+
python run_watermarking.py \
|
35 |
+
--model_name facebook/opt-1.3b \
|
36 |
+
--dataset_name c4 \
|
37 |
+
--dataset_config_name realnewslike \
|
38 |
+
--max_new_tokens 200 \
|
39 |
+
--min_prompt_tokens 50 \
|
40 |
+
--limit_indices 500 \
|
41 |
+
--input_truncation_strategy completion_length \
|
42 |
+
--input_filtering_strategy prompt_and_completion_length \
|
43 |
+
--output_filtering_strategy max_new_tokens \
|
44 |
+
--dynamic_seed markov_1 \
|
45 |
+
--bl_proportion 0.5 \
|
46 |
+
--bl_logit_bias 2.0 \
|
47 |
+
--bl_type soft \
|
48 |
+
--store_spike_ents True \
|
49 |
+
--num_beams 1 \
|
50 |
+
--use_sampling True \
|
51 |
+
--sampling_temp 0.7
|
52 |
+
--oracle_model_name facebook/opt-2.7b \
|
53 |
+
--run_name example_run \
|
54 |
+
--output_dir ./all_runs \
|
55 |
+
```
|
56 |
+
|
57 |
+
The result of each run is a directory with three files in it:
|
58 |
+
- `gen_table_meta.json` (hyperparameters passed from cmdline)
|
59 |
+
- `gen_table.jsonl`
|
60 |
+
- `gen_table_w_metrics.jsonl`
|
61 |
+
|
62 |
+
`gen_table_w_metrics`="generation table with metrics" meaning that it is the same as the first `jsonl` file in the lines/row dimension, but contains more columns/features, such as perplexity.
|
63 |
+
|
64 |
+
If you run multiple hyperparameter combinations, we suggest storing each of the run directories with those output files within one enclosing directory such as `all_runs` to facilitate the next step.
|
65 |
+
|
66 |
+
## Computing Metrics (`process_rows.py`)
|
67 |
+
.. and merging hyperparameter runs by concatenation.
|
68 |
+
|
69 |
+
After running a few combinations of hyperparameters (individual runs of the `run_watermarking.py` script), the result is a bunch of directories, each containing a file full of model outputs (`gen_table_w_metrics.jsonl`).
|
70 |
+
|
71 |
+
To prepare to analyze the performance of the watermark, we enrich each one of these generation sets with more metrics and derived features. The script that accomplishes this is `process_rows.py` - each prompt, output pair is considered a "row".
|
72 |
+
|
73 |
+
The script isn't fully command line parameterized, but inside you can see that the main method looks into a directory (such as the `all_runs` suggested above) and collects all of the sub dirs that contain `gen_table_w_metrics.jsonl` files. Each set of generations is reloaded from `jsonl` into a huggingface `Dataset` object so that a metric computation function `compute_bl_metrics` can be applied to it.
|
74 |
+
|
75 |
+
This adds the critical fields like `w_bl_whitelist_fraction` which represent the raw measurement of the watermark presence. In the final analysis step, this is used compute a z-score and perform the detection hypothesis test.
|
76 |
+
|
77 |
+
**_Note_**: to clarify explicitly, `compute_bl_metrics` is therefore the old "detection" step of the pipeline. In this earlier version, there was no dedicated sub/class structure to share the logic of the watermark between a generation object and a detector object. It was just located within the `score_sequence` function of the `watermark.py` file.
|
78 |
+
|
79 |
+
The final step in `process_rows.py` is a concatenation of these results. Each `gen_table_w_metrics.jsonl` from a hyperparameter run (within an `all_runs`) is transformed into a new dataset with the watermark detection measurement, and then all of these dataset objects are concatenated in the row dimension, forming one large dataset that has the generations and metrics from all of the different hyperparameter settings that were run.
|
80 |
+
|
81 |
+
This object is shaped like (rows,columns) where samples=rows, and features=columns, and for the paper it had a size ~ (3e4,25) since there were about 30 to 40 hyperparameter settings and between 500-1000 generations per setting. Huggingface datasets conveniently implements a `dataset.to_pandas()` function and this allows us to treat this result as a dataframe and slice and dice it however we like during the analysis phase.
|
82 |
+
|
83 |
+
|
84 |
+
## Analysis
|
85 |
+
|
86 |
+
The result of the above steps is a somewhat standard "datascience" format, a `pandas.DataFrame` and we suggest that you analyze it in whatever way you see fit. Since this part was very interactive and exploratory, there isn't a stable script version of this stage.
|
87 |
+
|
88 |
+
That said, the analysis code is in a notebook called `watermarking_analysis.ipynb`. Unfortunately, this notebook is monolithic. Pointers have been indicated as to which parts produce which figures. However, at this time, there is not a way to click once/run all and generate every chart and table from the paper.
|
89 |
+
|
90 |
+
A second notebook `watermarking_example_finding.ipynb` is solely for extracting some actual text prompts and outputs for tabulation in the paper.
|
91 |
+
|
lm-watermarking-main/experiments/io_utils.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
from typing import Any, Mapping, Iterable, Union, List, Callable, Optional
|
6 |
+
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def resolve_globs(glob_paths: Union[str, Iterable[str]]):
|
11 |
+
"""Returns filepaths corresponding to input filepath pattern(s)."""
|
12 |
+
filepaths = []
|
13 |
+
if isinstance(glob_paths, str):
|
14 |
+
glob_paths = [glob_paths]
|
15 |
+
|
16 |
+
for path in glob_paths:
|
17 |
+
filepaths.extend(glob.glob(path))
|
18 |
+
|
19 |
+
return filepaths
|
20 |
+
|
21 |
+
|
22 |
+
def read_jsonlines(filename: str) -> Iterable[Mapping[str, Any]]:
|
23 |
+
"""Yields an iterable of Python dicts after reading jsonlines from the input file."""
|
24 |
+
file_size = os.path.getsize(filename)
|
25 |
+
with open(filename) as fp:
|
26 |
+
for line in tqdm(fp.readlines(), desc=f'Reading JSON lines from {filename}', unit='lines'):
|
27 |
+
try:
|
28 |
+
example = json.loads(line)
|
29 |
+
yield example
|
30 |
+
except json.JSONDecodeError as ex:
|
31 |
+
logging.error(f'Input text: "{line}"')
|
32 |
+
logging.error(ex.args)
|
33 |
+
raise ex
|
34 |
+
|
35 |
+
|
36 |
+
def hf_read_jsonlines(filename: str,
|
37 |
+
n: Optional[int]=None,
|
38 |
+
minimal_questions: Optional[bool]=False,
|
39 |
+
unique_questions: Optional[bool] = False) -> Iterable[Mapping[str, Any]]:
|
40 |
+
"""Yields an iterable of Python dicts after reading jsonlines from the input file.
|
41 |
+
Optionally reads only first n lines from file."""
|
42 |
+
file_size = os.path.getsize(filename)
|
43 |
+
# O(n) but no memory
|
44 |
+
with open(filename) as f:
|
45 |
+
num_lines= sum(1 for _ in f)
|
46 |
+
if n is None:
|
47 |
+
n = num_lines
|
48 |
+
|
49 |
+
# returning a generator with the scope stmt seemed to be the issue, but I am not 100% sure
|
50 |
+
# I also don't know if there's a side effect, but I can't see how the scope wouldn't have
|
51 |
+
# remained upen in the first place with the original version...
|
52 |
+
# with open(filename) as fp:
|
53 |
+
def line_generator():
|
54 |
+
unique_qc_ids = set()
|
55 |
+
# note, I am p sure that readlines is not lazy, returns a list, thus really only the
|
56 |
+
# object conversion is lazy
|
57 |
+
for i, line in tqdm(enumerate(open(filename).readlines()[:n]), desc=f'Reading JSON lines from {filename}', unit='lines'):
|
58 |
+
try:
|
59 |
+
full_example = json.loads(line)
|
60 |
+
|
61 |
+
if unique_questions:
|
62 |
+
qc_id = full_example["object"]["qc_id"]
|
63 |
+
if qc_id in unique_qc_ids:
|
64 |
+
continue
|
65 |
+
else:
|
66 |
+
unique_qc_ids.add(qc_id)
|
67 |
+
|
68 |
+
if not minimal_questions:
|
69 |
+
example = full_example
|
70 |
+
else:
|
71 |
+
full_example = full_example
|
72 |
+
q_object = full_example["object"]
|
73 |
+
q_object.pop("question_info")
|
74 |
+
example= {}
|
75 |
+
example["object"] = {
|
76 |
+
"answer":q_object["answer"],
|
77 |
+
"clue_spans":q_object["clue_spans"],
|
78 |
+
"qc_id":q_object["qc_id"],
|
79 |
+
"question_text":q_object["question_text"],
|
80 |
+
}
|
81 |
+
yield example
|
82 |
+
|
83 |
+
except json.JSONDecodeError as ex:
|
84 |
+
logging.error(f'Input text: "{line}"')
|
85 |
+
logging.error(ex.args)
|
86 |
+
raise ex
|
87 |
+
return line_generator
|
88 |
+
|
89 |
+
|
90 |
+
def load_jsonlines(filename: str) -> List[Mapping[str, Any]]:
|
91 |
+
"""Returns a list of Python dicts after reading jsonlines from the input file."""
|
92 |
+
return list(read_jsonlines(filename))
|
93 |
+
|
94 |
+
|
95 |
+
def write_jsonlines(objs: Iterable[Mapping[str, Any]], filename: str, to_dict: Callable = lambda x: x):
|
96 |
+
"""Writes a list of Python Mappings as jsonlines at the input file."""
|
97 |
+
with open(filename, 'w') as fp:
|
98 |
+
for obj in tqdm(objs, desc=f'Writing JSON lines at {filename}'):
|
99 |
+
fp.write(json.dumps(to_dict(obj)))
|
100 |
+
fp.write('\n')
|
101 |
+
|
102 |
+
|
103 |
+
def read_json(filename: str) -> Mapping[str, Any]:
|
104 |
+
"""Returns a Python dict representation of JSON object at input file."""
|
105 |
+
with open(filename) as fp:
|
106 |
+
return json.load(fp)
|
107 |
+
|
108 |
+
|
109 |
+
def write_json(obj: Mapping[str, Any], filename: str, indent:int=None):
|
110 |
+
"""Writes a Python Mapping at the input file in JSON format."""
|
111 |
+
with open(filename, 'w') as fp:
|
112 |
+
json.dump(obj, fp, indent=indent)
|
113 |
+
|
114 |
+
|
115 |
+
def print_json(d, indent=4):
|
116 |
+
print(json.dumps(d, indent=indent))
|
lm-watermarking-main/experiments/launch.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from submitit import AutoExecutor
|
2 |
+
from submitit.helpers import CommandFunction
|
3 |
+
from itertools import chain
|
4 |
+
import os
|
5 |
+
from submitit_utils import ParameterGrid
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
# a debug/dry-run command
|
9 |
+
dummy_func = CommandFunction(["echo"], verbose=True)
|
10 |
+
|
11 |
+
###############################################################################
|
12 |
+
# Experiment specific command and parameter setup
|
13 |
+
# (the structure is general, but the values are not)
|
14 |
+
###############################################################################
|
15 |
+
|
16 |
+
base_run_name = None
|
17 |
+
|
18 |
+
ROOT_DIR = f'{os.getenv("ROOT_DIR")}'
|
19 |
+
# OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}'
|
20 |
+
# OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}_large_sweep'
|
21 |
+
# OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}_large_sweep_downsize'
|
22 |
+
# OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}_greedy_redo'
|
23 |
+
OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}_greedy_more_gammas'
|
24 |
+
|
25 |
+
# starting command/program to which we will append arguments
|
26 |
+
cmdline_function = CommandFunction(["python"], verbose=True)
|
27 |
+
|
28 |
+
# script name
|
29 |
+
script_name = "run_watermarking.py"
|
30 |
+
|
31 |
+
# base args
|
32 |
+
base_script_args = {
|
33 |
+
# "model_name" :"facebook/opt-2.7b",
|
34 |
+
"model_name" :"facebook/opt-1.3b",
|
35 |
+
"dataset_name" :"c4",
|
36 |
+
"dataset_config_name":"realnewslike",
|
37 |
+
# "dataset_config_name":"en",
|
38 |
+
# "dataset_name": "cml_pile",
|
39 |
+
# "dataset_config_name": "all_train_00",
|
40 |
+
# "shuffle_dataset" :"True", # NOTE
|
41 |
+
"dynamic_seed" :"markov_1",
|
42 |
+
"store_spike_ents" :"True",
|
43 |
+
# "oracle_model_name" :"EleutherAI/gpt-j-6B",
|
44 |
+
"oracle_model_name" :"facebook/opt-2.7b",
|
45 |
+
"no_wandb" :"False",
|
46 |
+
}
|
47 |
+
|
48 |
+
# dynamic/hparam args
|
49 |
+
# i.e. the parameters we would like to cross and sweep over
|
50 |
+
hparam_sets = [
|
51 |
+
# # main sampling sweep, central data
|
52 |
+
# {
|
53 |
+
# "min_prompt_tokens": [50],
|
54 |
+
# "max_new_tokens": [200],
|
55 |
+
# "input_truncation_strategy": ["completion_length"],
|
56 |
+
# "input_filtering_strategy": ["prompt_and_completion_length"],
|
57 |
+
# "output_filtering_strategy": ["max_new_tokens"],
|
58 |
+
# "limit_indices": [500],
|
59 |
+
# "bl_logit_bias": [0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0],
|
60 |
+
# "bl_proportion": [0.1, 0.25, 0.5, 0.75, 0.9],
|
61 |
+
# "bl_type": ["soft"],
|
62 |
+
# "num_beams": [1],
|
63 |
+
# "use_sampling": [True],
|
64 |
+
# "sampling_temp": [0.7],
|
65 |
+
# },
|
66 |
+
# greedy and beams secondary demos
|
67 |
+
# {
|
68 |
+
# "min_sample_tokens":[0],
|
69 |
+
# "min_prompt_tokens": [200],
|
70 |
+
# "max_new_tokens": [500],
|
71 |
+
# "all_gas_no_eos": [True],
|
72 |
+
# "input_truncation_strategy": ["prompt_length"],
|
73 |
+
# "input_filtering_strategy": ["prompt_and_completion_length"],
|
74 |
+
# "output_filtering_strategy": ["no_filter"],
|
75 |
+
# "limit_indices": [500],
|
76 |
+
# "bl_logit_bias": [0.1, 0.5, 1.0, 2.0, 5.0, 10.0],
|
77 |
+
# "bl_proportion": [0.5],
|
78 |
+
# "bl_type": ["soft"],
|
79 |
+
# "num_beams": [1],
|
80 |
+
# "use_sampling": [False],
|
81 |
+
# "sampling_temp": [0.0],
|
82 |
+
# },
|
83 |
+
# {
|
84 |
+
# "min_sample_tokens":[0],
|
85 |
+
# "min_prompt_tokens": [200],
|
86 |
+
# "max_new_tokens": [500],
|
87 |
+
# "all_gas_no_eos": [True],
|
88 |
+
# "no_repeat_ngram_size": [0],
|
89 |
+
# "input_truncation_strategy": ["prompt_length"],
|
90 |
+
# "input_filtering_strategy": ["prompt_and_completion_length"],
|
91 |
+
# "output_filtering_strategy": ["no_filter"],
|
92 |
+
# "limit_indices": [500],
|
93 |
+
# "bl_logit_bias": [0.1, 0.5, 1.0, 2.0, 5.0, 10.0],
|
94 |
+
# "bl_proportion": [0.5],
|
95 |
+
# "bl_type": ["soft"],
|
96 |
+
# "num_beams": [4],
|
97 |
+
# "use_sampling": [False],
|
98 |
+
# "sampling_temp": [0.0],
|
99 |
+
# },
|
100 |
+
{
|
101 |
+
"min_sample_tokens":[0],
|
102 |
+
"min_prompt_tokens": [200],
|
103 |
+
"max_new_tokens": [500],
|
104 |
+
"all_gas_no_eos": [True],
|
105 |
+
"no_repeat_ngram_size": [0],
|
106 |
+
"input_truncation_strategy": ["prompt_length"],
|
107 |
+
"input_filtering_strategy": ["prompt_and_completion_length"],
|
108 |
+
"output_filtering_strategy": ["no_filter"],
|
109 |
+
"limit_indices": [500],
|
110 |
+
"bl_logit_bias": [0.1, 0.5, 1.0, 2.0, 5.0, 10.0],
|
111 |
+
# "bl_logit_bias": [2.0, 5.0, 10.0],
|
112 |
+
# "bl_proportion": [0.5],
|
113 |
+
# "bl_proportion": [0.75],
|
114 |
+
"bl_proportion": [0.9],
|
115 |
+
"bl_type": ["soft"],
|
116 |
+
"num_beams": [8],
|
117 |
+
"use_sampling": [False],
|
118 |
+
"sampling_temp": [0.0],
|
119 |
+
},
|
120 |
+
############
|
121 |
+
]
|
122 |
+
|
123 |
+
# logic to set derived arguments based on existing arguments in the sweep sets
|
124 |
+
# the unique run name is the canonical example
|
125 |
+
def add_conditional_params(param_dict):
|
126 |
+
|
127 |
+
# unique_name = f'{base_run_name+"_" if base_run_name else ""}{param_dict.get("model_name")}_{param_dict.get("dataset_name")}_{param_dict.get("dataset_config_name")}'
|
128 |
+
unique_name_keys = ["model_name",
|
129 |
+
"bl_type",
|
130 |
+
"dynamic_seed",
|
131 |
+
"bl_proportion",
|
132 |
+
"bl_logit_bias",
|
133 |
+
"num_beams",
|
134 |
+
"use_sampling",
|
135 |
+
"sampling_temp",
|
136 |
+
"dataset_name",
|
137 |
+
"dataset_config_name",
|
138 |
+
"min_prompt_tokens",
|
139 |
+
"max_new_tokens",
|
140 |
+
"input_truncation_strategy",
|
141 |
+
"input_filtering_strategy",
|
142 |
+
"output_filtering_strategy",
|
143 |
+
"limit_indices",
|
144 |
+
"oracle_model_name"]
|
145 |
+
|
146 |
+
unique_name = f'{base_run_name+"_" if base_run_name else ""}{"_".join([str(param_dict.get(k)) for k in unique_name_keys])}'
|
147 |
+
unique_name = unique_name.replace("/", "-").replace(".","-")
|
148 |
+
param_dict.update({"run_name": unique_name})
|
149 |
+
param_dict.update({"output_dir": f'{OUTPUT_DIR}/{param_dict["run_name"]}'})
|
150 |
+
|
151 |
+
# Queue up all the arguments
|
152 |
+
def add_params(param_dicts):
|
153 |
+
new_dicts = []
|
154 |
+
for i, param_dict in enumerate(param_dicts):
|
155 |
+
new_dict = {}
|
156 |
+
|
157 |
+
new_dict.update({script_name : ""}) # This requires parse block change in submitit.core.utils.py L320
|
158 |
+
new_dict.update(base_script_args)
|
159 |
+
|
160 |
+
new_dict.update(param_dict)
|
161 |
+
add_conditional_params(new_dict)
|
162 |
+
|
163 |
+
new_dicts.append(new_dict)
|
164 |
+
return new_dicts
|
165 |
+
|
166 |
+
###############################################################################
|
167 |
+
# Generic submitit and slurm workflow
|
168 |
+
###############################################################################
|
169 |
+
|
170 |
+
# set up the executor and sbatch settings
|
171 |
+
# executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs/')
|
172 |
+
# executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs_large_sweep/')
|
173 |
+
# executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs_large_sweep_downsize/')
|
174 |
+
# executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs_greedy_redo/')
|
175 |
+
executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs_greedy_more_gammas/')
|
176 |
+
|
177 |
+
executor.update_parameters(
|
178 |
+
stderr_to_stdout=True,
|
179 |
+
slurm_name='water',
|
180 |
+
# slurm_account='tomg',
|
181 |
+
# slurm_qos='very_high',
|
182 |
+
# slurm_qos='high',
|
183 |
+
slurm_mem= '52gb',
|
184 |
+
slurm_gres='gpu:rtxa6000:1',
|
185 |
+
slurm_time='14:00:00',
|
186 |
+
slurm_account='scavenger',
|
187 |
+
slurm_partition='scavenger',
|
188 |
+
slurm_qos='scavenger',
|
189 |
+
# slurm_mem= '32gb',
|
190 |
+
# slurm_cpus_per_task=4,
|
191 |
+
# slurm_gres='gpu:rtxa5000:1',
|
192 |
+
# slurm_time='12:00:00',
|
193 |
+
)
|
194 |
+
|
195 |
+
# cross and line up parameter combinations
|
196 |
+
arg_dicts = list(chain(*(ParameterGrid(p_set) for p_set in hparam_sets)))
|
197 |
+
|
198 |
+
# set params and apply any extra param logic
|
199 |
+
arg_dicts = add_params(arg_dicts)
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
|
203 |
+
parser = argparse.ArgumentParser()
|
204 |
+
parser.add_argument(
|
205 |
+
"-d", "--dry_run",
|
206 |
+
action="store_true",
|
207 |
+
help="just echo the commands to be run",
|
208 |
+
)
|
209 |
+
args = parser.parse_args()
|
210 |
+
|
211 |
+
# context to make this loop/list comp execute an array job
|
212 |
+
# rather than individual jobs
|
213 |
+
with executor.batch():
|
214 |
+
|
215 |
+
if args.dry_run:
|
216 |
+
fn = dummy_func
|
217 |
+
else:
|
218 |
+
fn = cmdline_function
|
219 |
+
jobs = [executor.submit(fn, **arg_dict) for arg_dict in arg_dicts]
|
220 |
+
|
221 |
+
for job,args in zip(jobs, arg_dicts):
|
222 |
+
print(f"Job={job} | uid={args['run_name']}")
|
lm-watermarking-main/experiments/process_rows.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Basic imports
|
2 |
+
import os
|
3 |
+
from functools import partial
|
4 |
+
from argparse import Namespace
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# HF classses
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
|
11 |
+
from datasets import Dataset, concatenate_datasets
|
12 |
+
|
13 |
+
|
14 |
+
# watermarking micro lib
|
15 |
+
from watermark import (BlacklistLogitsProcessor,
|
16 |
+
compute_bl_metrics)
|
17 |
+
|
18 |
+
# some file i/o helpers
|
19 |
+
from io_utils import read_jsonlines, read_json
|
20 |
+
|
21 |
+
|
22 |
+
from watermark import compute_bl_metrics, BlacklistLogitsProcessor
|
23 |
+
|
24 |
+
|
25 |
+
###########################################################################
|
26 |
+
# Compute E[wl] for each example
|
27 |
+
###########################################################################
|
28 |
+
|
29 |
+
def expected_whitelist(example,
|
30 |
+
idx,
|
31 |
+
exp_wl_coef: float == None,
|
32 |
+
drop_spike_entropies: bool = False):
|
33 |
+
assert "spike_entropies" in example, "Need to construct bl processor with store_spike_ents=True to compute them in post"
|
34 |
+
|
35 |
+
num_toks_gend = example["w_bl_num_tokens_generated"]
|
36 |
+
avg_spike_ent = np.mean(example["spike_entropies"])
|
37 |
+
|
38 |
+
example.update({"avg_spike_entropy":avg_spike_ent})
|
39 |
+
if drop_spike_entropies: del example["spike_entropies"]
|
40 |
+
|
41 |
+
exp_num_wl = (exp_wl_coef*num_toks_gend)*avg_spike_ent
|
42 |
+
var_num_wl = num_toks_gend*exp_wl_coef*avg_spike_ent*(1-(exp_wl_coef*avg_spike_ent))
|
43 |
+
|
44 |
+
example.update({"w_bl_exp_num_wl_tokens":exp_num_wl})
|
45 |
+
example.update({"w_bl_var_num_wl_tokens":var_num_wl})
|
46 |
+
|
47 |
+
example.update({"exp_wl_coef":exp_wl_coef})
|
48 |
+
|
49 |
+
if num_toks_gend > 0:
|
50 |
+
example.update({"w_bl_exp_whitelist_fraction":exp_num_wl/num_toks_gend,
|
51 |
+
"w_bl_var_whitelist_fraction":var_num_wl/num_toks_gend})
|
52 |
+
else:
|
53 |
+
example.update({"w_bl_exp_whitelist_fraction":-1,
|
54 |
+
"w_bl_var_whitelist_fraction":-1})
|
55 |
+
return example
|
56 |
+
|
57 |
+
|
58 |
+
from typing import Callable
|
59 |
+
|
60 |
+
def add_metadata(ex, meta_table=None):
|
61 |
+
ex.update(meta_table)
|
62 |
+
return ex
|
63 |
+
|
64 |
+
|
65 |
+
def str_replace_bug_check(example,idx):
|
66 |
+
baseline_before = example["baseline_completion"]
|
67 |
+
example["baseline_completion"] = baseline_before.replace(example["truncated_input"][:-1],"")
|
68 |
+
if example["baseline_completion"] != baseline_before:
|
69 |
+
print("baseline input replacement bug occurred, skipping row!")
|
70 |
+
return False
|
71 |
+
else:
|
72 |
+
return True
|
73 |
+
|
74 |
+
|
75 |
+
def load_all_datasets(run_names: list[str]=None,
|
76 |
+
base_run_dir: str=None,
|
77 |
+
meta_name: str=None,
|
78 |
+
gen_name: str=None,
|
79 |
+
apply_metric_func: bool=False,
|
80 |
+
convert_to_pandas: bool = False,
|
81 |
+
drop_buggy_rows: bool = False,
|
82 |
+
limit_output_tokens: int = 0,
|
83 |
+
save_ds: bool = True,
|
84 |
+
save_dir: str=None):
|
85 |
+
|
86 |
+
print(f"Loading {len(run_names)} datasets from {base_run_dir}...")
|
87 |
+
|
88 |
+
if not isinstance(gen_name, Callable):
|
89 |
+
file_check = lambda name: os.path.exists(f"{base_run_dir}/{name}/{gen_name}")
|
90 |
+
assert all([file_check(name) for name in run_names]), f"Make sure all the run dirs contain the required data files: {meta_name} and {gen_name}"
|
91 |
+
|
92 |
+
all_datasets = []
|
93 |
+
for i,run_name in enumerate(run_names):
|
94 |
+
|
95 |
+
print(f"[{i}] Loading dataset")
|
96 |
+
|
97 |
+
run_base_dir = f"{base_run_dir}/{run_name}"
|
98 |
+
gen_table_meta_path = f"{run_base_dir}/{meta_name}"
|
99 |
+
|
100 |
+
if isinstance(gen_name, Callable):
|
101 |
+
gen_table_path = f"{run_base_dir}/{gen_name(run_name)}"
|
102 |
+
else:
|
103 |
+
gen_table_path = f"{run_base_dir}/{gen_name}"
|
104 |
+
|
105 |
+
# load the raw files
|
106 |
+
gen_table_meta = read_json(gen_table_meta_path)
|
107 |
+
gen_table_lst = [ex for ex in read_jsonlines(gen_table_path)]
|
108 |
+
gen_table_ds = Dataset.from_list(gen_table_lst)
|
109 |
+
|
110 |
+
print(f"Original dataset length={len(gen_table_ds)}")
|
111 |
+
|
112 |
+
# drop the rows where the string replace thing happens
|
113 |
+
if drop_buggy_rows:
|
114 |
+
gen_table_ds_filtered = gen_table_ds.filter(str_replace_bug_check,batched=False,with_indices=True)
|
115 |
+
else:
|
116 |
+
gen_table_ds_filtered = gen_table_ds
|
117 |
+
|
118 |
+
# enrich all rows with the run metadata
|
119 |
+
add_meta = partial(
|
120 |
+
add_metadata,
|
121 |
+
meta_table=gen_table_meta
|
122 |
+
)
|
123 |
+
gen_table_w_meta = gen_table_ds_filtered.map(add_meta, batched=False)
|
124 |
+
|
125 |
+
# optionally, apply the metric function(s) - somewhat expensive
|
126 |
+
# want to do this here rather than at end because you need each run's tokenizer
|
127 |
+
# though tbh it would be odd if they're not the same, but you can check that at the end
|
128 |
+
if apply_metric_func:
|
129 |
+
|
130 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_table_meta["model_name"])
|
131 |
+
|
132 |
+
comp_bl_metrics = partial(
|
133 |
+
compute_bl_metrics,
|
134 |
+
tokenizer=tokenizer,
|
135 |
+
hf_model_name=gen_table_meta["model_name"],
|
136 |
+
initial_seed=gen_table_meta["initial_seed"],
|
137 |
+
dynamic_seed=gen_table_meta["dynamic_seed"],
|
138 |
+
bl_proportion=gen_table_meta["bl_proportion"],
|
139 |
+
use_cuda=True, # this is obvi critical to match the pseudorandomness
|
140 |
+
record_hits=True,
|
141 |
+
limit_output_tokens=limit_output_tokens,
|
142 |
+
)
|
143 |
+
gen_table_w_bl_metrics = gen_table_w_meta.map(comp_bl_metrics, batched=False, with_indices=True)
|
144 |
+
|
145 |
+
|
146 |
+
# Construct the blacklist processor so you can get the expectation coef
|
147 |
+
all_token_ids = list(tokenizer.get_vocab().values())
|
148 |
+
vocab_size = len(all_token_ids)
|
149 |
+
args = Namespace()
|
150 |
+
args.__dict__.update(gen_table_meta)
|
151 |
+
|
152 |
+
bl_processor = BlacklistLogitsProcessor(bad_words_ids=None,
|
153 |
+
store_bl_ids=False,
|
154 |
+
store_spike_ents=True,
|
155 |
+
eos_token_id=tokenizer.eos_token_id,
|
156 |
+
vocab=all_token_ids,
|
157 |
+
vocab_size=vocab_size,
|
158 |
+
bl_proportion=args.bl_proportion,
|
159 |
+
bl_logit_bias=args.bl_logit_bias,
|
160 |
+
bl_type=args.bl_type,
|
161 |
+
initial_seed= args.initial_seed,
|
162 |
+
dynamic_seed=args.dynamic_seed)
|
163 |
+
|
164 |
+
if "spike_entropies" in gen_table_w_bl_metrics.column_names:
|
165 |
+
comp_exp_num_wl = partial(
|
166 |
+
expected_whitelist,
|
167 |
+
exp_wl_coef=bl_processor.expected_wl_coef,
|
168 |
+
drop_spike_entropies=False,
|
169 |
+
# drop_spike_entropies=True,
|
170 |
+
)
|
171 |
+
gen_table_w_spike_ents = gen_table_w_bl_metrics.map(comp_exp_num_wl, batched=False, with_indices=True)
|
172 |
+
final_single_run_ds = gen_table_w_spike_ents
|
173 |
+
else:
|
174 |
+
final_single_run_ds = gen_table_w_bl_metrics
|
175 |
+
else:
|
176 |
+
final_single_run_ds = gen_table_w_meta
|
177 |
+
|
178 |
+
all_datasets.append(final_single_run_ds)
|
179 |
+
|
180 |
+
ds = concatenate_datasets(all_datasets)
|
181 |
+
|
182 |
+
if save_ds:
|
183 |
+
ds.save_to_disk(save_dir)
|
184 |
+
|
185 |
+
if convert_to_pandas:
|
186 |
+
df = ds.to_pandas()
|
187 |
+
return df
|
188 |
+
else:
|
189 |
+
return ds
|
190 |
+
|
191 |
+
|
192 |
+
output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_large_sweep"
|
193 |
+
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_large_sweep_downsize"
|
194 |
+
|
195 |
+
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_large_sweep_downsize"
|
196 |
+
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_greedy_redo"
|
197 |
+
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_greedy_gamma_0-25"
|
198 |
+
|
199 |
+
run_names = list(filter(lambda name: os.path.exists(f"{output_dir}/{name}/gen_table_w_metrics.jsonl"), sorted(os.listdir(output_dir))))
|
200 |
+
run_names = list(filter(lambda name: "realnewslike" in name, run_names))
|
201 |
+
# run_names = list(filter(lambda name: "pile" in name, run_names))
|
202 |
+
# run_names = list(filter(lambda name: "c4_en" in name, run_names))
|
203 |
+
|
204 |
+
|
205 |
+
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_attacked_greedy_updated"
|
206 |
+
# # output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_attacked_new"
|
207 |
+
# run_names = list(filter(lambda name: os.path.exists(f"{output_dir}/{name}/gen_table_w{('_'+name) if 't5' in name else ''}_attack_metrics.jsonl"), sorted(os.listdir(output_dir))))
|
208 |
+
# run_names = list(filter(lambda name: os.path.exists(f"{output_dir}/{name}/gen_table_w_attack_metrics.jsonl"), sorted(os.listdir(output_dir))))
|
209 |
+
|
210 |
+
runs_to_load = run_names
|
211 |
+
|
212 |
+
|
213 |
+
print(len(run_names))
|
214 |
+
for name in run_names: print(name)
|
215 |
+
|
216 |
+
runs_ready = [os.path.exists(f"{output_dir}/{name}/gen_table_w_metrics.jsonl") for name in runs_to_load]
|
217 |
+
# runs_ready = [os.path.exists(f"{output_dir}/{name}/gen_table_w_attack_metrics.jsonl") for name in runs_to_load]
|
218 |
+
print(f"all runs ready? {all(runs_ready)}\n{runs_ready}")
|
219 |
+
|
220 |
+
|
221 |
+
# save_name = "analysis_ds_1-21_greedy_redo"
|
222 |
+
# save_name = "analysis_ds_1-21_greedy_redo_truncated"
|
223 |
+
# save_name = "analysis_ds_1-21_greedy_redo_truncated_sanity_check"
|
224 |
+
# save_name = "analysis_ds_1-19_realnews_1-3_v2_hitlist_check"
|
225 |
+
# save_name = "analysis_ds_1-20_more_attack"
|
226 |
+
|
227 |
+
# save_name = "analysis_ds_1-23_greedy_gamma_0-25_truncated"
|
228 |
+
# save_name = "analysis_ds_1-21_greedy_attacked_updated_truncated"
|
229 |
+
|
230 |
+
# save_name = "analysis_ds_1-23_pile_1-3"
|
231 |
+
# save_name = "analysis_ds_1-23_en_1-3"
|
232 |
+
|
233 |
+
save_name = "analysis_ds_1-30_realnews_2-7"
|
234 |
+
|
235 |
+
save_dir = f"input/{save_name}"
|
236 |
+
|
237 |
+
raw_data = load_all_datasets(run_names=runs_to_load,
|
238 |
+
base_run_dir=output_dir,
|
239 |
+
meta_name="gen_table_meta.json",
|
240 |
+
gen_name="gen_table_w_metrics.jsonl",
|
241 |
+
# gen_name="gen_table_w_attack_metrics.jsonl",
|
242 |
+
apply_metric_func=True,
|
243 |
+
# drop_buggy_rows=True,
|
244 |
+
drop_buggy_rows=False,
|
245 |
+
# limit_output_tokens=200,
|
246 |
+
convert_to_pandas=False,
|
247 |
+
save_ds=True,
|
248 |
+
save_dir=save_dir)
|
249 |
+
|
250 |
+
print(f"All finished with {save_dir}!!")
|
lm-watermarking-main/experiments/run_watermarking.py
ADDED
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Basic imports
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
from typing import List, Iterable, Optional
|
6 |
+
from functools import partial
|
7 |
+
import time
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
import random
|
11 |
+
import math
|
12 |
+
from statistics import mean
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from torch import Tensor
|
17 |
+
from tokenizers import Tokenizer
|
18 |
+
|
19 |
+
import wandb
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
|
22 |
+
# cache path before HF imports just for kicks
|
23 |
+
# bc I don't really know when this is pulled by the library
|
24 |
+
# TODO change to passing as an arg to the model load fn
|
25 |
+
USER = "jkirchen"
|
26 |
+
# Huggingface cache
|
27 |
+
HF_HOME=f"/cmlscratch/{USER}/.cache/huggingface"
|
28 |
+
# HF_HOME=f"/scratch0/{USER}/.cache/huggingface"
|
29 |
+
# HF_HOME=f"/scratch1/{USER}/.cache/huggingface"
|
30 |
+
os.environ["HF_HOME"] = HF_HOME
|
31 |
+
|
32 |
+
print(os.environ["HF_HOME"])
|
33 |
+
|
34 |
+
# HF classses
|
35 |
+
from transformers import (AutoTokenizer,
|
36 |
+
AutoModelForSeq2SeqLM,
|
37 |
+
AutoModelForCausalLM,
|
38 |
+
LogitsProcessorList)
|
39 |
+
|
40 |
+
from datasets import load_dataset, Dataset
|
41 |
+
|
42 |
+
# watermarking micro lib
|
43 |
+
from watermark import (BlacklistLogitsProcessor,
|
44 |
+
add_idx,
|
45 |
+
check_input_lengths,
|
46 |
+
check_output_lengths,
|
47 |
+
tokenize_for_generation,
|
48 |
+
generate_completions,
|
49 |
+
evaluate_generation_fluency)
|
50 |
+
|
51 |
+
# better bool flag type for argparse
|
52 |
+
from submitit_utils import str2bool
|
53 |
+
|
54 |
+
# some file i/o helpers
|
55 |
+
from io_utils import write_jsonlines, write_json, read_jsonlines, read_json
|
56 |
+
|
57 |
+
def main(args):
|
58 |
+
|
59 |
+
###########################################################################
|
60 |
+
# Start logging
|
61 |
+
###########################################################################
|
62 |
+
if not args.no_wandb:
|
63 |
+
|
64 |
+
# storing slurm info to be sent to wandb to allow auditing logfiles later
|
65 |
+
args.SLURM_JOB_ID = os.getenv("SLURM_JOB_ID")
|
66 |
+
args.SLURM_ARRAY_JOB_ID = os.getenv("SLURM_ARRAY_JOB_ID")
|
67 |
+
args.SLURM_ARRAY_TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
|
68 |
+
|
69 |
+
# start a new wandb run to track this experiment, will send data to it later
|
70 |
+
run = wandb.init(
|
71 |
+
# set the wandb project where this run will be logged
|
72 |
+
project=args.wandb_project,
|
73 |
+
entity=args.wandb_entity,
|
74 |
+
name=args.run_name,
|
75 |
+
|
76 |
+
# track hyperparameters and run metadata
|
77 |
+
config=args
|
78 |
+
)
|
79 |
+
|
80 |
+
print(f"Output dir for this run: {args.output_dir}")
|
81 |
+
# notify if exists
|
82 |
+
if os.path.exists(args.output_dir):
|
83 |
+
print(f"Output dir for this run already exists!")
|
84 |
+
print(f"Contents: {sorted(os.listdir(args.output_dir))}")
|
85 |
+
else:
|
86 |
+
# create the output dir where run artifacts are stored
|
87 |
+
os.makedirs(args.output_dir)
|
88 |
+
|
89 |
+
###########################################################################
|
90 |
+
# Instantiate model and tokenizer
|
91 |
+
###########################################################################
|
92 |
+
hf_model_name = args.model_name
|
93 |
+
|
94 |
+
if "t5" in hf_model_name or "T0" in hf_model_name:
|
95 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(hf_model_name)
|
96 |
+
else:
|
97 |
+
model = AutoModelForCausalLM.from_pretrained(hf_model_name)
|
98 |
+
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
|
100 |
+
|
101 |
+
# defaults to device 0
|
102 |
+
# will need to use 'parallelize' for multi-gpu sharding
|
103 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
104 |
+
model = model.to(device)
|
105 |
+
model.eval()
|
106 |
+
|
107 |
+
###########################################################################
|
108 |
+
# Load the dataset
|
109 |
+
###########################################################################
|
110 |
+
|
111 |
+
dataset_name, dataset_config_name = args.dataset_name, args.dataset_config_name
|
112 |
+
|
113 |
+
if dataset_name == "cml_pile":
|
114 |
+
subsets = [dataset_config_name]
|
115 |
+
dataset = load_dataset("input/cml_pile.py",
|
116 |
+
subsets=subsets,
|
117 |
+
streaming=True,
|
118 |
+
split=None,
|
119 |
+
ignore_verifications=True)["train"]
|
120 |
+
else:
|
121 |
+
dataset = load_dataset(dataset_name, dataset_config_name, split="train", streaming=True)
|
122 |
+
|
123 |
+
# log an example
|
124 |
+
ds_iterator = iter(dataset)
|
125 |
+
idx = 75 # if this is c4, it's the schumacher example lol
|
126 |
+
i = 0
|
127 |
+
while i < idx:
|
128 |
+
next(ds_iterator)
|
129 |
+
i += 1
|
130 |
+
|
131 |
+
example = next(ds_iterator)
|
132 |
+
print(example)
|
133 |
+
|
134 |
+
###########################################################################
|
135 |
+
# Construct the blacklist processor/sampler
|
136 |
+
###########################################################################
|
137 |
+
|
138 |
+
all_token_ids = list(tokenizer.get_vocab().values())
|
139 |
+
vocab_size = len(all_token_ids)
|
140 |
+
print(f"Vocabulary size: {vocab_size}")
|
141 |
+
|
142 |
+
max_new_tokens = args.max_new_tokens
|
143 |
+
min_prompt_tokens = args.min_prompt_tokens
|
144 |
+
|
145 |
+
init_seed = args.initial_seed
|
146 |
+
dyna_seed=args.dynamic_seed # type not value
|
147 |
+
bl_proportion = args.bl_proportion
|
148 |
+
bl_logit_bias = args.bl_logit_bias
|
149 |
+
bl_type = args.bl_type
|
150 |
+
n_beams = args.num_beams
|
151 |
+
early_stopping = args.early_stopping
|
152 |
+
no_repeat_ngram_size = args.no_repeat_ngram_size
|
153 |
+
store_bl_ids = args.store_bl_ids
|
154 |
+
store_spike_ents = args.store_spike_ents
|
155 |
+
|
156 |
+
bl_processor = BlacklistLogitsProcessor(bad_words_ids=None,
|
157 |
+
store_bl_ids=store_bl_ids,
|
158 |
+
store_spike_ents=store_spike_ents,
|
159 |
+
eos_token_id=tokenizer.eos_token_id,
|
160 |
+
vocab=all_token_ids,
|
161 |
+
vocab_size=vocab_size,
|
162 |
+
bl_proportion=bl_proportion,
|
163 |
+
bl_logit_bias=bl_logit_bias,
|
164 |
+
bl_type=bl_type,
|
165 |
+
initial_seed=init_seed,
|
166 |
+
dynamic_seed=dyna_seed)
|
167 |
+
|
168 |
+
logit_processor_lst = LogitsProcessorList([bl_processor])
|
169 |
+
|
170 |
+
# Greedy and basic beam search, default
|
171 |
+
gen_kwargs = dict(
|
172 |
+
max_new_tokens=max_new_tokens,
|
173 |
+
num_beams=n_beams,
|
174 |
+
)
|
175 |
+
if n_beams > 1:
|
176 |
+
# these are only for beam search repetition correction
|
177 |
+
if no_repeat_ngram_size > 0:
|
178 |
+
gen_kwargs.update(dict(no_repeat_ngram_size=no_repeat_ngram_size))
|
179 |
+
gen_kwargs.update(dict(early_stopping=early_stopping))
|
180 |
+
|
181 |
+
if args.use_sampling:
|
182 |
+
gen_kwargs.update(dict(do_sample=True,
|
183 |
+
top_k=0,
|
184 |
+
temperature=args.sampling_temp))
|
185 |
+
if args.all_gas_no_eos:
|
186 |
+
gen_kwargs.update(dict(suppress_tokens=[tokenizer.eos_token_id]))
|
187 |
+
|
188 |
+
generate_without_blacklist = partial(
|
189 |
+
model.generate,
|
190 |
+
**gen_kwargs
|
191 |
+
)
|
192 |
+
generate_with_blacklist = partial(
|
193 |
+
model.generate,
|
194 |
+
logits_processor=logit_processor_lst,
|
195 |
+
**gen_kwargs
|
196 |
+
)
|
197 |
+
|
198 |
+
###########################################################################
|
199 |
+
# Construct the generation and measurement pipeline (lazy)
|
200 |
+
# that pulls from the streaming dataset, applies the generations map funcs
|
201 |
+
###########################################################################
|
202 |
+
|
203 |
+
# Set up the pipeline functions
|
204 |
+
if "c4" in dataset_name:
|
205 |
+
columns_to_remove = ["text","timestamp","url"]
|
206 |
+
else:
|
207 |
+
columns_to_remove = []
|
208 |
+
|
209 |
+
# Construct the data filtering/sampling scheme partials
|
210 |
+
token_kwargs = dict(
|
211 |
+
hf_model_name=hf_model_name,
|
212 |
+
tokenizer=tokenizer,
|
213 |
+
model=model,
|
214 |
+
)
|
215 |
+
if args.input_truncation_strategy == "prompt_length":
|
216 |
+
token_kwargs.update(dict(min_prompt_tokens=min_prompt_tokens))
|
217 |
+
elif args.input_truncation_strategy == "completion_length":
|
218 |
+
token_kwargs.update(dict(max_new_tokens=max_new_tokens))
|
219 |
+
else:
|
220 |
+
ValueError(f"Unknown input truncation strategy {args.input_truncation_strategy}")
|
221 |
+
tokenize_prompts = partial(
|
222 |
+
tokenize_for_generation,
|
223 |
+
**token_kwargs
|
224 |
+
)
|
225 |
+
|
226 |
+
input_check_kwargs = dict(
|
227 |
+
# min_sample_len = min_prompt_tokens + max_new_tokens,
|
228 |
+
min_sample_len = args.min_sample_tokens, # first line is a bug sometimes with large amounts
|
229 |
+
)
|
230 |
+
if args.input_filtering_strategy == "prompt_length":
|
231 |
+
input_check_kwargs.update(dict(min_prompt_len = min_prompt_tokens,
|
232 |
+
min_completion_len = 0))
|
233 |
+
elif args.input_filtering_strategy == "completion_length":
|
234 |
+
input_check_kwargs.update(dict(min_prompt_len = 0,
|
235 |
+
min_completion_len = max_new_tokens))
|
236 |
+
elif args.input_filtering_strategy == "prompt_and_completion_length":
|
237 |
+
input_check_kwargs.update(dict(min_prompt_len = min_prompt_tokens,
|
238 |
+
min_completion_len = max_new_tokens))
|
239 |
+
else:
|
240 |
+
ValueError(f"Unknown input filtering strategy {args.input_filtering_strategy}")
|
241 |
+
input_check = partial(
|
242 |
+
check_input_lengths,
|
243 |
+
**input_check_kwargs
|
244 |
+
)
|
245 |
+
|
246 |
+
if args.output_filtering_strategy == "max_new_tokens":
|
247 |
+
output_kwargs = dict(min_output_len = max_new_tokens)
|
248 |
+
elif args.output_filtering_strategy == "no_filter":
|
249 |
+
output_kwargs = dict(min_output_len = 0)
|
250 |
+
else:
|
251 |
+
ValueError(f"Unknown output filtering strategy {args.output_filtering_strategy}")
|
252 |
+
output_check = partial(
|
253 |
+
check_output_lengths,
|
254 |
+
**output_kwargs
|
255 |
+
)
|
256 |
+
|
257 |
+
gen_completions = partial(
|
258 |
+
generate_completions,
|
259 |
+
max_new_tokens=max_new_tokens,
|
260 |
+
hf_model_name=hf_model_name,
|
261 |
+
tokenizer=tokenizer,
|
262 |
+
model=model,
|
263 |
+
no_bl_partial=generate_without_blacklist,
|
264 |
+
w_bl_partial=generate_with_blacklist,
|
265 |
+
bl_processor_list=logit_processor_lst,
|
266 |
+
)
|
267 |
+
|
268 |
+
###########################################################################
|
269 |
+
# Compose/apply the pipeline steps
|
270 |
+
###########################################################################
|
271 |
+
|
272 |
+
# Apply the pipeline operations to the dataset
|
273 |
+
indexed_dataset = dataset.map(add_idx, batched=False, with_indices=True)
|
274 |
+
|
275 |
+
# shuffled the first shuffle_buffer_size rows of the (streaming) dataset
|
276 |
+
if args.shuffle_dataset:
|
277 |
+
shuffled_dataset = indexed_dataset.shuffle(seed=args.shuffle_seed,
|
278 |
+
buffer_size=args.shuffle_buffer_size)
|
279 |
+
else:
|
280 |
+
shuffled_dataset = indexed_dataset
|
281 |
+
|
282 |
+
# tokenize and truncate the row inputs to create prompts according to the strategy spec'd above
|
283 |
+
tokenized_and_truncated_dataset = shuffled_dataset.map(tokenize_prompts,
|
284 |
+
batched=False,
|
285 |
+
with_indices=True)
|
286 |
+
|
287 |
+
# filter the rows of the dataset based on length checks for the tokenized prompts and baseline completions
|
288 |
+
input_length_filtered_dataset = tokenized_and_truncated_dataset.filter(input_check,
|
289 |
+
batched=False,
|
290 |
+
with_indices=True)
|
291 |
+
|
292 |
+
# perform generation by calling the models
|
293 |
+
columns_to_remove += ["inputs", "untruncated_inputs"] # these are now materialized and must be dropped externally
|
294 |
+
generations_dataset = input_length_filtered_dataset.map(gen_completions,
|
295 |
+
batched=False,
|
296 |
+
with_indices=True,
|
297 |
+
remove_columns=columns_to_remove)
|
298 |
+
|
299 |
+
# # filter the dataset a last time based on the lengths of the outputs of the model
|
300 |
+
# output_length_filtered_dataset = generations_dataset.filter(output_check,
|
301 |
+
# batched=False,
|
302 |
+
# with_indices=True)
|
303 |
+
|
304 |
+
###########################################################################
|
305 |
+
# Main loop - actually executes the generation pipeline.
|
306 |
+
# and accumulates the result rows in a list, assumes list is "small"-ish
|
307 |
+
# and we aren't accumulating any tensors or other memory hogging artifacts
|
308 |
+
###########################################################################
|
309 |
+
if not args.load_prev_generations:
|
310 |
+
|
311 |
+
processed_examples = []
|
312 |
+
ds_iterator = iter(generations_dataset)
|
313 |
+
i = 0
|
314 |
+
while i < args.limit_indices:
|
315 |
+
|
316 |
+
ex = next(ds_iterator)
|
317 |
+
|
318 |
+
# log basics to stdout
|
319 |
+
print(f"#"*80)
|
320 |
+
print(f"dataset index: {ex['idx']}")
|
321 |
+
print(f"orig_sample_length: {ex['orig_sample_length']}")
|
322 |
+
print(f"prompt_length: {ex['prompt_length']}")
|
323 |
+
print(f"real_completion_length: {ex['real_completion_length']}")
|
324 |
+
print(f"no_bl_num_tokens_generated: {ex['no_bl_num_tokens_generated']}")
|
325 |
+
print(f"w_bl_num_tokens_generated: {ex['w_bl_num_tokens_generated']}")
|
326 |
+
|
327 |
+
print(f"\ntruncated_input: ")
|
328 |
+
print(ex["truncated_input"])
|
329 |
+
print(f"\nbaseline_completion: ")
|
330 |
+
print(ex["baseline_completion"])
|
331 |
+
print(f"\nno_bl_output: ")
|
332 |
+
print(ex["no_bl_output"])
|
333 |
+
print(f"\nw_bl_output: ")
|
334 |
+
print(ex["w_bl_output"])
|
335 |
+
print(f"\nno_bl_gen_time: ")
|
336 |
+
print(ex["no_bl_gen_time"])
|
337 |
+
print(f"\nno_bl_sec_per_tok: ")
|
338 |
+
print(ex["no_bl_sec_per_tok"])
|
339 |
+
print(f"\nno_bl_tok_per_sec: ")
|
340 |
+
print(ex["no_bl_tok_per_sec"])
|
341 |
+
print(f"\nw_bl_gen_time: ")
|
342 |
+
print(ex["w_bl_gen_time"])
|
343 |
+
print(f"\nw_bl_sec_per_tok: ")
|
344 |
+
print(ex["w_bl_sec_per_tok"])
|
345 |
+
print(f"\nw_bl_tok_per_sec: ")
|
346 |
+
print(ex["w_bl_tok_per_sec"])
|
347 |
+
|
348 |
+
processed_examples.append(ex)
|
349 |
+
if output_check(ex) == True:
|
350 |
+
i += 1
|
351 |
+
else:
|
352 |
+
print(f"\nGeneration too short, saving outputs, but not incrementing counter...\n",
|
353 |
+
f"{i} of {len(processed_examples)} rows were satisfactory so far",
|
354 |
+
f"current generation overhead ratio: {round(len(processed_examples)/(i+1), 3)}",
|
355 |
+
f"completed {round(i/args.limit_indices, 2)} of total")
|
356 |
+
|
357 |
+
print(f"#"*80,
|
358 |
+
f"\nGeneration output length check overhead was num rows processed={len(processed_examples)}",
|
359 |
+
f"for {args.limit_indices} samples. Ratio: {round(len(processed_examples)/args.limit_indices, 3)}")
|
360 |
+
|
361 |
+
###########################################################################
|
362 |
+
# Generation jsonl dumping/loading
|
363 |
+
###########################################################################
|
364 |
+
|
365 |
+
gen_table_meta_path = f"{args.output_dir}/gen_table_meta.json"
|
366 |
+
gen_table_path = f"{args.output_dir}/gen_table.jsonl"
|
367 |
+
safe_gen_table_path = f"{args.output_dir}/gen_table_safe.jsonl"
|
368 |
+
|
369 |
+
args.gen_table_already_existed = False
|
370 |
+
|
371 |
+
if not args.load_prev_generations:
|
372 |
+
|
373 |
+
if os.path.exists(gen_table_path):
|
374 |
+
print(f"Found existing generation files at this output dir: {args.output_dir}")
|
375 |
+
print(f"Writing generations at alternate, safe path and exiting. Note! this only works once. "
|
376 |
+
f"Safe version will get overwritten next time ... ")
|
377 |
+
gen_table_path = f"{args.output_dir}/gen_table_safe.jsonl"
|
378 |
+
args.gen_table_already_existed = True
|
379 |
+
|
380 |
+
gen_table_meta = args.__dict__
|
381 |
+
gen_table = processed_examples
|
382 |
+
|
383 |
+
write_jsonlines(gen_table, gen_table_path)
|
384 |
+
write_json(gen_table_meta,gen_table_meta_path,indent=4)
|
385 |
+
|
386 |
+
if args.gen_table_already_existed:
|
387 |
+
# finish the wandb run
|
388 |
+
if not args.no_wandb: run.finish()
|
389 |
+
return # from main, for safety
|
390 |
+
else:
|
391 |
+
print(f"Loading previously generated outputs for evaluation via oracle model and metrics...")
|
392 |
+
|
393 |
+
assert os.path.exists(gen_table_meta_path), f"failed file check for prev generations metadata json file: {gen_table_meta_path}"
|
394 |
+
assert os.path.exists(gen_table_path), f"failed file check for prev generations jsonl file: {gen_table_path}"
|
395 |
+
|
396 |
+
curr_gen_table_meta = args.__dict__.copy()
|
397 |
+
prev_gen_table_meta = read_json(gen_table_meta_path)
|
398 |
+
|
399 |
+
assert not prev_gen_table_meta["gen_table_already_existed"], f"failed for safety bc 'gen_table_already_existed' was true in the metadata file in this dir, indicating a possible issue"
|
400 |
+
assert not os.path.exists(safe_gen_table_path), f"failed for safety bc there is a secondary 'safe' marked file in this dir indicating a possible issue"
|
401 |
+
|
402 |
+
params_to_ignore = ["load_prev_generations","SLURM_JOB_ID","SLURM_ARRAY_JOB_ID","SLURM_ARRAY_TASK_ID"]
|
403 |
+
for k in params_to_ignore:
|
404 |
+
del curr_gen_table_meta[k]
|
405 |
+
del prev_gen_table_meta[k]
|
406 |
+
assert curr_gen_table_meta == prev_gen_table_meta, "failed safety check that current script params equal the params for the prev generations being loaded"
|
407 |
+
|
408 |
+
# gen_table_meta = argparse.Namespace(**args.__dict__)
|
409 |
+
gen_table_meta = args
|
410 |
+
gen_table = [ex for ex in read_jsonlines(gen_table_path)]
|
411 |
+
|
412 |
+
if args.generate_only:
|
413 |
+
# finish the wandb run
|
414 |
+
if not args.no_wandb: run.finish()
|
415 |
+
return # early exit, will reload later for ppl scoring
|
416 |
+
|
417 |
+
# Create a new dataset object either from the loop over examples
|
418 |
+
# or from the reloaded json lines
|
419 |
+
|
420 |
+
# gen_table_ds = Dataset.from_generator(ex for ex in gen_table) # hack since from_list is newer, and had 2.4.0
|
421 |
+
gen_table_ds = Dataset.from_list(gen_table)
|
422 |
+
|
423 |
+
###########################################################################
|
424 |
+
# Perplexity (PPL) evaluation
|
425 |
+
# which is a separate step partially bc it requires a different model on gpu
|
426 |
+
###########################################################################
|
427 |
+
|
428 |
+
# Load the oracle model for PPL measurement
|
429 |
+
# Assume on single GPU and need to free orig model memory for oracle model
|
430 |
+
if model is not None:
|
431 |
+
model = model.to(torch.device("cpu"))
|
432 |
+
del model
|
433 |
+
|
434 |
+
oracle_model_name = args.oracle_model_name
|
435 |
+
print(f"Loading oracle model: {oracle_model_name}")
|
436 |
+
|
437 |
+
oracle_tokenizer = AutoTokenizer.from_pretrained(oracle_model_name)
|
438 |
+
oracle_model = AutoModelForCausalLM.from_pretrained(oracle_model_name).to(device)
|
439 |
+
oracle_model.eval()
|
440 |
+
|
441 |
+
# construct fluency/ppl partial
|
442 |
+
eval_gen_metrics = partial(
|
443 |
+
evaluate_generation_fluency,
|
444 |
+
oracle_model_name=oracle_model_name,
|
445 |
+
oracle_model=oracle_model,
|
446 |
+
oracle_tokenizer=oracle_tokenizer
|
447 |
+
)
|
448 |
+
|
449 |
+
print(f"Computing metrics on model generations: {gen_table_ds}")
|
450 |
+
|
451 |
+
gen_table_w_metrics_ds = gen_table_ds.map(eval_gen_metrics, batched=False, with_indices=True)
|
452 |
+
|
453 |
+
|
454 |
+
print(f"#"*80)
|
455 |
+
print(f"baseline avg PPL: {mean(gen_table_w_metrics_ds['baseline_ppl'])}")
|
456 |
+
print(f"baseline avg loss: {mean(gen_table_w_metrics_ds['baseline_loss'])}")
|
457 |
+
print(f"no_bl avg PPL: {mean(gen_table_w_metrics_ds['no_bl_ppl'])}")
|
458 |
+
print(f"no_bl avg loss: {mean(gen_table_w_metrics_ds['no_bl_loss'])}")
|
459 |
+
print(f"w_bl avg PPL: {mean(gen_table_w_metrics_ds['w_bl_ppl'])}")
|
460 |
+
print(f"w_bl avg loss: {mean(gen_table_w_metrics_ds['w_bl_loss'])}")
|
461 |
+
|
462 |
+
# clear the model just for fun
|
463 |
+
oracle_model = oracle_model.to(torch.device("cpu"))
|
464 |
+
del oracle_model
|
465 |
+
|
466 |
+
gen_table_w_metrics_path = f"{args.output_dir}/gen_table_w_metrics.jsonl"
|
467 |
+
if os.path.exists(gen_table_w_metrics_path):
|
468 |
+
print(f"Found existing generation files with metrics added at this output dir. Overwriting anyway :\ -> {args.output_dir}")
|
469 |
+
|
470 |
+
gen_table_w_metrics_lst = [ex for ex in gen_table_w_metrics_ds]
|
471 |
+
write_jsonlines(gen_table_w_metrics_lst, gen_table_w_metrics_path)
|
472 |
+
|
473 |
+
# finish the wandb run
|
474 |
+
run.finish()
|
475 |
+
|
476 |
+
return
|
477 |
+
|
478 |
+
|
479 |
+
if __name__ == "__main__":
|
480 |
+
|
481 |
+
parser = argparse.ArgumentParser(description="Run watermarked huggingface LM generation pipeline")
|
482 |
+
parser.add_argument(
|
483 |
+
"--model_name",
|
484 |
+
type=str,
|
485 |
+
default="facebook/opt-2.7b",
|
486 |
+
help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
|
487 |
+
)
|
488 |
+
parser.add_argument(
|
489 |
+
"--dataset_name",
|
490 |
+
type=str,
|
491 |
+
default="c4",
|
492 |
+
help="The name of the dataset to use (via the datasets library).",
|
493 |
+
)
|
494 |
+
parser.add_argument(
|
495 |
+
"--dataset_config_name",
|
496 |
+
type=str,
|
497 |
+
default="realnewslike",
|
498 |
+
help="The configuration name of the dataset to use (via the datasets library).",
|
499 |
+
)
|
500 |
+
parser.add_argument(
|
501 |
+
"--shuffle_dataset",
|
502 |
+
type=str2bool,
|
503 |
+
default=False,
|
504 |
+
help="Whether to shuffle the dataset before sampling.",
|
505 |
+
)
|
506 |
+
parser.add_argument(
|
507 |
+
"--shuffle_seed",
|
508 |
+
type=int,
|
509 |
+
default=1234,
|
510 |
+
help="The seed to use for dataset shuffle op.",
|
511 |
+
)
|
512 |
+
parser.add_argument(
|
513 |
+
"--shuffle_buffer_size",
|
514 |
+
type=int,
|
515 |
+
default=10_000,
|
516 |
+
help="The buffer size to use for dataset shuffle op - takes n rows first, then shuffles those indices",
|
517 |
+
)
|
518 |
+
parser.add_argument(
|
519 |
+
"--max_new_tokens",
|
520 |
+
type=int,
|
521 |
+
default=100,
|
522 |
+
help="The number of tokens to generate using the model, and the num tokens removed from real text sample",
|
523 |
+
)
|
524 |
+
parser.add_argument(
|
525 |
+
"--min_prompt_tokens",
|
526 |
+
type=int,
|
527 |
+
default=50, # 500
|
528 |
+
help="The number of examples (first N) to process from the dataset.",
|
529 |
+
)
|
530 |
+
parser.add_argument(
|
531 |
+
"--min_sample_tokens",
|
532 |
+
type=int,
|
533 |
+
default=0,
|
534 |
+
help="The the minimum length of raw prompt samples to consider.",
|
535 |
+
)
|
536 |
+
parser.add_argument(
|
537 |
+
"--limit_indices",
|
538 |
+
type=int,
|
539 |
+
default=5, # 500
|
540 |
+
help="The number of examples (first N) to process from the dataset.",
|
541 |
+
)
|
542 |
+
parser.add_argument(
|
543 |
+
"--input_truncation_strategy",
|
544 |
+
type=str,
|
545 |
+
default="completion_length",
|
546 |
+
choices=["completion_length", "prompt_length"],
|
547 |
+
help="The strategy to use when tokenizing and truncating raw inputs to make prompts.",
|
548 |
+
)
|
549 |
+
parser.add_argument(
|
550 |
+
"--input_filtering_strategy",
|
551 |
+
type=str,
|
552 |
+
default="completion_length",
|
553 |
+
choices=["completion_length", "prompt_length", "prompt_and_completion_length"],
|
554 |
+
help="The strategy to use when tokenizing and truncating raw inputs to make prompts.",
|
555 |
+
)
|
556 |
+
parser.add_argument(
|
557 |
+
"--output_filtering_strategy",
|
558 |
+
type=str,
|
559 |
+
default="no_filter",
|
560 |
+
choices=["no_filter", "max_new_tokens"],
|
561 |
+
help=(f"The strategy to use when filtering/skipping rows if the model didn't ",
|
562 |
+
f"generate enough tokens to facilitate analysis.")
|
563 |
+
)
|
564 |
+
parser.add_argument(
|
565 |
+
"--initial_seed",
|
566 |
+
type=int,
|
567 |
+
default=1234,
|
568 |
+
help=("The initial seed to use in the blacklist randomization process.",
|
569 |
+
"Is unused if the process is markov generally. Can be None."),
|
570 |
+
)
|
571 |
+
parser.add_argument(
|
572 |
+
"--dynamic_seed",
|
573 |
+
type=str,
|
574 |
+
default="markov_1",
|
575 |
+
choices=[None, "initial", "markov_1"],
|
576 |
+
help="The seeding procedure to use when sampling the blacklist at each step.",
|
577 |
+
)
|
578 |
+
parser.add_argument(
|
579 |
+
"--bl_proportion",
|
580 |
+
type=float,
|
581 |
+
default=0.5,
|
582 |
+
help="The ratio of blacklist to whitelist tokens when splitting the vocabulary",
|
583 |
+
)
|
584 |
+
parser.add_argument(
|
585 |
+
"--bl_logit_bias",
|
586 |
+
type=float,
|
587 |
+
default=1.0,
|
588 |
+
help="The amount of bias (absolute) to add to the logits in the whitelist half of the vocabulary at every step",
|
589 |
+
)
|
590 |
+
parser.add_argument(
|
591 |
+
"--bl_type",
|
592 |
+
type=str,
|
593 |
+
default="soft",
|
594 |
+
choices=["soft", "hard"],
|
595 |
+
help="The type of blacklisting being performed.",
|
596 |
+
)
|
597 |
+
parser.add_argument(
|
598 |
+
"--num_beams",
|
599 |
+
type=int,
|
600 |
+
default=1,
|
601 |
+
help="The number of beams to use where '1' is no beam search.",
|
602 |
+
)
|
603 |
+
parser.add_argument(
|
604 |
+
"--no_repeat_ngram_size",
|
605 |
+
type=int,
|
606 |
+
default=0,
|
607 |
+
# default=8,
|
608 |
+
help="ngram size to force the model not to generate, can't be too small or model is handicapped, too large and blows up in complexity.",
|
609 |
+
)
|
610 |
+
parser.add_argument(
|
611 |
+
"--early_stopping",
|
612 |
+
type=str2bool,
|
613 |
+
default=False,
|
614 |
+
help="Whether to use early stopping, only for beam search.",
|
615 |
+
)
|
616 |
+
# parser.add_argument(
|
617 |
+
# "--hard_min_length",
|
618 |
+
# type=str2bool,
|
619 |
+
# default=False,
|
620 |
+
# help="Whether to use the min length logits processor to force the generations to be max_new_tokens.",
|
621 |
+
# )
|
622 |
+
parser.add_argument(
|
623 |
+
"--oracle_model_name",
|
624 |
+
type=str,
|
625 |
+
default="EleutherAI/gpt-j-6B",
|
626 |
+
help="PPL scoring, or oracle model, path to pretrained model or model identifier from huggingface.co/models.",
|
627 |
+
)
|
628 |
+
parser.add_argument(
|
629 |
+
"--no_wandb",
|
630 |
+
type=str2bool,
|
631 |
+
default=False,
|
632 |
+
help="Whether to log to wandb.",
|
633 |
+
)
|
634 |
+
parser.add_argument(
|
635 |
+
"--wandb_project",
|
636 |
+
type=str,
|
637 |
+
default="lm-blacklisting",
|
638 |
+
help="The name of the wandb project.",
|
639 |
+
)
|
640 |
+
parser.add_argument(
|
641 |
+
"--wandb_entity",
|
642 |
+
type=str,
|
643 |
+
default="jwkirchenbauer",
|
644 |
+
help="The wandb entity/user for the project.",
|
645 |
+
)
|
646 |
+
parser.add_argument(
|
647 |
+
"--run_name",
|
648 |
+
type=str,
|
649 |
+
default=None,
|
650 |
+
help="The unique name for the run.",
|
651 |
+
)
|
652 |
+
parser.add_argument(
|
653 |
+
"--output_dir",
|
654 |
+
type=str,
|
655 |
+
default="./output",
|
656 |
+
help="The unique name for the run.",
|
657 |
+
)
|
658 |
+
parser.add_argument(
|
659 |
+
"--load_prev_generations",
|
660 |
+
type=str2bool,
|
661 |
+
default=False,
|
662 |
+
help=("Whether to run generations or load from a json lines in the output_dir. "
|
663 |
+
"If True, this file must exist and meta/args must match"),
|
664 |
+
)
|
665 |
+
parser.add_argument(
|
666 |
+
"--store_bl_ids",
|
667 |
+
type=str2bool,
|
668 |
+
default=False,
|
669 |
+
help=("Whether to store all the blacklists while generating with bl processor. "),
|
670 |
+
)
|
671 |
+
parser.add_argument(
|
672 |
+
"--store_spike_ents",
|
673 |
+
type=str2bool,
|
674 |
+
default=False,
|
675 |
+
help=("Whether to store the spike entropies while generating with bl processor. "),
|
676 |
+
)
|
677 |
+
parser.add_argument(
|
678 |
+
"--use_sampling",
|
679 |
+
type=str2bool,
|
680 |
+
default=False,
|
681 |
+
help=("Whether to perform sampling during generation. (non-greedy decoding)"),
|
682 |
+
)
|
683 |
+
parser.add_argument(
|
684 |
+
"--sampling_temp",
|
685 |
+
type=float,
|
686 |
+
default=0.7,
|
687 |
+
help="The temperature to use when generating using multinom sampling",
|
688 |
+
)
|
689 |
+
parser.add_argument(
|
690 |
+
"--generate_only",
|
691 |
+
type=str2bool,
|
692 |
+
default=False,
|
693 |
+
help=("Whether to only produce outputs and not evaluate anything like ppl"),
|
694 |
+
)
|
695 |
+
parser.add_argument(
|
696 |
+
"--all_gas_no_eos",
|
697 |
+
type=str2bool,
|
698 |
+
default=False,
|
699 |
+
help=("Whether to weight the EOS token as -inf"),
|
700 |
+
)
|
701 |
+
|
702 |
+
args = parser.parse_args()
|
703 |
+
|
704 |
+
main(args)
|
705 |
+
|
lm-watermarking-main/experiments/submitit_utils.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# stuff specifically for the sklearn logic
|
2 |
+
from typing import Mapping
|
3 |
+
from functools import partial, reduce
|
4 |
+
import operator
|
5 |
+
|
6 |
+
from itertools import product
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
###############################################################################
|
10 |
+
# A grid search convenience class
|
11 |
+
###############################################################################
|
12 |
+
class ParameterGrid:
|
13 |
+
"""logic YOINKED from sklearn <3
|
14 |
+
def worth just using the lib itself, or something fancier in future for
|
15 |
+
efficient sampling etc. It's implemented as an iterator interface but thats
|
16 |
+
probs not necessary"""
|
17 |
+
|
18 |
+
def __init__(self, params):
|
19 |
+
# we may want to product a few sets of parameters
|
20 |
+
# independently of eachother, so expects a List[Mapping]
|
21 |
+
if isinstance(params, Mapping):
|
22 |
+
self.params = [params]
|
23 |
+
else:
|
24 |
+
self.params = params
|
25 |
+
# removed all checking code soooo make sure your
|
26 |
+
# param dict is already nice and conforming
|
27 |
+
|
28 |
+
def __iter__(self):
|
29 |
+
"""Iterate over the points in the grid.
|
30 |
+
Returns
|
31 |
+
-------
|
32 |
+
params : iterator over dict of str to any
|
33 |
+
Yields dictionaries mapping each estimator parameter to one of its
|
34 |
+
allowed values.
|
35 |
+
"""
|
36 |
+
for p in self.params:
|
37 |
+
# Always sort the keys of a dictionary, for reproducibility
|
38 |
+
items = sorted(p.items())
|
39 |
+
if not items:
|
40 |
+
yield {}
|
41 |
+
else:
|
42 |
+
keys, values = zip(*items)
|
43 |
+
for v in product(*values):
|
44 |
+
params = dict(zip(keys, v))
|
45 |
+
yield params
|
46 |
+
|
47 |
+
def __len__(self):
|
48 |
+
"""Number of points on the grid."""
|
49 |
+
# Product function that can handle iterables (np.product can't).
|
50 |
+
product = partial(reduce, operator.mul)
|
51 |
+
return sum(
|
52 |
+
product(len(v) for v in p.values()) if p else 1 for p in self.params
|
53 |
+
)
|
54 |
+
|
55 |
+
###############################################################################
|
56 |
+
# little "oneliner" reduce thingy that turns your shallow dict into
|
57 |
+
# the list [k1, v1, k2, v2, k3, v3 ...]
|
58 |
+
# and optionally "k1 v1 k2 v2 k3 v3"
|
59 |
+
|
60 |
+
def flatten_dict(dict, to_string=False, sep=" "):
|
61 |
+
flat_dict = reduce(operator.iconcat,dict.items() , [])
|
62 |
+
if to_string:
|
63 |
+
try:
|
64 |
+
return sep.join([str(elm) for elm in flat_dict])
|
65 |
+
except:
|
66 |
+
raise ValueError(f'Error converting dict={flat_dict} to whitespace joined string')
|
67 |
+
else:
|
68 |
+
return flat_dict
|
69 |
+
|
70 |
+
|
71 |
+
def str2bool(v):
|
72 |
+
if isinstance(v, bool):
|
73 |
+
return v
|
74 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
75 |
+
return True
|
76 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
77 |
+
return False
|
78 |
+
else:
|
79 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
lm-watermarking-main/experiments/watermark.py
ADDED
@@ -0,0 +1,820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# micro lib to implement the watermarking extensions to LM generation
|
2 |
+
# as well as utils for eval/validaiton
|
3 |
+
|
4 |
+
from typing import List, Optional, Callable
|
5 |
+
|
6 |
+
import time
|
7 |
+
import random
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from torch import Tensor
|
13 |
+
from tokenizers import Tokenizer
|
14 |
+
from transformers import LogitsProcessor, LogitsProcessorList, set_seed
|
15 |
+
|
16 |
+
|
17 |
+
def tokenize_and_truncate(example: dict,
|
18 |
+
completion_length: int = None,
|
19 |
+
prompt_length: int = None,
|
20 |
+
hf_model_name: str = None,
|
21 |
+
tokenizer = None,
|
22 |
+
model_max_seq_len: int = 4096):
|
23 |
+
"""take hf dataset entry and preprocess it for completion by a model"""
|
24 |
+
assert hf_model_name is not None, "need model name to know whether to adjust wrt special tokens"
|
25 |
+
assert "text" in example, "expects 'text' field to be present"
|
26 |
+
# tokenize
|
27 |
+
inputs = tokenizer.encode(example["text"], return_tensors="pt", truncation=True, max_length=model_max_seq_len)
|
28 |
+
example.update({"untruncated_inputs": inputs})
|
29 |
+
|
30 |
+
if (completion_length is not None) and (prompt_length is None):
|
31 |
+
# leave at least one token as prefix # FIXME I think plus 1 since 0 is start tok
|
32 |
+
slice_length = min(inputs.shape[1]-1, completion_length)
|
33 |
+
elif (prompt_length is not None) and (completion_length is None):
|
34 |
+
desired_comp_len = (inputs.shape[1]-1) - prompt_length
|
35 |
+
slice_length = desired_comp_len if desired_comp_len > 0 else 0
|
36 |
+
else:
|
37 |
+
raise ValueError((f"Can only tokenize and truncate based on either the desired prompt length or desired completion length,",
|
38 |
+
f" but got completion_length:{completion_length},prompt_length:{prompt_length}"))
|
39 |
+
|
40 |
+
# truncate
|
41 |
+
inputs = inputs[:,:inputs.shape[1]-slice_length]
|
42 |
+
# logic depending on special tokens for the model
|
43 |
+
if "t5" in hf_model_name or "T0" in hf_model_name:
|
44 |
+
inputs[0,-1] = 1
|
45 |
+
# else: pass
|
46 |
+
example.update({"inputs": inputs})
|
47 |
+
return example
|
48 |
+
|
49 |
+
|
50 |
+
class BlacklistLogitsProcessor(LogitsProcessor):
|
51 |
+
"""
|
52 |
+
[`LogitsProcessor`] that enforces that specified sequences will never be sampled.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
bad_words_ids (`List[List[int]]`):
|
56 |
+
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
|
57 |
+
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
|
58 |
+
add_special_tokens=False).input_ids`.
|
59 |
+
eos_token_id (`int`):
|
60 |
+
The id of the *end-of-sequence* token.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self,
|
64 |
+
bad_words_ids: List[List[int]],
|
65 |
+
eos_token_id: int,
|
66 |
+
vocab: list[int],
|
67 |
+
vocab_size: int,
|
68 |
+
bl_proportion: float=0.5,
|
69 |
+
bl_logit_bias: float=1.0,
|
70 |
+
bl_type: str = "hard", # "soft"
|
71 |
+
initial_seed: int=None,
|
72 |
+
dynamic_seed: str=None, # "initial", "markov_1", None
|
73 |
+
store_bl_ids: bool=True,
|
74 |
+
store_spike_ents: bool = False,
|
75 |
+
noop_blacklist: bool = False,
|
76 |
+
):
|
77 |
+
|
78 |
+
self.vocab = vocab
|
79 |
+
self.vocab_size = vocab_size
|
80 |
+
self.bl_proportion = bl_proportion
|
81 |
+
self.bl_logit_bias = bl_logit_bias
|
82 |
+
self.bl_type = bl_type
|
83 |
+
|
84 |
+
if initial_seed is None:
|
85 |
+
self.initial_seed = None
|
86 |
+
assert dynamic_seed != "initial"
|
87 |
+
else:
|
88 |
+
random.seed(initial_seed)
|
89 |
+
self.initial_seed = initial_seed
|
90 |
+
|
91 |
+
self.dynamic_seed = dynamic_seed
|
92 |
+
|
93 |
+
self.eos_token_id = eos_token_id
|
94 |
+
# self.bad_words_id_length_1 = self._prepare_bad_words(bad_words_ids)
|
95 |
+
|
96 |
+
self.bad_words_mask: Optional[torch.LongTensor] = None
|
97 |
+
|
98 |
+
self.store_bl_ids = store_bl_ids
|
99 |
+
self.bl_ids = None
|
100 |
+
|
101 |
+
self.store_spike_ents = store_spike_ents
|
102 |
+
self.spike_entropies = None
|
103 |
+
|
104 |
+
# hack to replace this with an approximation of infinite bias
|
105 |
+
# so that the expectation coefficient will come out to 1.0
|
106 |
+
if self.bl_type == "hard":
|
107 |
+
self.bl_logit_bias = 10000 # FIXME to a value that is actually close to the largest soft watermark used
|
108 |
+
|
109 |
+
alpha = torch.exp(torch.tensor(self.bl_logit_bias)).item()
|
110 |
+
# gamma = self.bl_proportion
|
111 |
+
gamma = 1.0-self.bl_proportion
|
112 |
+
self.alpha = alpha
|
113 |
+
self.gamma = gamma
|
114 |
+
|
115 |
+
self.z_value = ((1-gamma)*(alpha-1))/(1-gamma+(alpha*gamma))
|
116 |
+
self.expected_wl_coef = (gamma*alpha)/(1-gamma+(alpha*gamma))
|
117 |
+
|
118 |
+
# catch for overflow when bias is "infinite"
|
119 |
+
if self.alpha == torch.inf:
|
120 |
+
self.z_value = 1.0
|
121 |
+
self.expected_wl_coef = 1.0
|
122 |
+
|
123 |
+
self.noop_blacklist = noop_blacklist
|
124 |
+
if self.noop_blacklist: print(f"Blacklist processor for accounting only, no rescoring of logits")
|
125 |
+
|
126 |
+
self.g_cuda = None
|
127 |
+
self.large_prime = 15485863
|
128 |
+
|
129 |
+
@property
|
130 |
+
def blacklisted_ids(self):
|
131 |
+
assert self.store_bl_ids, "Need to instantiate processor with `store_bl_ids` to be able to retrieve them later"
|
132 |
+
# flatten the each indexes blacklist
|
133 |
+
l_of_bl_ids = [[] for _ in range(len(self.bl_ids))]
|
134 |
+
for b_idx, batch in enumerate(self.bl_ids):
|
135 |
+
for l_of_l, seed in batch:
|
136 |
+
bl_ids = [l[0] for l in l_of_l] # this was the main line, maybe unnecessary now?
|
137 |
+
l_of_bl_ids[b_idx].append((bl_ids,seed))
|
138 |
+
return l_of_bl_ids
|
139 |
+
|
140 |
+
def get_and_clear_stored_bl_ids(self):
|
141 |
+
old_bl_ids = self.bl_ids
|
142 |
+
self.bl_ids = None
|
143 |
+
return old_bl_ids
|
144 |
+
|
145 |
+
def get_spike_entropies(self):
|
146 |
+
spike_ents = [[] for _ in range(len(self.spike_entropies))]
|
147 |
+
for b_idx, ent_tensor_list in enumerate(self.spike_entropies):
|
148 |
+
for ent_tensor in ent_tensor_list:
|
149 |
+
spike_ents[b_idx].append(ent_tensor.item())
|
150 |
+
return spike_ents
|
151 |
+
|
152 |
+
def get_and_clear_stored_spike_ents(self):
|
153 |
+
spike_ents = self.get_spike_entropies()
|
154 |
+
self.spike_entropies = None
|
155 |
+
return spike_ents
|
156 |
+
|
157 |
+
def compute_spike_entropy(self, scores):
|
158 |
+
# precomputed z value in init
|
159 |
+
probs = scores.softmax(dim=-1)
|
160 |
+
denoms = 1+(self.z_value*probs)
|
161 |
+
renormed_probs = probs / denoms
|
162 |
+
sum_renormed_probs = renormed_probs.sum()
|
163 |
+
return sum_renormed_probs
|
164 |
+
|
165 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
166 |
+
|
167 |
+
self.bad_words_id_length_1 = [None for _ in range(input_ids.shape[0])]
|
168 |
+
|
169 |
+
if self.g_cuda is None:
|
170 |
+
self.g_cuda = torch.Generator(device=input_ids.device)
|
171 |
+
|
172 |
+
for b_idx in range(input_ids.shape[0]):
|
173 |
+
if self.dynamic_seed == "initial":
|
174 |
+
self.g_cuda.manual_seed(self.large_prime*self.initial_seed)
|
175 |
+
elif self.dynamic_seed == "markov_1":
|
176 |
+
self.g_cuda.manual_seed(self.large_prime*input_ids[b_idx][-1].item())
|
177 |
+
elif self.dynamic_seed is None:
|
178 |
+
# let the rng evolve naturally - this is not a realistic setting
|
179 |
+
pass
|
180 |
+
|
181 |
+
bl_ct = int(self.vocab_size*self.bl_proportion)
|
182 |
+
blacklist_ids = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.g_cuda)[:bl_ct] # ty Yuxin :]
|
183 |
+
|
184 |
+
if self.store_bl_ids:
|
185 |
+
if self.bl_ids is None: self.bl_ids = [[] for _ in range(input_ids.shape[0])]
|
186 |
+
self.bl_ids[b_idx].append((blacklist_ids,input_ids.tolist()[b_idx][-1]))
|
187 |
+
|
188 |
+
if self.store_spike_ents:
|
189 |
+
if self.spike_entropies is None: self.spike_entropies = [[] for _ in range(input_ids.shape[0])]
|
190 |
+
self.spike_entropies[b_idx].append(self.compute_spike_entropy(scores[b_idx]))
|
191 |
+
|
192 |
+
# self.bad_words_id_length_1[b_idx] = self._prepare_bad_words(blacklist_ids)
|
193 |
+
# this logic may not really be necessary for our usecase
|
194 |
+
self.bad_words_id_length_1[b_idx] = blacklist_ids
|
195 |
+
|
196 |
+
if not self.noop_blacklist:
|
197 |
+
self.bad_words_mask = self._calc_curr_bad_word_mask(scores)
|
198 |
+
scores = self._set_scores_to_inf_for_banned_tokens(scores)
|
199 |
+
|
200 |
+
return scores
|
201 |
+
|
202 |
+
def _prepare_bad_words(self, bad_words_ids: List[List[int]]) -> list[int]:
|
203 |
+
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [self.eos_token_id], bad_words_ids))
|
204 |
+
return bad_words_ids
|
205 |
+
# used to have more logic, not used now
|
206 |
+
|
207 |
+
def _calc_curr_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor:
|
208 |
+
bad_words_mask = torch.zeros_like(scores)
|
209 |
+
for b_idx in range(len(self.bad_words_id_length_1)):
|
210 |
+
bad_words_mask[b_idx][self.bad_words_id_length_1[b_idx]] = 1
|
211 |
+
final_mask = bad_words_mask.bool()
|
212 |
+
return final_mask
|
213 |
+
|
214 |
+
def _set_scores_to_inf_for_banned_tokens(
|
215 |
+
self, scores: torch.Tensor
|
216 |
+
) -> torch.Tensor:
|
217 |
+
"""
|
218 |
+
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
|
219 |
+
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
|
220 |
+
|
221 |
+
Args:
|
222 |
+
scores: logits distribution of shape (batch size, vocabulary size)
|
223 |
+
banned_tokens: list of list of tokens to ban of length (batch_size)
|
224 |
+
# NOTE^^ Omitted logic for dynamic mask based on multi-token ban words
|
225 |
+
"""
|
226 |
+
if self.bl_type == "hard":
|
227 |
+
scores = scores.masked_fill(self.bad_words_mask, -float("inf"))
|
228 |
+
elif self.bl_type == "soft":
|
229 |
+
whitelist_mask = torch.logical_not(self.bad_words_mask)
|
230 |
+
blacklist_mask = self.bad_words_mask
|
231 |
+
scores[whitelist_mask] = scores[whitelist_mask] + self.bl_logit_bias
|
232 |
+
# scores[blacklist_mask] = scores[blacklist_mask] - self.bl_logit_bias # additive only
|
233 |
+
else:
|
234 |
+
raise NotImplementedError(f"unrecognized bl type {self.bl_type}!")
|
235 |
+
return scores
|
236 |
+
|
237 |
+
|
238 |
+
def score_sequence(inputs: Tensor = None,
|
239 |
+
outputs: Tensor = None,
|
240 |
+
tokenizer: Tokenizer = None,
|
241 |
+
# logits_processor: LogitsProcessor = None,
|
242 |
+
initial_seed: int = None,
|
243 |
+
dynamic_seed: str = None,
|
244 |
+
bl_proportion: float = None,
|
245 |
+
use_cuda: bool = True,
|
246 |
+
record_hits: bool = False,
|
247 |
+
debug: bool = True,
|
248 |
+
# trim_tokens: int = 2,
|
249 |
+
):
|
250 |
+
|
251 |
+
assert (inputs is not None) and \
|
252 |
+
(outputs is not None) and \
|
253 |
+
(tokenizer is not None),"output tensor, tokenizer, and bl params req'd"
|
254 |
+
# (logits_processor is not None),
|
255 |
+
|
256 |
+
vocabulary = list(tokenizer.get_vocab().values())
|
257 |
+
vocab_size = len(vocabulary)
|
258 |
+
|
259 |
+
model_generations = outputs.tolist()[0] # these are tensors unpack once for speed
|
260 |
+
# toks_generated = model_generations[num_orig_input_tokens:]
|
261 |
+
toks_generated = model_generations
|
262 |
+
num_toks_generated = len(toks_generated)
|
263 |
+
|
264 |
+
# num_toks_to_trim = trim_tokens*2
|
265 |
+
# if (num_toks_generated-num_toks_to_trim > 0) == False:
|
266 |
+
# return -1, -1
|
267 |
+
|
268 |
+
# assert num_toks_generated > num_toks_to_trim, f"Need more than {num_toks_to_trim} toks total since we trim start and end a bit."
|
269 |
+
|
270 |
+
# toks_generated = toks_generated[trim_tokens:-trim_tokens]
|
271 |
+
|
272 |
+
if initial_seed is not None:
|
273 |
+
random.seed(initial_seed)
|
274 |
+
|
275 |
+
device = (torch.device("cuda") if use_cuda else torch.device("cpu"))
|
276 |
+
g_cuda = torch.Generator(device=device)
|
277 |
+
large_prime = 15485863
|
278 |
+
|
279 |
+
bl_hits, hit_list = 0, []
|
280 |
+
|
281 |
+
prev_token = inputs[0][-1].item()
|
282 |
+
# prev_token = toks_generated[0] # haven't decided whether this edge effect matters
|
283 |
+
|
284 |
+
# for idx,tok_gend in enumerate(toks_generated[1:]):
|
285 |
+
for idx,tok_gend in enumerate(toks_generated):
|
286 |
+
|
287 |
+
# prev_token = model_generations[num_orig_input_tokens+idx-1]
|
288 |
+
|
289 |
+
if dynamic_seed == "initial":
|
290 |
+
g_cuda.manual_seed(large_prime*initial_seed)
|
291 |
+
elif dynamic_seed == "markov_1":
|
292 |
+
g_cuda.manual_seed(large_prime*prev_token)
|
293 |
+
elif dynamic_seed is None:
|
294 |
+
# let the rng evolve naturally - this is not a realistic setting
|
295 |
+
pass
|
296 |
+
|
297 |
+
bl_ct = int(vocab_size*bl_proportion)
|
298 |
+
posthoc_blacklist = torch.randperm(vocab_size, device=device, generator=g_cuda)[:bl_ct] # ty Yuxin :]
|
299 |
+
|
300 |
+
tok_in_ph_bl = tok_gend in posthoc_blacklist
|
301 |
+
if tok_in_ph_bl:
|
302 |
+
bl_hits += 1
|
303 |
+
hit_list.append(True)
|
304 |
+
else:
|
305 |
+
hit_list.append(False)
|
306 |
+
|
307 |
+
if debug:
|
308 |
+
decoded_token = tokenizer.decode(tok_gend, skip_special_tokens=True)
|
309 |
+
print(f"Token generated: '{decoded_token}' was in the blacklist {tok_in_ph_bl}")
|
310 |
+
|
311 |
+
prev_token = tok_gend
|
312 |
+
|
313 |
+
if debug:
|
314 |
+
print(f"wl hits / num tokens : {num_toks_generated-bl_hits}/{num_toks_generated} = {(num_toks_generated-bl_hits)/num_toks_generated:.02f}")
|
315 |
+
print(f"bl hits / num tokens : {bl_hits}/{num_toks_generated} = {bl_hits/num_toks_generated:.02f}")
|
316 |
+
|
317 |
+
if record_hits:
|
318 |
+
return bl_hits, num_toks_generated, hit_list
|
319 |
+
# bl_fraction = bl_hits/num_toks_generated
|
320 |
+
return bl_hits, num_toks_generated
|
321 |
+
|
322 |
+
|
323 |
+
def tokenize_for_generation(example: dict,
|
324 |
+
idx: int,
|
325 |
+
max_new_tokens: int=None,
|
326 |
+
min_prompt_tokens: int=None,
|
327 |
+
hf_model_name : str=None,
|
328 |
+
tokenizer: Tokenizer=None,
|
329 |
+
model: torch.nn.Module=None):
|
330 |
+
|
331 |
+
# preprocessing, generation & scoring
|
332 |
+
assert isinstance(example, dict), "Expect no batch dimension currently!"
|
333 |
+
|
334 |
+
# preprocess for model generation/completion
|
335 |
+
example = tokenize_and_truncate(example,
|
336 |
+
completion_length=max_new_tokens,
|
337 |
+
prompt_length=min_prompt_tokens,
|
338 |
+
hf_model_name=hf_model_name,
|
339 |
+
tokenizer=tokenizer,
|
340 |
+
# model_max_seq_len=model.config.max_position_embeddings)
|
341 |
+
model_max_seq_len=None)
|
342 |
+
inputs = example["inputs"]
|
343 |
+
# for calculating the baseline violation rate across the "gold" completion
|
344 |
+
untruncated_inputs = example["untruncated_inputs"]
|
345 |
+
|
346 |
+
# decode the preprocessed input to store for audit
|
347 |
+
re_decoded_input = tokenizer.batch_decode(inputs, skip_special_tokens=True)[0]
|
348 |
+
example.update({"truncated_input":re_decoded_input})
|
349 |
+
|
350 |
+
# also decode the original suffix of the input for audit as the baseline
|
351 |
+
decoded_untruncated_input = tokenizer.batch_decode(untruncated_inputs, skip_special_tokens=True)[0]
|
352 |
+
example.update({"baseline_completion":decoded_untruncated_input.replace(re_decoded_input,"")})
|
353 |
+
|
354 |
+
example.update({
|
355 |
+
"orig_sample_length" : untruncated_inputs.shape[1],
|
356 |
+
"prompt_length" : inputs.shape[1],
|
357 |
+
"real_completion_length" : untruncated_inputs.shape[1] - inputs.shape[1],
|
358 |
+
})
|
359 |
+
return example
|
360 |
+
|
361 |
+
|
362 |
+
def generate_completions(example: dict,
|
363 |
+
idx: int,
|
364 |
+
max_new_tokens: int=None,
|
365 |
+
hf_model_name : str=None,
|
366 |
+
tokenizer: Tokenizer=None,
|
367 |
+
model: torch.nn.Module=None,
|
368 |
+
no_bl_partial: Callable=None,
|
369 |
+
w_bl_partial: Callable=None,
|
370 |
+
# return_logits: bool=False,
|
371 |
+
bl_processor_list: LogitsProcessorList=None):
|
372 |
+
|
373 |
+
# preprocessing, generation & scoring
|
374 |
+
assert isinstance(example, dict), "Expect no batch dimension currently!"
|
375 |
+
|
376 |
+
# # preprocess for model generation/completion
|
377 |
+
# example = tokenize_and_truncate(example,
|
378 |
+
# completion_length=max_new_tokens,
|
379 |
+
# hf_model_name=hf_model_name,
|
380 |
+
# tokenizer=tokenizer,
|
381 |
+
# model_max_seq_len=model.config.max_position_embeddings)
|
382 |
+
# inputs = example["inputs"]
|
383 |
+
# # for calculating the baseline violation rate across the "gold" completion
|
384 |
+
# untruncated_inputs = example["untruncated_inputs"]
|
385 |
+
|
386 |
+
# # decode the preprocessed input to store for audit
|
387 |
+
# re_decoded_input = tokenizer.batch_decode(inputs, skip_special_tokens=True)[0]
|
388 |
+
# example.update({"truncated_input":re_decoded_input})
|
389 |
+
|
390 |
+
# # also decode the original suffix of the input for audit as the baseline
|
391 |
+
# decoded_untruncated_input = tokenizer.batch_decode(untruncated_inputs, skip_special_tokens=True)[0]
|
392 |
+
# example.update({"baseline_completion":decoded_untruncated_input.replace(re_decoded_input,"")})
|
393 |
+
|
394 |
+
inputs = example["inputs"]
|
395 |
+
re_decoded_input = example["truncated_input"]
|
396 |
+
|
397 |
+
# call the vanilla and watermarked generation function wrappers with the preprocessed inputs
|
398 |
+
with torch.no_grad():
|
399 |
+
|
400 |
+
samples_taken = 0
|
401 |
+
max_retries = 10
|
402 |
+
success = False
|
403 |
+
while (success is False) and (samples_taken < max_retries):
|
404 |
+
samples_taken += 1
|
405 |
+
|
406 |
+
# set_seed(1234) # debugging the error when using sampling # leaving this off for now
|
407 |
+
|
408 |
+
start_generation = time.time()
|
409 |
+
outputs_no_bl = no_bl_partial(inputs.to(model.device))
|
410 |
+
example["no_bl_gen_time"] = time.time() - start_generation
|
411 |
+
|
412 |
+
# set_seed(1234) # debugging the error when using sampling
|
413 |
+
|
414 |
+
start_generation = time.time()
|
415 |
+
outputs_w_bl = w_bl_partial(inputs.to(model.device))
|
416 |
+
example["w_bl_gen_time"] = time.time() - start_generation
|
417 |
+
|
418 |
+
# if return_logits:
|
419 |
+
# output_no_bl_dict = outputs_no_bl
|
420 |
+
# logits_no_bl = output_no_bl_dict.scores
|
421 |
+
# outputs_no_bl = output_no_bl_dict.sequences
|
422 |
+
# example["logits_no_bl"] = logits_no_bl
|
423 |
+
|
424 |
+
# output_w_bl_dict = outputs_w_bl
|
425 |
+
# logits_w_bl = output_w_bl_dict.scores
|
426 |
+
# outputs_w_bl = output_w_bl_dict.sequences
|
427 |
+
# example["logits_w_bl"] = logits_w_bl
|
428 |
+
|
429 |
+
|
430 |
+
if bl_processor_list:
|
431 |
+
if bl_processor_list[0].bl_ids is not None:
|
432 |
+
example["bl_ids"] = bl_processor_list[0].get_and_clear_stored_bl_ids()
|
433 |
+
if bl_processor_list[0].spike_entropies is not None:
|
434 |
+
example["spike_entropies"] = bl_processor_list[0].get_and_clear_stored_spike_ents()
|
435 |
+
|
436 |
+
try:
|
437 |
+
# decode and store the new generations for auditing
|
438 |
+
no_bl_decoded_output = tokenizer.batch_decode(outputs_no_bl, skip_special_tokens=True)[0]
|
439 |
+
example.update({"no_bl_output":no_bl_decoded_output.replace(re_decoded_input,"")})
|
440 |
+
|
441 |
+
w_bl_decoded_output = tokenizer.batch_decode(outputs_w_bl, skip_special_tokens=True)[0]
|
442 |
+
example.update({"w_bl_output":w_bl_decoded_output.replace(re_decoded_input,"")})
|
443 |
+
|
444 |
+
success = True
|
445 |
+
|
446 |
+
except:
|
447 |
+
# log what happened
|
448 |
+
print(f"Error while trying to decode the outputs of the model...")
|
449 |
+
if samples_taken == 1:
|
450 |
+
print(f"truncated_input: {inputs.tolist()}")
|
451 |
+
print(f"Result of attempt {samples_taken}")
|
452 |
+
print(f"shape outputs_no_bl: {outputs_no_bl.shape}")
|
453 |
+
no_bl_toks = outputs_no_bl.tolist()[0]
|
454 |
+
print(f"outputs_no_bl: {no_bl_toks}")
|
455 |
+
print(f"outputs_no_bl min: {min(no_bl_toks)}")
|
456 |
+
print(f"outputs_no_bl max: {max(no_bl_toks)}")
|
457 |
+
|
458 |
+
print(f"shape outputs_w_bl: {outputs_w_bl.shape}")
|
459 |
+
w_bl_toks = outputs_w_bl.tolist()[0]
|
460 |
+
print(f"outputs_w_bl: {w_bl_toks}")
|
461 |
+
print(f"outputs_w_bl min: {min(w_bl_toks)}")
|
462 |
+
print(f"outputs_w_bl max: {max(w_bl_toks)}")
|
463 |
+
|
464 |
+
if success is False:
|
465 |
+
print(f"Unable to get both a no_bl and w_bl output that were decodeable after {samples_taken} tries, returning empty strings.")
|
466 |
+
example.update({"no_bl_output":""})
|
467 |
+
example.update({"w_bl_output":""})
|
468 |
+
if bl_processor_list:
|
469 |
+
if bl_processor_list[0].bl_ids is not None:
|
470 |
+
example["bl_ids"] = []
|
471 |
+
if bl_processor_list[0].spike_entropies is not None:
|
472 |
+
example["spike_entropies"] = []
|
473 |
+
|
474 |
+
# Able to get lengths in here by checking
|
475 |
+
# truncated input shape versus the output shape
|
476 |
+
|
477 |
+
example.update({
|
478 |
+
# "baseline_num_tokens_generated" : untruncated_inputs.shape[1] - inputs.shape[1], # want this earlier now
|
479 |
+
"no_bl_num_tokens_generated" : outputs_no_bl.shape[1] - inputs.shape[1],
|
480 |
+
"w_bl_num_tokens_generated" : outputs_w_bl.shape[1] - inputs.shape[1]
|
481 |
+
})
|
482 |
+
example.update({
|
483 |
+
"no_bl_sec_per_tok" : example["no_bl_gen_time"]/example["no_bl_num_tokens_generated"],
|
484 |
+
"no_bl_tok_per_sec" : example["no_bl_num_tokens_generated"]/example["no_bl_gen_time"],
|
485 |
+
"w_bl_sec_per_tok" : example["w_bl_gen_time"]/example["w_bl_num_tokens_generated"],
|
486 |
+
"w_bl_tok_per_sec" : example["w_bl_num_tokens_generated"]/example["w_bl_gen_time"],
|
487 |
+
})
|
488 |
+
|
489 |
+
# now done externally because these persist outside this func
|
490 |
+
# # remove any fields we don't need to keep
|
491 |
+
# del example["inputs"]
|
492 |
+
# del example["untruncated_inputs"]
|
493 |
+
|
494 |
+
return example
|
495 |
+
|
496 |
+
|
497 |
+
def compute_bl_metrics(example: dict,
|
498 |
+
idx: int,
|
499 |
+
hf_model_name: str = None,
|
500 |
+
tokenizer: Tokenizer=None,
|
501 |
+
initial_seed: int = None,
|
502 |
+
dynamic_seed: str = None,
|
503 |
+
bl_proportion: float = None,
|
504 |
+
use_cuda: bool = None,
|
505 |
+
record_hits: bool = False,
|
506 |
+
limit_output_tokens: int = 0):
|
507 |
+
|
508 |
+
# if example["idx"] == 3: breakpoint()
|
509 |
+
# okay need to catch an odd bug here and fix things
|
510 |
+
baseline_before = example["baseline_completion"]
|
511 |
+
example["baseline_completion"] = baseline_before.replace(example["truncated_input"][:-1],"")
|
512 |
+
if example["baseline_completion"] != baseline_before:
|
513 |
+
print("baseline input replacement bug occurred!")
|
514 |
+
|
515 |
+
no_bl_before = example["no_bl_output"]
|
516 |
+
example["no_bl_output"] = no_bl_before.replace(example["truncated_input"][:-1],"")
|
517 |
+
if example["no_bl_output"] != no_bl_before:
|
518 |
+
print("no_bl_output input replacement bug occurred!")
|
519 |
+
|
520 |
+
w_bl_before = example["w_bl_output"]
|
521 |
+
example["w_bl_output"] = w_bl_before.replace(example["truncated_input"][:-1],"")
|
522 |
+
if example["w_bl_output"] != w_bl_before:
|
523 |
+
print("w_bl_output input replacement bug occurred!")
|
524 |
+
|
525 |
+
if ("w_bl_output_attacked" in example):
|
526 |
+
w_bl_attacked_before = example["w_bl_output_attacked"]
|
527 |
+
example["w_bl_output_attacked"] = w_bl_attacked_before.replace(example["truncated_input"][:-1],"")
|
528 |
+
if example["w_bl_output_attacked"] != w_bl_attacked_before:
|
529 |
+
print("w_bl_output_attacked input replacement bug occurred!")
|
530 |
+
|
531 |
+
##########
|
532 |
+
|
533 |
+
# preprocess for model generation/completion
|
534 |
+
inputs = tokenize_and_truncate({"text":example["truncated_input"]},
|
535 |
+
completion_length=0,
|
536 |
+
hf_model_name=hf_model_name,
|
537 |
+
tokenizer=tokenizer)["inputs"]
|
538 |
+
|
539 |
+
baseline_outputs = tokenize_and_truncate({"text":example["baseline_completion"]},
|
540 |
+
completion_length=0,
|
541 |
+
hf_model_name=hf_model_name,
|
542 |
+
tokenizer=tokenizer)["inputs"][:,1:]
|
543 |
+
|
544 |
+
no_bl_outputs = tokenize_and_truncate({"text":example["no_bl_output"]},
|
545 |
+
completion_length=0,
|
546 |
+
hf_model_name=hf_model_name,
|
547 |
+
tokenizer=tokenizer)["inputs"][:,1:]
|
548 |
+
|
549 |
+
w_bl_outputs = tokenize_and_truncate({"text":example["w_bl_output"]},
|
550 |
+
completion_length=0,
|
551 |
+
hf_model_name=hf_model_name,
|
552 |
+
tokenizer=tokenizer)["inputs"][:,1:]
|
553 |
+
if "w_bl_output_attacked" in example:
|
554 |
+
w_bl_attacked_outputs = tokenize_and_truncate({"text":example["w_bl_output_attacked"]},
|
555 |
+
completion_length=0,
|
556 |
+
hf_model_name=hf_model_name,
|
557 |
+
tokenizer=tokenizer)["inputs"][:,1:]
|
558 |
+
else:
|
559 |
+
w_bl_attacked_outputs = None
|
560 |
+
|
561 |
+
if limit_output_tokens > 0:
|
562 |
+
example["orig_baseline_completion"] = example["baseline_completion"]
|
563 |
+
example["orig_real_completion_length"] = example["real_completion_length"]
|
564 |
+
baseline_outputs = baseline_outputs[:,:limit_output_tokens]
|
565 |
+
example["real_completion_length"] = baseline_outputs.shape[1]
|
566 |
+
example["baseline_completion"] = tokenizer.batch_decode(baseline_outputs, skip_special_tokens=True)[0]
|
567 |
+
|
568 |
+
example["orig_no_bl_output"] = example["no_bl_output"]
|
569 |
+
example["orig_no_bl_num_tokens_generated"] = example["no_bl_num_tokens_generated"]
|
570 |
+
no_bl_outputs = no_bl_outputs[:,:limit_output_tokens]
|
571 |
+
example["no_bl_num_tokens_generated"] = no_bl_outputs.shape[1]
|
572 |
+
example["no_bl_output"] = tokenizer.batch_decode(no_bl_outputs, skip_special_tokens=True)[0]
|
573 |
+
|
574 |
+
example["orig_w_bl_output"] = example["w_bl_output"]
|
575 |
+
example["orig_w_bl_num_tokens_generated"] = example["w_bl_num_tokens_generated"]
|
576 |
+
w_bl_outputs = w_bl_outputs[:,:limit_output_tokens]
|
577 |
+
example["w_bl_num_tokens_generated"] = w_bl_outputs.shape[1]
|
578 |
+
example["w_bl_output"] = tokenizer.batch_decode(w_bl_outputs, skip_special_tokens=True)[0]
|
579 |
+
|
580 |
+
example["orig_spike_entropies"] = example["spike_entropies"]
|
581 |
+
example["spike_entropies"] = [example["spike_entropies"][0][:limit_output_tokens]]
|
582 |
+
|
583 |
+
if "w_bl_output_attacked" in example:
|
584 |
+
# raise NotImplementedError("Havent thought what to do yet for this")
|
585 |
+
example["orig_w_bl_output_attacked"] = example["w_bl_output_attacked"]
|
586 |
+
# example["orig_w_bl_attacked_num_tokens_generated"] = example["w_bl_attacked_num_tokens_generated"]
|
587 |
+
w_bl_attacked_outputs = w_bl_attacked_outputs[:,:limit_output_tokens]
|
588 |
+
example["w_bl_attacked_num_tokens_generated"] = w_bl_attacked_outputs.shape[1]
|
589 |
+
example["w_bl_output_attacked"] = tokenizer.batch_decode(w_bl_attacked_outputs, skip_special_tokens=True)[0]
|
590 |
+
|
591 |
+
# score the 3 sequence completions/outputs wrt to bl hits
|
592 |
+
result = score_sequence(inputs=inputs,
|
593 |
+
outputs=baseline_outputs, # <-- real text completions
|
594 |
+
initial_seed=initial_seed,
|
595 |
+
dynamic_seed=dynamic_seed,
|
596 |
+
bl_proportion=bl_proportion,
|
597 |
+
tokenizer=tokenizer,
|
598 |
+
use_cuda=use_cuda,
|
599 |
+
record_hits=record_hits,
|
600 |
+
debug=False)
|
601 |
+
if record_hits:
|
602 |
+
bl_hits, num_toks_gend, hit_list = result
|
603 |
+
else:
|
604 |
+
bl_hits, num_toks_gend = result
|
605 |
+
example.update({"baseline_num_toks_gend_eq_0":(num_toks_gend == 0)})
|
606 |
+
# if num_toks_gend < 0.99*example["real_completion_length"]: breakpoint()
|
607 |
+
# if len(hit_list) < 0.99*example["real_completion_length"]: breakpoint()
|
608 |
+
|
609 |
+
if num_toks_gend == 0:
|
610 |
+
# print("No tokens generated, odd, avoiding div by zero and returning -1's")
|
611 |
+
wl_frac = -1
|
612 |
+
bl_frac = -1
|
613 |
+
else:
|
614 |
+
wl_frac = (num_toks_gend-bl_hits)/num_toks_gend
|
615 |
+
bl_frac = bl_hits/num_toks_gend
|
616 |
+
baseline_stats = {
|
617 |
+
"baseline_whitelist_fraction": wl_frac,
|
618 |
+
"baseline_blacklist_fraction": bl_frac
|
619 |
+
}
|
620 |
+
example.update(baseline_stats)
|
621 |
+
if record_hits: example.update({"baseline_hit_list":hit_list})
|
622 |
+
|
623 |
+
result = score_sequence(inputs=inputs,
|
624 |
+
outputs=no_bl_outputs, # <-- non-blacklisted version
|
625 |
+
initial_seed=initial_seed,
|
626 |
+
dynamic_seed=dynamic_seed,
|
627 |
+
bl_proportion=bl_proportion,
|
628 |
+
tokenizer=tokenizer,
|
629 |
+
record_hits=record_hits,
|
630 |
+
debug=False)
|
631 |
+
if record_hits:
|
632 |
+
bl_hits, num_toks_gend, hit_list = result
|
633 |
+
else:
|
634 |
+
bl_hits, num_toks_gend = result
|
635 |
+
example.update({"no_bl_num_toks_gend_eq_0":(num_toks_gend == 0)})
|
636 |
+
# if num_toks_gend < 0.99*example["no_bl_num_tokens_generated"]: breakpoint()
|
637 |
+
# if len(hit_list) < 0.99*example["no_bl_num_tokens_generated"]: breakpoint()
|
638 |
+
|
639 |
+
if num_toks_gend == 0:
|
640 |
+
# print("No tokens generated, odd, avoiding div by zero and returning -1's")
|
641 |
+
wl_frac = -1
|
642 |
+
bl_frac = -1
|
643 |
+
else:
|
644 |
+
wl_frac = (num_toks_gend-bl_hits)/num_toks_gend
|
645 |
+
bl_frac = bl_hits/num_toks_gend
|
646 |
+
no_bl_stats = {
|
647 |
+
"no_bl_whitelist_fraction": wl_frac,
|
648 |
+
"no_bl_blacklist_fraction": bl_frac
|
649 |
+
}
|
650 |
+
example.update(no_bl_stats)
|
651 |
+
if record_hits: example.update({"no_bl_hit_list":hit_list})
|
652 |
+
|
653 |
+
result = score_sequence(inputs=inputs,
|
654 |
+
outputs=w_bl_outputs, # <-- blacklisted version
|
655 |
+
initial_seed=initial_seed,
|
656 |
+
dynamic_seed=dynamic_seed,
|
657 |
+
bl_proportion=bl_proportion,
|
658 |
+
tokenizer=tokenizer,
|
659 |
+
record_hits=record_hits,
|
660 |
+
# breakpoint_on_hit=True, # banging head against wall
|
661 |
+
debug=False)
|
662 |
+
if record_hits:
|
663 |
+
bl_hits, num_toks_gend, hit_list = result
|
664 |
+
else:
|
665 |
+
bl_hits, num_toks_gend = result
|
666 |
+
example.update({"w_bl_num_toks_gend_eq_0":(num_toks_gend == 0)})
|
667 |
+
# if num_toks_gend < 0.99*example["w_bl_num_tokens_generated"]: breakpoint()
|
668 |
+
# if len(hit_list) < 0.99*example["w_bl_num_tokens_generated"]: breakpoint()
|
669 |
+
|
670 |
+
if num_toks_gend == 0:
|
671 |
+
# print("No tokens generated, odd, avoiding div by zero and returning -1's")
|
672 |
+
wl_frac = -1
|
673 |
+
bl_frac = -1
|
674 |
+
else:
|
675 |
+
wl_frac = (num_toks_gend-bl_hits)/num_toks_gend
|
676 |
+
bl_frac = bl_hits/num_toks_gend
|
677 |
+
w_bl_stats = {
|
678 |
+
"w_bl_whitelist_fraction": wl_frac,
|
679 |
+
"w_bl_blacklist_fraction": bl_frac
|
680 |
+
}
|
681 |
+
example.update(w_bl_stats)
|
682 |
+
if record_hits: example.update({"w_bl_hit_list":hit_list})
|
683 |
+
|
684 |
+
if w_bl_attacked_outputs is not None:
|
685 |
+
result = score_sequence(inputs=inputs,
|
686 |
+
outputs=w_bl_attacked_outputs, # <-- blacklisted but attacked version
|
687 |
+
initial_seed=initial_seed,
|
688 |
+
dynamic_seed=dynamic_seed,
|
689 |
+
bl_proportion=bl_proportion,
|
690 |
+
tokenizer=tokenizer,
|
691 |
+
record_hits=record_hits,
|
692 |
+
# breakpoint_on_hit=True, # banging head against wall
|
693 |
+
debug=False)
|
694 |
+
if record_hits:
|
695 |
+
bl_hits, num_toks_gend, hit_list = result
|
696 |
+
else:
|
697 |
+
bl_hits, num_toks_gend = result
|
698 |
+
example.update({"w_bl_attacked_num_toks_gend_eq_0":(num_toks_gend == 0)})
|
699 |
+
# if (num_toks_gend-bl_hits)/(num_toks_gend) < 1.0: breakpoint()
|
700 |
+
|
701 |
+
if num_toks_gend == 0:
|
702 |
+
# print("No tokens generated, odd, avoiding div by zero and returning -1's")
|
703 |
+
wl_frac = -1
|
704 |
+
bl_frac = -1
|
705 |
+
else:
|
706 |
+
wl_frac = (num_toks_gend-bl_hits)/num_toks_gend
|
707 |
+
bl_frac = bl_hits/num_toks_gend
|
708 |
+
w_bl_attacked_stats = {
|
709 |
+
"w_bl_attacked_num_tokens_generated": num_toks_gend,
|
710 |
+
"w_bl_attacked_whitelist_fraction": wl_frac,
|
711 |
+
"w_bl_attacked_blacklist_fraction": bl_frac
|
712 |
+
}
|
713 |
+
example.update(w_bl_attacked_stats)
|
714 |
+
if record_hits: example.update({"w_bl_attacked_hit_list":hit_list})
|
715 |
+
|
716 |
+
return example
|
717 |
+
|
718 |
+
|
719 |
+
|
720 |
+
def aggregate_bl_stats(example: dict, idx: int, stat_table: dict):
|
721 |
+
|
722 |
+
stat_table["baseline_stats"]["whitelist_fraction"] += example["baseline_stats"]["whitelist_fraction"]
|
723 |
+
stat_table["baseline_stats"]["blacklist_fraction"] += example["baseline_stats"]["blacklist_fraction"]
|
724 |
+
|
725 |
+
stat_table["w_bl_stats"]["whitelist_fraction"] += example["w_bl_stats"]["whitelist_fraction"]
|
726 |
+
stat_table["w_bl_stats"]["blacklist_fraction"] += example["w_bl_stats"]["blacklist_fraction"]
|
727 |
+
|
728 |
+
stat_table["no_bl_stats"]["whitelist_fraction"] += example["no_bl_stats"]["whitelist_fraction"]
|
729 |
+
stat_table["no_bl_stats"]["blacklist_fraction"] += example["no_bl_stats"]["blacklist_fraction"]
|
730 |
+
|
731 |
+
stat_table["num_examples"] += 1
|
732 |
+
|
733 |
+
return example
|
734 |
+
|
735 |
+
|
736 |
+
def compute_ppl_single(prefix_and_output_text = None,
|
737 |
+
output_text = None,
|
738 |
+
oracle_model_name = None,
|
739 |
+
oracle_model = None,
|
740 |
+
oracle_tokenizer = None):
|
741 |
+
|
742 |
+
with torch.no_grad():
|
743 |
+
tokd_prefix = tokenize_and_truncate({"text":prefix_and_output_text}, completion_length=0, hf_model_name=oracle_model_name, tokenizer=oracle_tokenizer, model_max_seq_len=oracle_model.config.max_position_embeddings)["inputs"]
|
744 |
+
tokd_inputs = tokd_prefix
|
745 |
+
# if only want to score the "generation" part we need the suffix tokenization length
|
746 |
+
tokd_suffix = tokenize_and_truncate({"text":output_text}, completion_length=0, hf_model_name=oracle_model_name, tokenizer=oracle_tokenizer, model_max_seq_len=oracle_model.config.max_position_embeddings)["inputs"]
|
747 |
+
|
748 |
+
tokd_inputs = tokd_inputs.to(oracle_model.device)
|
749 |
+
# make labels, mark if not including all positions
|
750 |
+
tokd_labels = tokd_inputs.clone().detach()
|
751 |
+
tokd_labels[:,:tokd_labels.shape[1]-tokd_suffix.shape[1]+1] = -100
|
752 |
+
|
753 |
+
outputs = oracle_model(input_ids=tokd_inputs, labels=tokd_labels)
|
754 |
+
loss = outputs.loss # avg CE loss all positions (except -100, TODO plz check that this is working correctly)
|
755 |
+
ppl = torch.tensor(math.exp(loss))
|
756 |
+
|
757 |
+
return loss.item(), ppl.item()
|
758 |
+
|
759 |
+
|
760 |
+
def evaluate_generation_fluency(example: dict,
|
761 |
+
idx: int,
|
762 |
+
oracle_model_name = None,
|
763 |
+
oracle_model = None,
|
764 |
+
oracle_tokenizer = None):
|
765 |
+
|
766 |
+
# pull out the required fields from the pipeline results
|
767 |
+
inputs_plus_baseline_output = f"{example['truncated_input']}{example['baseline_completion']}"
|
768 |
+
baseline_output = f"{example['baseline_completion']}"
|
769 |
+
|
770 |
+
inputs_plus_no_bl_output = f"{example['truncated_input']}{example['no_bl_output']}"
|
771 |
+
no_bl_output = f"{example['no_bl_output']}"
|
772 |
+
|
773 |
+
inputs_plus_w_bl_output = f"{example['truncated_input']}{example['w_bl_output']}"
|
774 |
+
w_bl_output = f"{example['w_bl_output']}"
|
775 |
+
|
776 |
+
# add metrics
|
777 |
+
loss, ppl = compute_ppl_single(inputs_plus_baseline_output, baseline_output, oracle_model_name, oracle_model, oracle_tokenizer)
|
778 |
+
example["baseline_loss"] = loss
|
779 |
+
example["baseline_ppl"] = ppl
|
780 |
+
loss, ppl = compute_ppl_single(inputs_plus_no_bl_output, no_bl_output, oracle_model_name, oracle_model, oracle_tokenizer)
|
781 |
+
example["no_bl_loss"] = loss
|
782 |
+
example["no_bl_ppl"] = ppl
|
783 |
+
loss, ppl = compute_ppl_single(inputs_plus_w_bl_output, w_bl_output, oracle_model_name, oracle_model, oracle_tokenizer)
|
784 |
+
example["w_bl_loss"] = loss
|
785 |
+
example["w_bl_ppl"] = ppl
|
786 |
+
|
787 |
+
# del any temp values
|
788 |
+
return example
|
789 |
+
|
790 |
+
|
791 |
+
def add_idx(example,idx):
|
792 |
+
example.update({"idx":idx})
|
793 |
+
return example
|
794 |
+
|
795 |
+
|
796 |
+
def check_input_lengths(example,idx, min_sample_len=0, min_prompt_len=0, min_completion_len=0):
|
797 |
+
orig_sample_length = example["orig_sample_length"]
|
798 |
+
prompt_length = example["prompt_length"]
|
799 |
+
real_completion_length = example["real_completion_length"]
|
800 |
+
|
801 |
+
# breakpoint()
|
802 |
+
|
803 |
+
conds = all([
|
804 |
+
orig_sample_length >= min_sample_len,
|
805 |
+
prompt_length >= min_prompt_len,
|
806 |
+
real_completion_length >= min_completion_len,
|
807 |
+
])
|
808 |
+
return conds
|
809 |
+
|
810 |
+
|
811 |
+
def check_output_lengths(example,min_output_len=0):
|
812 |
+
no_bl_output_len = example["no_bl_num_tokens_generated"]
|
813 |
+
w_bl_output_len = example["w_bl_num_tokens_generated"]
|
814 |
+
conds = all([
|
815 |
+
no_bl_output_len >= min_output_len,
|
816 |
+
w_bl_output_len >= min_output_len,
|
817 |
+
])
|
818 |
+
return conds
|
819 |
+
|
820 |
+
|
lm-watermarking-main/experiments/watermarking_analysis.ipynb
ADDED
@@ -0,0 +1,2049 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Watermark Analysis\n",
|
8 |
+
"\n",
|
9 |
+
"Notebook for performing analysis and visualization of the effects of watermarking schemes"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"# Basic imports\n",
|
19 |
+
"import os\n",
|
20 |
+
"\n",
|
21 |
+
"from tqdm import tqdm\n",
|
22 |
+
"from statistics import mean\n",
|
23 |
+
"\n",
|
24 |
+
"import numpy as np\n",
|
25 |
+
"import pandas as pd\n",
|
26 |
+
"import torch\n",
|
27 |
+
"\n",
|
28 |
+
"import matplotlib.pyplot as plt\n",
|
29 |
+
"\n",
|
30 |
+
"from matplotlib import rc\n",
|
31 |
+
"rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
|
32 |
+
"rc('text', usetex=True)\n",
|
33 |
+
"\n",
|
34 |
+
"import cmasher as cmr"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"from datasets import load_from_disk"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"attachments": {},
|
48 |
+
"cell_type": "markdown",
|
49 |
+
"metadata": {},
|
50 |
+
"source": [
|
51 |
+
"### Load the processed dataset/frame"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": null,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"# save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n",
|
61 |
+
"# save_name = \"analysis_ds_1-21_greedy_redo\" \n",
|
62 |
+
"# save_name = \"analysis_ds_1-21_greedy_redo_truncated\"\n",
|
63 |
+
"# save_name = \"analysis_ds_1-23_greedy_gamma_0-25_truncated\" \n",
|
64 |
+
"# save_name = \"analysis_ds_1-23_greedy_gamma_0-25_0-5_truncated\" # in figure (not 100% sure this is correct, check)\n",
|
65 |
+
"\n",
|
66 |
+
"# save_name = \"analysis_ds_1-20_more_attack\" # in figure\n",
|
67 |
+
"\n",
|
68 |
+
"# save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n",
|
69 |
+
"# save_name = \"analysis_ds_1-23_en_1-3\"\n",
|
70 |
+
"save_name = \"analysis_ds_1-23_pile_1-3\"\n",
|
71 |
+
"\n",
|
72 |
+
"save_dir = f\"input/{save_name}\""
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [],
|
80 |
+
"source": [
|
81 |
+
"raw_data = load_from_disk(save_dir)"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "markdown",
|
86 |
+
"metadata": {},
|
87 |
+
"source": [
|
88 |
+
"#### convert to pandas df"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": null,
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"df = raw_data.to_pandas()"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": null,
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"print(f\"Orig number of rows: {len(df)}\")\n",
|
107 |
+
"df.tail()"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "code",
|
112 |
+
"execution_count": null,
|
113 |
+
"metadata": {},
|
114 |
+
"outputs": [],
|
115 |
+
"source": [
|
116 |
+
"df.columns"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"attachments": {},
|
121 |
+
"cell_type": "markdown",
|
122 |
+
"metadata": {},
|
123 |
+
"source": [
|
124 |
+
"### \"retokenization\" problem \n",
|
125 |
+
"\n",
|
126 |
+
"current hypo for what matches this criterion is based on the non 1-to-1 aspect of tokenization"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"execution_count": null,
|
132 |
+
"metadata": {},
|
133 |
+
"outputs": [],
|
134 |
+
"source": [
|
135 |
+
"retok_problematic_rows = df[(df['w_bl_whitelist_fraction'] != -1.0) & (df['w_bl_whitelist_fraction'] != 1.0) & (df['bl_type'] == 'hard')]\n",
|
136 |
+
"print(f\"Num rows that are hard-blacklisted, and measureable, but still have a non-100% WL fraction: {len(retok_problematic_rows)} out of {len(df[df['bl_type'] == 'hard'])}\")\n",
|
137 |
+
"# retok_problematic_rows"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"attachments": {},
|
142 |
+
"cell_type": "markdown",
|
143 |
+
"metadata": {},
|
144 |
+
"source": [
|
145 |
+
"#### Replace or drop the the specially marked -1 rows since these are unmeasureable due to short length"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "code",
|
150 |
+
"execution_count": null,
|
151 |
+
"metadata": {},
|
152 |
+
"outputs": [],
|
153 |
+
"source": [
|
154 |
+
"orig_len = len(df)\n",
|
155 |
+
"\n",
|
156 |
+
"# df['no_bl_whitelist_fraction'].mask(df['no_bl_whitelist_fraction'] == -1.0, pd.NA, inplace=True)\n",
|
157 |
+
"# df['w_bl_whitelist_fraction'].mask(df['w_bl_whitelist_fraction'] == -1.0, pd.NA, inplace=True)\n",
|
158 |
+
"\n",
|
159 |
+
"df = df[df[\"no_bl_whitelist_fraction\"] != -1.0]\n",
|
160 |
+
"df = df[df[\"w_bl_whitelist_fraction\"] != -1.0]\n",
|
161 |
+
"\n",
|
162 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "markdown",
|
167 |
+
"metadata": {},
|
168 |
+
"source": [
|
169 |
+
"#### Drop rows where there weren't enough tokens to measure ppl in one or both of the output cases"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": null,
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"orig_len = len(df)\n",
|
179 |
+
"# df = df[df[\"no_bl_ppl\"].isna()]\n",
|
180 |
+
"# df = df[df[\"w_bl_ppl\"].isna()]\n",
|
181 |
+
"df = df[~(df[\"no_bl_ppl\"].isna() | df[\"w_bl_ppl\"].isna())]\n",
|
182 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
|
183 |
+
]
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"attachments": {},
|
187 |
+
"cell_type": "markdown",
|
188 |
+
"metadata": {},
|
189 |
+
"source": [
|
190 |
+
"#### drop rows with really large bias, as 100.0 is $\\simeq \\infty$"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "code",
|
195 |
+
"execution_count": null,
|
196 |
+
"metadata": {},
|
197 |
+
"outputs": [],
|
198 |
+
"source": [
|
199 |
+
"orig_len = len(df)\n",
|
200 |
+
"\n",
|
201 |
+
"df = df[df[\"bl_logit_bias\"] <= 100.0]\n",
|
202 |
+
"\n",
|
203 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"attachments": {},
|
208 |
+
"cell_type": "markdown",
|
209 |
+
"metadata": {},
|
210 |
+
"source": [
|
211 |
+
"#### drop rows where using sampling but also beam search, not considering at this time"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": null,
|
217 |
+
"metadata": {},
|
218 |
+
"outputs": [],
|
219 |
+
"source": [
|
220 |
+
"orig_len = len(df)\n",
|
221 |
+
"\n",
|
222 |
+
"# df = df[df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False and tup[2] != 1) or (tup[0] == True and tup[2] == 1) or (tup[0] == False))]\n",
|
223 |
+
"df = df[((df[\"use_sampling\"]==True) & (df[\"num_beams\"] == 1)) | (df[\"use_sampling\"]==False)]\n",
|
224 |
+
"\n",
|
225 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "markdown",
|
230 |
+
"metadata": {},
|
231 |
+
"source": [
|
232 |
+
"#### correct the sampling temp column"
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"cell_type": "code",
|
237 |
+
"execution_count": null,
|
238 |
+
"metadata": {},
|
239 |
+
"outputs": [],
|
240 |
+
"source": [
|
241 |
+
"df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"].fillna(0.0)\n",
|
242 |
+
"df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"].fillna(1.0)"
|
243 |
+
]
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"attachments": {},
|
247 |
+
"cell_type": "markdown",
|
248 |
+
"metadata": {},
|
249 |
+
"source": [
|
250 |
+
"#### marking the hard blacklist rows as having inf/very large bias\n",
|
251 |
+
"\n",
|
252 |
+
"(after the > 100.0 bias drop)"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "code",
|
257 |
+
"execution_count": null,
|
258 |
+
"metadata": {},
|
259 |
+
"outputs": [],
|
260 |
+
"source": [
|
261 |
+
"df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = np.inf\n",
|
262 |
+
"# df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = 10000 # crosscheck with whats hardcoded in the bl processor"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"attachments": {},
|
267 |
+
"cell_type": "markdown",
|
268 |
+
"metadata": {},
|
269 |
+
"source": [
|
270 |
+
"#### Rename some parameters"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": null,
|
276 |
+
"metadata": {},
|
277 |
+
"outputs": [],
|
278 |
+
"source": [
|
279 |
+
"df[\"delta\"] = df[\"bl_logit_bias\"].values\n",
|
280 |
+
"df[\"gamma\"] = 1 - df[\"bl_proportion\"].values\n",
|
281 |
+
"df[\"gamma\"] = df[\"gamma\"].round(3)\n",
|
282 |
+
"\n",
|
283 |
+
"df[\"no_bl_act_num_wl_tokens\"] = np.round(df[\"no_bl_whitelist_fraction\"].values*df[\"no_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
|
284 |
+
"df[\"w_bl_act_num_wl_tokens\"] = np.round(df[\"w_bl_whitelist_fraction\"].values*df[\"w_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
|
285 |
+
"\n",
|
286 |
+
"df[\"w_bl_std_num_wl_tokens\"] = np.sqrt(df[\"w_bl_var_num_wl_tokens\"].values)\n",
|
287 |
+
"\n",
|
288 |
+
"if \"real_completion_length\":\n",
|
289 |
+
" df[\"baseline_num_tokens_generated\"] = df[\"real_completion_length\"].values\n",
|
290 |
+
"\n",
|
291 |
+
"if \"actual_attacked_ratio\" in df.columns:\n",
|
292 |
+
" df[\"actual_attacked_fraction\"] = df[\"actual_attacked_ratio\"].values*df[\"replace_ratio\"].values\n",
|
293 |
+
"\n",
|
294 |
+
"if \"meta\" in df.columns:\n",
|
295 |
+
" df[\"pile_set_name\"] = df[\"meta\"].apply(lambda dict: dict[\"pile_set_name\"])\n",
|
296 |
+
"\n",
|
297 |
+
"df[\"baseline_hit_list_length\"] = df[\"baseline_hit_list\"].apply(len)\n",
|
298 |
+
"df[\"no_bl_hit_list_length\"] = df[\"no_bl_hit_list\"].apply(len)\n",
|
299 |
+
"df[\"w_bl_hit_list_length\"] = df[\"w_bl_hit_list\"].apply(len)"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"cell_type": "code",
|
304 |
+
"execution_count": null,
|
305 |
+
"metadata": {},
|
306 |
+
"outputs": [],
|
307 |
+
"source": [
|
308 |
+
"# for pile outlier filtering\n",
|
309 |
+
"df[\"w_bl_space_count\"] = df[\"w_bl_output\"].apply(lambda string: string.count(\" \"))\n",
|
310 |
+
"df[\"no_bl_space_count\"] = df[\"no_bl_output\"].apply(lambda string: string.count(\" \"))\n",
|
311 |
+
"df[\"baseline_space_count\"] = df[\"baseline_completion\"].apply(lambda string: string.count(\" \"))\n",
|
312 |
+
"\n",
|
313 |
+
"df[\"w_bl_space_frac\"] = df[\"w_bl_space_count\"].values / df[\"w_bl_hit_list_length\"]\n",
|
314 |
+
"df[\"no_bl_space_frac\"] = df[\"no_bl_space_count\"].values / df[\"no_bl_hit_list_length\"]\n",
|
315 |
+
"df[\"baseline_space_frac\"] = df[\"baseline_space_count\"].values / df[\"baseline_hit_list_length\"]"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"attachments": {},
|
320 |
+
"cell_type": "markdown",
|
321 |
+
"metadata": {},
|
322 |
+
"source": [
|
323 |
+
"### Filter for the generation lengths we want to look at"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "code",
|
328 |
+
"execution_count": null,
|
329 |
+
"metadata": {},
|
330 |
+
"outputs": [],
|
331 |
+
"source": [
|
332 |
+
"orig_len = len(df)\n",
|
333 |
+
"\n",
|
334 |
+
"# # main filters\n",
|
335 |
+
"# # df = df[(df[\"real_completion_length\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)]\n",
|
336 |
+
"# df = df[(df[\"gamma\"] == 0.1) | (df[\"gamma\"] == 0.25) | (df[\"gamma\"] == 0.5)]\n",
|
337 |
+
"# df = df[(df[\"delta\"] == 1.0) | (df[\"delta\"] == 2.0) | (df[\"delta\"] == 10.0)]\n",
|
338 |
+
"# df = df[(df[\"use_sampling\"] == True)]\n",
|
339 |
+
"# df = df[(df[\"bl_type\"] == \"soft\")]\n",
|
340 |
+
"\n",
|
341 |
+
"# df = df[(df[\"real_completion_length\"] == 200) & (df[\"no_bl_num_tokens_generated\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)] # now also applies to the truncated version\n",
|
342 |
+
"# df = df[(df[\"no_bl_num_tokens_generated\"] >= 500) & (df[\"w_bl_num_tokens_generated\"] >= 500)] # all gas noop\n",
|
343 |
+
"\n",
|
344 |
+
"# # # attack specific\n",
|
345 |
+
"# df = df[(df[\"real_completion_length\"] == 200) & (df[\"no_bl_num_tokens_generated\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)]\n",
|
346 |
+
"# df = df[(df[\"replace_ratio\"] <= 0.7)]\n",
|
347 |
+
"\n",
|
348 |
+
"# NOTE pile only\n",
|
349 |
+
"df = df[df[\"w_bl_space_frac\"] <= 0.9]\n",
|
350 |
+
"df = df[df[\"no_bl_space_frac\"] <= 0.9]\n",
|
351 |
+
"# df = df[df[\"pile_set_name\"] != \"Github\"]\n",
|
352 |
+
"\n",
|
353 |
+
"upper_T = 205\n",
|
354 |
+
"lower_T = 195\n",
|
355 |
+
"df = df[(df[\"baseline_hit_list_length\"] >= lower_T) & (df[\"no_bl_hit_list_length\"] >= lower_T) & (df[\"w_bl_hit_list_length\"] >= lower_T)] # now also applies to the truncated version\n",
|
356 |
+
"df = df[(df[\"baseline_hit_list_length\"] <= upper_T) & (df[\"no_bl_hit_list_length\"] <= upper_T) & (df[\"w_bl_hit_list_length\"] <= upper_T)] # now also applies to the truncated version\n",
|
357 |
+
"\n",
|
358 |
+
"\n",
|
359 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"attachments": {},
|
364 |
+
"cell_type": "markdown",
|
365 |
+
"metadata": {},
|
366 |
+
"source": [
|
367 |
+
"# Add z-scores (convert the raw watermark measurement, fraction, to a z-score )"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "code",
|
372 |
+
"execution_count": null,
|
373 |
+
"metadata": {},
|
374 |
+
"outputs": [],
|
375 |
+
"source": [
|
376 |
+
"from math import sqrt\n",
|
377 |
+
"import scipy.stats\n",
|
378 |
+
"def compute_z_score(observed_wl_frac, T, gamma):\n",
|
379 |
+
" numer = observed_wl_frac - gamma\n",
|
380 |
+
" denom = sqrt(gamma*(1-gamma)/T)\n",
|
381 |
+
" z = numer/denom\n",
|
382 |
+
" return z\n",
|
383 |
+
"\n",
|
384 |
+
"def compute_wl_for_z(z, T, gamma):\n",
|
385 |
+
" denom = sqrt(gamma*(1-gamma)/T)\n",
|
386 |
+
" numer = ((z*denom)+gamma)*T\n",
|
387 |
+
" return numer\n",
|
388 |
+
"\n",
|
389 |
+
"def compute_p_value(z):\n",
|
390 |
+
" p_value = scipy.stats.norm.sf(abs(z))\n",
|
391 |
+
" return p_value\n",
|
392 |
+
"\n",
|
393 |
+
"df[\"baseline_z_score\"] = df[[\"baseline_whitelist_fraction\", \"baseline_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
|
394 |
+
"df[\"no_bl_z_score\"] = df[[\"no_bl_whitelist_fraction\", \"no_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
|
395 |
+
"df[\"w_bl_z_score\"] = df[[\"w_bl_whitelist_fraction\", \"w_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
|
396 |
+
"\n",
|
397 |
+
"if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
|
398 |
+
" df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)"
|
399 |
+
]
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"cell_type": "code",
|
403 |
+
"execution_count": null,
|
404 |
+
"metadata": {},
|
405 |
+
"outputs": [],
|
406 |
+
"source": [
|
407 |
+
"# if attacked in df\n",
|
408 |
+
"if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
|
409 |
+
" df[\"w_bl_attacked_act_num_wl_tokens\"] = np.round(df[\"w_bl_attacked_whitelist_fraction\"].values*df[\"w_bl_attacked_num_tokens_generated\"],1) # round to 1 for sanity\n",
|
410 |
+
"\n",
|
411 |
+
" df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
|
412 |
+
"\n",
|
413 |
+
" df[[\"bl_proportion\",\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\",\"w_bl_attacked_act_num_wl_tokens\", \"w_bl_attacked_z_score\"]]"
|
414 |
+
]
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"attachments": {},
|
418 |
+
"cell_type": "markdown",
|
419 |
+
"metadata": {},
|
420 |
+
"source": [
|
421 |
+
"# Prepare groupby (decide which hyperparameters to groups the rows by)"
|
422 |
+
]
|
423 |
+
},
|
424 |
+
{
|
425 |
+
"cell_type": "code",
|
426 |
+
"execution_count": null,
|
427 |
+
"metadata": {},
|
428 |
+
"outputs": [],
|
429 |
+
"source": [
|
430 |
+
"# groupby_fields = ['num_beams', 'max_new_tokens']\n",
|
431 |
+
"# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens']\n",
|
432 |
+
"# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens', 'bl_logit_bias']\n",
|
433 |
+
"# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias']\n",
|
434 |
+
"# groupby_fields = ['use_sampling','sampling_temp','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias']\n",
|
435 |
+
"# groupby_fields = ['use_sampling','sampling_temp','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias','bl_proportion']\n",
|
436 |
+
"# groupby_fields = ['use_sampling','num_beams','bl_type','bl_logit_bias','bl_proportion']\n",
|
437 |
+
"\n",
|
438 |
+
"if \"w_bl_attacked_whitelist_fraction\" in df.columns: \n",
|
439 |
+
" groupby_fields = ['use_sampling','num_beams','gamma','delta', 'replace_ratio'] # attack grouping\n",
|
440 |
+
"else:\n",
|
441 |
+
" groupby_fields = ['use_sampling','num_beams','delta','gamma'] # regular grouping\n",
|
442 |
+
" # groupby_fields = ['use_sampling','delta','gamma'] # regular grouping, but no beam variation\n",
|
443 |
+
" # groupby_fields = ['delta','gamma'] # regular grouping, but no beam variation, and all sampling"
|
444 |
+
]
|
445 |
+
},
|
446 |
+
{
|
447 |
+
"attachments": {},
|
448 |
+
"cell_type": "markdown",
|
449 |
+
"metadata": {},
|
450 |
+
"source": [
|
451 |
+
"### narrowing in on IQ range (not generally used)\n",
|
452 |
+
"\n",
|
453 |
+
"(removing outliers by subsetting to rows near the mean etc.)"
|
454 |
+
]
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"cell_type": "code",
|
458 |
+
"execution_count": null,
|
459 |
+
"metadata": {},
|
460 |
+
"outputs": [],
|
461 |
+
"source": [
|
462 |
+
"# tmp_grped_25 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.25).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_25th'})\n",
|
463 |
+
"# tmp_grped_50 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.5).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_50th'})\n",
|
464 |
+
"# tmp_grped_75 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.75).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_75th'})\n",
|
465 |
+
"# df = df.merge(tmp_grped_25, on = groupby_fields)\n",
|
466 |
+
"# df = df.merge(tmp_grped_50, on = groupby_fields)\n",
|
467 |
+
"# df = df.merge(tmp_grped_75, on = groupby_fields)\n",
|
468 |
+
"\n",
|
469 |
+
"# # tmp_grped_mean = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].mean().rename(columns={'avg_spike_entropy': 'avg_spike_entropy_mean'})\n",
|
470 |
+
"# # tmp_grped_median = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].median().rename(columns={'avg_spike_entropy': 'avg_spike_entropy_median'})\n",
|
471 |
+
"# # df = df.merge(tmp_grped_mean, on = groupby_fields)\n",
|
472 |
+
"# # df = df.merge(tmp_grped_median, on = groupby_fields)\n"
|
473 |
+
]
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"cell_type": "code",
|
477 |
+
"execution_count": null,
|
478 |
+
"metadata": {},
|
479 |
+
"outputs": [],
|
480 |
+
"source": [
|
481 |
+
"# # eps = 0.001\n",
|
482 |
+
"# eps = 0.005\n",
|
483 |
+
"# df[\"avg_spike_entropy_mean_minus_eps\"] = df['avg_spike_entropy_mean']-eps\n",
|
484 |
+
"# df[\"avg_spike_entropy_mean_plus_eps\"] = df['avg_spike_entropy_mean']+eps\n",
|
485 |
+
"\n",
|
486 |
+
"# df[\"avg_spike_entropy_median_minus_eps\"] = df['avg_spike_entropy_median']-eps\n",
|
487 |
+
"# df[\"avg_spike_entropy_median_plus_eps\"] = df['avg_spike_entropy_median']+eps\n",
|
488 |
+
"# print(df.columns)"
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"cell_type": "code",
|
493 |
+
"execution_count": null,
|
494 |
+
"metadata": {},
|
495 |
+
"outputs": [],
|
496 |
+
"source": [
|
497 |
+
"# # df[[\"avg_spike_entropy_25th\",\"avg_spike_entropy_75th\"]]\n",
|
498 |
+
"# df[[\"avg_spike_entropy_mean_minus_eps\",\"avg_spike_entropy_mean\",\"avg_spike_entropy_mean_plus_eps\"]]\n",
|
499 |
+
"# df[[\"avg_spike_entropy_median_minus_eps\",\"avg_spike_entropy_median\",\"avg_spike_entropy_median_plus_eps\"]]"
|
500 |
+
]
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"cell_type": "code",
|
504 |
+
"execution_count": null,
|
505 |
+
"metadata": {},
|
506 |
+
"outputs": [],
|
507 |
+
"source": [
|
508 |
+
"# orig_len = len(df)\n",
|
509 |
+
"\n",
|
510 |
+
"# subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_25th\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_75th\"])]\n",
|
511 |
+
"\n",
|
512 |
+
"# # subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_mean_minus_eps\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_mean_plus_eps\"])]\n",
|
513 |
+
"# # subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_mean_minus_eps\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_mean_plus_eps\"])]\n",
|
514 |
+
"\n",
|
515 |
+
"# print(f\"Dropped {orig_len-len(subdf)} rows, new len {len(subdf)}\")"
|
516 |
+
]
|
517 |
+
},
|
518 |
+
{
|
519 |
+
"cell_type": "code",
|
520 |
+
"execution_count": null,
|
521 |
+
"metadata": {},
|
522 |
+
"outputs": [],
|
523 |
+
"source": [
|
524 |
+
"# subdf.groupby(groupby_fields)['avg_spike_entropy'].describe()\n",
|
525 |
+
"# df.groupby(groupby_fields)['avg_spike_entropy'].describe()"
|
526 |
+
]
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"cell_type": "code",
|
530 |
+
"execution_count": null,
|
531 |
+
"metadata": {},
|
532 |
+
"outputs": [],
|
533 |
+
"source": [
|
534 |
+
"# df = subdf"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
{
|
538 |
+
"attachments": {},
|
539 |
+
"cell_type": "markdown",
|
540 |
+
"metadata": {},
|
541 |
+
"source": [
|
542 |
+
"# Perform the groupby (group rows by their hyperparameter settings)"
|
543 |
+
]
|
544 |
+
},
|
545 |
+
{
|
546 |
+
"cell_type": "code",
|
547 |
+
"execution_count": null,
|
548 |
+
"metadata": {},
|
549 |
+
"outputs": [],
|
550 |
+
"source": [
|
551 |
+
"grouped_df = df.groupby(groupby_fields)"
|
552 |
+
]
|
553 |
+
},
|
554 |
+
{
|
555 |
+
"cell_type": "code",
|
556 |
+
"execution_count": null,
|
557 |
+
"metadata": {},
|
558 |
+
"outputs": [],
|
559 |
+
"source": [
|
560 |
+
"print(f\"Number of rows after filtering: {len(df)}\")\n",
|
561 |
+
"print(f\"Number of groups: {len(grouped_df)}\")"
|
562 |
+
]
|
563 |
+
},
|
564 |
+
{
|
565 |
+
"attachments": {},
|
566 |
+
"cell_type": "markdown",
|
567 |
+
"metadata": {},
|
568 |
+
"source": [
|
569 |
+
"# Loop to compute \"confusion matrix\" (TPR,FPR etc.) at some z scores for tabulation (Table 2 & 8)"
|
570 |
+
]
|
571 |
+
},
|
572 |
+
{
|
573 |
+
"cell_type": "code",
|
574 |
+
"execution_count": null,
|
575 |
+
"metadata": {},
|
576 |
+
"outputs": [],
|
577 |
+
"source": [
|
578 |
+
"import sklearn.metrics as metrics\n",
|
579 |
+
"\n",
|
580 |
+
"def reject_null_hypo(z_score=None,cuttoff=None):\n",
|
581 |
+
" return z_score > cuttoff\n",
|
582 |
+
"\n",
|
583 |
+
"records = []\n",
|
584 |
+
"\n",
|
585 |
+
"for group_params in tqdm(list(grouped_df.groups.keys())):\n",
|
586 |
+
" sub_df = grouped_df.get_group(group_params)\n",
|
587 |
+
" grp_size = len(sub_df)\n",
|
588 |
+
"\n",
|
589 |
+
" # baseline_z_scores = sub_df[\"baseline_z_score\"].values\n",
|
590 |
+
" # w_bl_z_scores = sub_df[\"w_bl_z_score\"].values\n",
|
591 |
+
" # all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n",
|
592 |
+
"\n",
|
593 |
+
" # baseline_labels = np.zeros_like(baseline_z_scores)\n",
|
594 |
+
" # attacked_labels = np.ones_like(w_bl_z_scores)\n",
|
595 |
+
" # all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
|
596 |
+
"\n",
|
597 |
+
" # fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
|
598 |
+
" # roc_auc = metrics.auc(fpr, tpr)\n",
|
599 |
+
" record = {k:v for k,v in zip(groupby_fields,group_params)}\n",
|
600 |
+
"\n",
|
601 |
+
" for thresh in [4.0,5.0]:\n",
|
602 |
+
" \n",
|
603 |
+
" record[\"count\"] = grp_size\n",
|
604 |
+
" record[f\"baseline_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"baseline_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
|
605 |
+
" record[f\"baseline_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"baseline_z_score\"],cuttoff=thresh)).sum() / grp_size\n",
|
606 |
+
" record[f\"no_bl_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
|
607 |
+
" record[f\"no_bl_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
|
608 |
+
" record[f\"w_bl_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
|
609 |
+
" record[f\"w_bl_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
|
610 |
+
"\n",
|
611 |
+
" if \"w_bl_attacked_z_score\" in sub_df.columns:\n",
|
612 |
+
" record[f\"w_bl_attacked_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
|
613 |
+
" record[f\"w_bl_attacked_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
|
614 |
+
"\n",
|
615 |
+
" records.append(record)\n",
|
616 |
+
"\n",
|
617 |
+
" # # df[f\"baseline_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"baseline_z_score\"].values,cuttoff=thresh)\n",
|
618 |
+
" # # df[f\"baseline_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"baseline_z_score\"],cuttoff=thresh)\n",
|
619 |
+
" # # df[f\"no_bl_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
|
620 |
+
" # # df[f\"no_bl_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
|
621 |
+
" # # df[f\"w_bl_tp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
|
622 |
+
" # # df[f\"w_bl_fn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
|
623 |
+
"\n",
|
624 |
+
"\n",
|
625 |
+
"roc_df = pd.DataFrame.from_records(records)\n"
|
626 |
+
]
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"cell_type": "code",
|
630 |
+
"execution_count": null,
|
631 |
+
"metadata": {},
|
632 |
+
"outputs": [],
|
633 |
+
"source": [
|
634 |
+
"# thresh = 6.0\n",
|
635 |
+
"# thresh = 5.0\n",
|
636 |
+
"std_threshes = [4.0, 5.0] #, 6.0]\n",
|
637 |
+
"# std_threshes = [4.0]\n",
|
638 |
+
"\n",
|
639 |
+
"# roc_df[\"params\"] = roc_df.index.to_list()\n",
|
640 |
+
"\n",
|
641 |
+
"# columns = [\"num_beams\", \"delta\", \"gamma\", \"count\"]\n",
|
642 |
+
"# columns = [\"delta\", \"gamma\", \"count\"]\n",
|
643 |
+
"columns = [\"use_sampling\",\"delta\", \"gamma\", \"count\"]\n",
|
644 |
+
"# columns = [\"use_sampling\", \"replace_ratio\", \"count\"]\n",
|
645 |
+
"\n",
|
646 |
+
"for thresh in std_threshes:\n",
|
647 |
+
" # columns += [f\"baseline_fpr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\"]\n",
|
648 |
+
" # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"no_bl_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fn_at_{thresh}\"]\n",
|
649 |
+
"\n",
|
650 |
+
"\n",
|
651 |
+
" # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
|
652 |
+
" \n",
|
653 |
+
" if f\"w_bl_attacked_fnr_at_{thresh}\" in roc_df.columns:\n",
|
654 |
+
" columns += [f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
|
655 |
+
" columns += [f\"w_bl_attacked_tpr_at_{thresh}\",f\"w_bl_attacked_fnr_at_{thresh}\"] # if attack\n",
|
656 |
+
" else:\n",
|
657 |
+
" columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
|
658 |
+
"\n",
|
659 |
+
"# filter ot not\n",
|
660 |
+
"sub_df = roc_df[(roc_df[\"use_sampling\"] == True) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 5.0)) & ((roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) )]\n",
|
661 |
+
"# sub_df = roc_df[(roc_df[\"use_sampling\"] == False) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 5.0)) & ((roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) ) & (roc_df[\"num_beams\"] == 8)]\n",
|
662 |
+
"# sub_df = roc_df[(roc_df[\"replace_ratio\"] == 0.1) | (roc_df[\"replace_ratio\"] == 0.3) | (roc_df[\"replace_ratio\"] == 0.5) | (roc_df[\"replace_ratio\"] == 0.7)]\n",
|
663 |
+
"# sub_df = roc_df[(roc_df[\"num_beams\"] == 8)]\n",
|
664 |
+
"# sub_df = roc_df\n",
|
665 |
+
"\n",
|
666 |
+
"# sub_df.sort_values(\"delta\")[columns]\n",
|
667 |
+
"# sub_df.sort_values(\"num_beams\")[columns]\n",
|
668 |
+
"sub_df.sort_values(by=[\"delta\",\"gamma\"],ascending=[True, False])[columns]"
|
669 |
+
]
|
670 |
+
},
|
671 |
+
{
|
672 |
+
"attachments": {},
|
673 |
+
"cell_type": "markdown",
|
674 |
+
"metadata": {},
|
675 |
+
"source": [
|
676 |
+
"#### write tables to latex"
|
677 |
+
]
|
678 |
+
},
|
679 |
+
{
|
680 |
+
"cell_type": "code",
|
681 |
+
"execution_count": null,
|
682 |
+
"metadata": {},
|
683 |
+
"outputs": [],
|
684 |
+
"source": [
|
685 |
+
"# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"gamma\").round(3).to_latex(index=False))\n",
|
686 |
+
"# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"delta\").round(3).to_latex(index=False))\n",
|
687 |
+
"# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"num_beams\").round(3).to_latex(index=False))\n",
|
688 |
+
"\n",
|
689 |
+
"# print(sub_df.sort_values(by=[\"delta\",\"gamma\"],ascending=[True, False])[columns].round(3).to_latex(index=False))\n",
|
690 |
+
"# print(sub_df.sort_values(\"num_beams\")[columns].round(3).to_latex(index=False))"
|
691 |
+
]
|
692 |
+
},
|
693 |
+
{
|
694 |
+
"attachments": {},
|
695 |
+
"cell_type": "markdown",
|
696 |
+
"metadata": {},
|
697 |
+
"source": [
|
698 |
+
"# ROC: No Attack (figure 4)"
|
699 |
+
]
|
700 |
+
},
|
701 |
+
{
|
702 |
+
"cell_type": "code",
|
703 |
+
"execution_count": null,
|
704 |
+
"metadata": {},
|
705 |
+
"outputs": [],
|
706 |
+
"source": [
|
707 |
+
"plt.clf()\n",
|
708 |
+
"plt.figure(constrained_layout=True)\n",
|
709 |
+
"plt.figure(figsize=(5, 4))\n",
|
710 |
+
"\n",
|
711 |
+
"import sklearn.metrics as metrics\n",
|
712 |
+
"\n",
|
713 |
+
"zoom = False\n",
|
714 |
+
"# zoom = True\n",
|
715 |
+
"\n",
|
716 |
+
"beam_search = None\n",
|
717 |
+
"# beam_search = 1\n",
|
718 |
+
"# beam_search = 4\n",
|
719 |
+
"# beam_search = 8\n",
|
720 |
+
"\n",
|
721 |
+
"deltas = [1.0,2.0,5.0,10.0]\n",
|
722 |
+
"# gammas = [0.25, 0.5]\n",
|
723 |
+
"gammas = [0.25]\n",
|
724 |
+
"# gammas = [0.5]\n",
|
725 |
+
"\n",
|
726 |
+
"# deltas = [1.0,2.0,5.0,10.0]\n",
|
727 |
+
"# gammas = [0.1,0.5]\n",
|
728 |
+
"\n",
|
729 |
+
"groups = []\n",
|
730 |
+
"names = []\n",
|
731 |
+
"for d in deltas:\n",
|
732 |
+
" for g in gammas:\n",
|
733 |
+
" if beam_search:\n",
|
734 |
+
" groups.append((False, beam_search, d, g))\n",
|
735 |
+
" else:\n",
|
736 |
+
" groups.append((True, 1, d, g))\n",
|
737 |
+
" names.append(f\"$\\delta:{d},\\gamma:{g}$\")\n",
|
738 |
+
"groups=groups[::-1]\n",
|
739 |
+
"names=names[::-1]\n",
|
740 |
+
"\n",
|
741 |
+
"# Make colormap\n",
|
742 |
+
"import matplotlib.pyplot as plt\n",
|
743 |
+
"viridis = plt.colormaps['viridis'].resampled(len(groups)+1) \n",
|
744 |
+
"cmap = viridis.colors[:len(groups)][::-1]\n",
|
745 |
+
"\n",
|
746 |
+
"# plot different parameter levels\n",
|
747 |
+
"for i,(group,name) in enumerate(zip(groups,names)):\n",
|
748 |
+
"\n",
|
749 |
+
" baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n",
|
750 |
+
" w_bl_z_scores = grouped_df.get_group(group)[\"w_bl_z_score\"].values\n",
|
751 |
+
" all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n",
|
752 |
+
"\n",
|
753 |
+
" baseline_labels = np.zeros_like(baseline_z_scores)\n",
|
754 |
+
" attacked_labels = np.ones_like(w_bl_z_scores)\n",
|
755 |
+
" all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
|
756 |
+
"\n",
|
757 |
+
" fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
|
758 |
+
" roc_auc = metrics.auc(fpr, tpr)\n",
|
759 |
+
"\n",
|
760 |
+
" plt.plot(fpr, tpr, color=cmap[i], label = f'{name}, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n",
|
761 |
+
"\n",
|
762 |
+
"if \"w_bl_attacked_ppl\" in df.columns:\n",
|
763 |
+
" pass\n",
|
764 |
+
"else:\n",
|
765 |
+
" # # vanilla ppl value\n",
|
766 |
+
" plt.scatter([-1],[-1],label=f' $\\delta=0$, PPL: {round(grouped_df[\"no_bl_ppl\"].describe().loc[groups,\"mean\"].mean(),1)}', color=\"white\")\n",
|
767 |
+
"\n",
|
768 |
+
"if zoom:\n",
|
769 |
+
" if not \"w_bl_attacked_ppl\" in df.columns:\n",
|
770 |
+
" plt.legend(loc = 'lower right', fontsize = 12)\n",
|
771 |
+
" plt.xscale(\"log\")\n",
|
772 |
+
" # plt.yscale(\"log\")\n",
|
773 |
+
" plt.xlim([0, 1])\n",
|
774 |
+
" plt.ylim([0.5, 1])\n",
|
775 |
+
" plot_name = (\"roc_auc_zoom\" if not beam_search else f\"roc_auc_zoom_greedy_beams_{beam_search}\")\n",
|
776 |
+
"\n",
|
777 |
+
"else:\n",
|
778 |
+
" if \"w_bl_attacked_ppl\" in df.columns:\n",
|
779 |
+
" plt.legend(loc = 'lower right', fontsize = 12)\n",
|
780 |
+
" plt.plot([0, 1], [0, 1],'r--')\n",
|
781 |
+
" plt.xlim([0, 1])\n",
|
782 |
+
" plt.ylim([0, 1])\n",
|
783 |
+
" plot_name = (\"roc_auc\" if not beam_search else f\"roc_auc_greedy_beams_{beam_search}\")\n",
|
784 |
+
"\n",
|
785 |
+
"plt.ylabel('True Positive Rate', fontsize = 12)\n",
|
786 |
+
"plt.xlabel('False Positive Rate', fontsize = 12)\n",
|
787 |
+
"\n",
|
788 |
+
"print(plot_name)\n",
|
789 |
+
"\n",
|
790 |
+
"# fname = f\"figs/{plot_name}.pdf\"\n",
|
791 |
+
"# plt.savefig(fname, format=\"pdf\")\n",
|
792 |
+
"\n",
|
793 |
+
"plt.show()"
|
794 |
+
]
|
795 |
+
},
|
796 |
+
{
|
797 |
+
"attachments": {},
|
798 |
+
"cell_type": "markdown",
|
799 |
+
"metadata": {},
|
800 |
+
"source": [
|
801 |
+
"\n",
|
802 |
+
"# ROC: Attack (figure 6)"
|
803 |
+
]
|
804 |
+
},
|
805 |
+
{
|
806 |
+
"cell_type": "code",
|
807 |
+
"execution_count": null,
|
808 |
+
"metadata": {},
|
809 |
+
"outputs": [],
|
810 |
+
"source": [
|
811 |
+
"import sklearn.metrics as metrics\n",
|
812 |
+
"\n",
|
813 |
+
"plt.clf()\n",
|
814 |
+
"plt.figure(constrained_layout=True)\n",
|
815 |
+
"plt.figure(figsize=(5, 4))\n",
|
816 |
+
"\n",
|
817 |
+
"# attack_budgets = [0.1,0.2,0.3,0.4,0.5,0.6,0.7]\n",
|
818 |
+
"attack_budgets = [0.1,0.3,0.5,0.7]\n",
|
819 |
+
"groups = [(True, 1, 0.5, 2.0, budget) for budget in attack_budgets]\n",
|
820 |
+
"beams = False\n",
|
821 |
+
"# groups = [(False, 8, 0.5, 2.0, budget) for budget in attack_budgets]\n",
|
822 |
+
"# beams = True\n",
|
823 |
+
"\n",
|
824 |
+
"names = [f\"$\\epsilon={eps}$\" for eps in attack_budgets]\n",
|
825 |
+
"\n",
|
826 |
+
"# Make colormap\n",
|
827 |
+
"import matplotlib.pyplot as plt\n",
|
828 |
+
"viridis = plt.colormaps['viridis'].resampled(len(groups)+1+1) # attack\n",
|
829 |
+
"cmap = viridis.colors[:len(groups)+1][::-1]\n",
|
830 |
+
"\n",
|
831 |
+
"# plot original\n",
|
832 |
+
"group = groups[0] # any will do\n",
|
833 |
+
"baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n",
|
834 |
+
"baseline_labels = np.zeros_like(baseline_z_scores)\n",
|
835 |
+
"\n",
|
836 |
+
"orig_watermark_z_scores = grouped_df.get_group(group)[\"w_bl_z_score\"].values\n",
|
837 |
+
"watermark_labels = np.ones_like(orig_watermark_z_scores)\n",
|
838 |
+
"\n",
|
839 |
+
"all_scores = np.concatenate([baseline_z_scores,orig_watermark_z_scores])\n",
|
840 |
+
"all_labels = np.concatenate([baseline_labels,watermark_labels])\n",
|
841 |
+
"\n",
|
842 |
+
"fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
|
843 |
+
"roc_auc = metrics.auc(fpr, tpr)\n",
|
844 |
+
"\n",
|
845 |
+
"plt.plot(fpr, tpr, color=cmap[0], label = f'unattacked, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n",
|
846 |
+
"\n",
|
847 |
+
"# plot different attack levels\n",
|
848 |
+
"for i,(group,name) in enumerate(zip(groups,names)):\n",
|
849 |
+
"\n",
|
850 |
+
" baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n",
|
851 |
+
" attacked_z_scores = grouped_df.get_group(group)[\"w_bl_attacked_z_score\"].values\n",
|
852 |
+
" all_scores = np.concatenate([baseline_z_scores,attacked_z_scores])\n",
|
853 |
+
"\n",
|
854 |
+
" baseline_labels = np.zeros_like(baseline_z_scores)\n",
|
855 |
+
" attacked_labels = np.ones_like(attacked_z_scores)\n",
|
856 |
+
" all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
|
857 |
+
"\n",
|
858 |
+
" fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
|
859 |
+
" roc_auc = metrics.auc(fpr, tpr)\n",
|
860 |
+
"\n",
|
861 |
+
" plt.plot(fpr, tpr, color=cmap[i+1], label = f'{name}, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_attacked_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n",
|
862 |
+
"\n",
|
863 |
+
"if \"w_bl_attacked_ppl\" in df.columns:\n",
|
864 |
+
" pass\n",
|
865 |
+
"else:\n",
|
866 |
+
" # # vanilla ppl value\n",
|
867 |
+
" plt.scatter([-1],[-1],label=f' $\\delta=0$, PPL: {round(grouped_df[\"no_bl_ppl\"].describe().loc[groups,\"mean\"].mean(),1)}', color=\"white\")\n",
|
868 |
+
"\n",
|
869 |
+
"zoom = False\n",
|
870 |
+
"# zoom = True\n",
|
871 |
+
"if zoom:\n",
|
872 |
+
" if not \"w_bl_attacked_ppl\" in df.columns:\n",
|
873 |
+
" plt.legend(loc = 'lower right')\n",
|
874 |
+
" plt.xscale(\"log\")\n",
|
875 |
+
" # plt.yscale(\"log\")\n",
|
876 |
+
" plt.xlim([0, 1])\n",
|
877 |
+
" plt.ylim([0.5, 1])\n",
|
878 |
+
" if \"w_bl_attacked_ppl\" in df.columns:\n",
|
879 |
+
" plot_name = \"roc_auc_untargeted_attack_no_beams_zoom\"\n",
|
880 |
+
" # plot_name = \"roc_auc_untargeted_attack_with_beams_zoom\"\n",
|
881 |
+
" else:\n",
|
882 |
+
" plot_name = \"roc_auc_zoom\"\n",
|
883 |
+
"else:\n",
|
884 |
+
" if \"w_bl_attacked_ppl\" in df.columns:\n",
|
885 |
+
" plt.legend(loc = 'lower right',fontsize = 9)\n",
|
886 |
+
" plt.plot([0, 1], [0, 1],'r--')\n",
|
887 |
+
" plt.xlim([0, 1])\n",
|
888 |
+
" plt.ylim([0, 1])\n",
|
889 |
+
" if \"w_bl_attacked_ppl\" in df.columns:\n",
|
890 |
+
" if beams: plot_name = \"roc_auc_untargeted_attack_w_beams\"\n",
|
891 |
+
" if not beams: plot_name = \"roc_auc_untargeted_attack_no_beams\"\n",
|
892 |
+
" else:\n",
|
893 |
+
" plot_name = \"roc_auc\"\n",
|
894 |
+
"\n",
|
895 |
+
"plt.ylabel('True Positive Rate')\n",
|
896 |
+
"plt.xlabel('False Positive Rate')\n",
|
897 |
+
"\n",
|
898 |
+
"print(plot_name)\n",
|
899 |
+
"\n",
|
900 |
+
"# fname = f\"figs/{plot_name}.pdf\"\n",
|
901 |
+
"# plt.savefig(fname, format=\"pdf\")\n",
|
902 |
+
"\n",
|
903 |
+
"plt.show()"
|
904 |
+
]
|
905 |
+
},
|
906 |
+
{
|
907 |
+
"cell_type": "code",
|
908 |
+
"execution_count": null,
|
909 |
+
"metadata": {},
|
910 |
+
"outputs": [],
|
911 |
+
"source": []
|
912 |
+
},
|
913 |
+
{
|
914 |
+
"attachments": {},
|
915 |
+
"cell_type": "markdown",
|
916 |
+
"metadata": {},
|
917 |
+
"source": [
|
918 |
+
"# Z vs T (figure 3)"
|
919 |
+
]
|
920 |
+
},
|
921 |
+
{
|
922 |
+
"cell_type": "code",
|
923 |
+
"execution_count": null,
|
924 |
+
"metadata": {},
|
925 |
+
"outputs": [],
|
926 |
+
"source": [
|
927 |
+
"plt.clf()\n",
|
928 |
+
"plt.figure(constrained_layout=True)\n",
|
929 |
+
"plt.figure(figsize=(5, 4))\n",
|
930 |
+
"\n",
|
931 |
+
"# save_fig = True\n",
|
932 |
+
"save_fig = False\n",
|
933 |
+
"\n",
|
934 |
+
"z_scores = True\n",
|
935 |
+
"# z_scores = False\n",
|
936 |
+
"\n",
|
937 |
+
"beam_search = None\n",
|
938 |
+
"# beam_search = 1\n",
|
939 |
+
"# beam_search = 4\n",
|
940 |
+
"# beam_search = 8\n",
|
941 |
+
"\n",
|
942 |
+
"ablate = \"delta\"\n",
|
943 |
+
"delta_gammas = [\n",
|
944 |
+
" # (0.5,0.25),\n",
|
945 |
+
" # (1.0,0.25),\n",
|
946 |
+
" # (2.0,0.25),\n",
|
947 |
+
" # (5.0,0.25),\n",
|
948 |
+
" # (10.0,0.25),\n",
|
949 |
+
" (0.5,0.5),\n",
|
950 |
+
" (1.0,0.5),\n",
|
951 |
+
" (2.0,0.5),\n",
|
952 |
+
" (5.0,0.5),\n",
|
953 |
+
" (10.0,0.5),\n",
|
954 |
+
"]\n",
|
955 |
+
"# ablate = \"gamma\"\n",
|
956 |
+
"# delta_gammas = [\n",
|
957 |
+
"# # (5.0,0.9),\n",
|
958 |
+
"# # (5.0,0.75),\n",
|
959 |
+
"# # (5.0,0.5),\n",
|
960 |
+
"# # (5.0,0.25),\n",
|
961 |
+
"# # (5.0,0.1),\n",
|
962 |
+
"# (2.0,0.9),\n",
|
963 |
+
"# (2.0,0.75),\n",
|
964 |
+
"# (2.0,0.5),\n",
|
965 |
+
"# (2.0,0.25),\n",
|
966 |
+
"# (2.0,0.1),\n",
|
967 |
+
"# ]\n",
|
968 |
+
"# if not z_scores: delta_gammas = delta_gammas[::-1]\n",
|
969 |
+
"\n",
|
970 |
+
"groups = []\n",
|
971 |
+
"names = []\n",
|
972 |
+
"\n",
|
973 |
+
"for d,g in delta_gammas:\n",
|
974 |
+
" if beam_search:\n",
|
975 |
+
" groups.append((False, beam_search, d, g))\n",
|
976 |
+
" else:\n",
|
977 |
+
" groups.append((True, 1, d, g))\n",
|
978 |
+
" names.append(f\"$\\delta:{d},\\gamma:{g}$\")\n",
|
979 |
+
"\n",
|
980 |
+
"groups=groups[::-1]\n",
|
981 |
+
"names=names[::-1]\n",
|
982 |
+
"\n",
|
983 |
+
"\n",
|
984 |
+
"axis_max_t = 200\n",
|
985 |
+
"\n",
|
986 |
+
"max_t = None\n",
|
987 |
+
"# max_t = 200\n",
|
988 |
+
"# max_t = 100\n",
|
989 |
+
"# max_t = 50\n",
|
990 |
+
"\n",
|
991 |
+
"# Make colormap\n",
|
992 |
+
"import matplotlib.pyplot as plt\n",
|
993 |
+
"viridis = plt.colormaps['viridis'].resampled(len(groups)+1) \n",
|
994 |
+
"cmap = viridis.colors[:len(groups)][::-1]\n",
|
995 |
+
"\n",
|
996 |
+
"for grp_idx,(group, name) in enumerate(zip(groups, names)):\n",
|
997 |
+
"\n",
|
998 |
+
" delta, gamma = group[-2],group[-1]\n",
|
999 |
+
"\n",
|
1000 |
+
" # this is the series of bools corresponding to token at T being in whitelist\n",
|
1001 |
+
" w_bl_hit_list = grouped_df.get_group(group)[\"w_bl_hit_list\"].to_list()\n",
|
1002 |
+
"\n",
|
1003 |
+
" lengths = [len(l) for l in w_bl_hit_list]\n",
|
1004 |
+
" diff_lengths = set(lengths) \n",
|
1005 |
+
" counter = {}\n",
|
1006 |
+
" for l in lengths:\n",
|
1007 |
+
" if counter.get(l):\n",
|
1008 |
+
" counter[l] += 1\n",
|
1009 |
+
" else:\n",
|
1010 |
+
" counter[l] = 1\n",
|
1011 |
+
" if max_t:\n",
|
1012 |
+
" min_length = min(min(diff_lengths),max_t)\n",
|
1013 |
+
" max_t = min_length\n",
|
1014 |
+
" else:\n",
|
1015 |
+
" min_length = min(diff_lengths)\n",
|
1016 |
+
" w_bl_hit_list = [l[:min_length] for l in w_bl_hit_list]\n",
|
1017 |
+
"\n",
|
1018 |
+
" # wl_hit_matrix = ~np.matrix(w_bl_hit_list)\n",
|
1019 |
+
" wl_hit_matrix = (~torch.tensor(w_bl_hit_list, dtype=bool)).to(torch.float)\n",
|
1020 |
+
" # wl_hit_matrix\n",
|
1021 |
+
"\n",
|
1022 |
+
" n = wl_hit_matrix.shape[0]\n",
|
1023 |
+
"\n",
|
1024 |
+
" if max_t:\n",
|
1025 |
+
" t_values = torch.arange(0,max_t)\n",
|
1026 |
+
" indices = torch.arange(0,max_t)\n",
|
1027 |
+
" else:\n",
|
1028 |
+
" t_values = torch.arange(0,wl_hit_matrix.shape[1])\n",
|
1029 |
+
" indices = torch.arange(0,wl_hit_matrix.shape[1])\n",
|
1030 |
+
" # print(t_values[:10])\n",
|
1031 |
+
"\n",
|
1032 |
+
" avg_cumulative = list()\n",
|
1033 |
+
" std_cumulative = list()\n",
|
1034 |
+
" prc_25_cumulative = list()\n",
|
1035 |
+
" prc_50_cumulative = list()\n",
|
1036 |
+
" prc_75_cumulative = list()\n",
|
1037 |
+
"\n",
|
1038 |
+
" prc_25_seq_indices = list()\n",
|
1039 |
+
"\n",
|
1040 |
+
" for idx in indices:\n",
|
1041 |
+
"\n",
|
1042 |
+
" hits_upto_t = wl_hit_matrix[:,:idx+1]\n",
|
1043 |
+
" cumulative_sum_to_t = hits_upto_t.sum(axis=1)\n",
|
1044 |
+
" wl_frac_at_t = cumulative_sum_to_t/(t_values[idx]+1)\n",
|
1045 |
+
" \n",
|
1046 |
+
" if z_scores:\n",
|
1047 |
+
" wl_z_score_at_t = compute_z_score(wl_frac_at_t, t_values[idx], gamma)\n",
|
1048 |
+
" avg_at_t = torch.mean(wl_z_score_at_t,axis=0)\n",
|
1049 |
+
" std_at_t = torch.std(wl_z_score_at_t,axis=0)\n",
|
1050 |
+
" prc_25_at_t = torch.quantile(wl_z_score_at_t,q=0.25,axis=0)\n",
|
1051 |
+
" prc_50_at_t = torch.quantile(wl_z_score_at_t,q=0.50,axis=0)\n",
|
1052 |
+
" prc_75_at_t = torch.quantile(wl_z_score_at_t,q=0.75,axis=0)\n",
|
1053 |
+
"\n",
|
1054 |
+
" if gamma == 0.9: # and idx > 20 and idx < 90:\n",
|
1055 |
+
" pcen=np.quantile(wl_z_score_at_t,0.75,interpolation='nearest')\n",
|
1056 |
+
" i_near=abs(wl_z_score_at_t-pcen).argmin()\n",
|
1057 |
+
" # prc_25_seq_indices.append((i_near.item(),pcen))\n",
|
1058 |
+
" prc_25_seq_indices.append((i_near.item()))\n",
|
1059 |
+
" else:\n",
|
1060 |
+
" avg_at_t = torch.mean(wl_frac_at_t,axis=0)\n",
|
1061 |
+
" std_at_t = torch.std(wl_frac_at_t,axis=0)\n",
|
1062 |
+
" prc_25_at_t = torch.quantile(wl_frac_at_t,q=0.25,axis=0)\n",
|
1063 |
+
" prc_50_at_t = torch.quantile(wl_frac_at_t,q=0.50,axis=0)\n",
|
1064 |
+
" prc_75_at_t = torch.quantile(wl_frac_at_t,q=0.75,axis=0)\n",
|
1065 |
+
"\n",
|
1066 |
+
" avg_cumulative.append(avg_at_t.item())\n",
|
1067 |
+
" std_cumulative.append(std_at_t.item())\n",
|
1068 |
+
" prc_25_cumulative.append(prc_25_at_t.item())\n",
|
1069 |
+
" prc_50_cumulative.append(prc_50_at_t.item())\n",
|
1070 |
+
" prc_75_cumulative.append(prc_75_at_t.item())\n",
|
1071 |
+
"\n",
|
1072 |
+
"\n",
|
1073 |
+
" print(prc_25_seq_indices)\n",
|
1074 |
+
"\n",
|
1075 |
+
" avg_cumulative = np.array(avg_cumulative)\n",
|
1076 |
+
" std_cumulative = np.array(std_cumulative)\n",
|
1077 |
+
" std_err_cumulative = std_cumulative/np.sqrt(n)\n",
|
1078 |
+
" var_cumulative = std_cumulative**2\n",
|
1079 |
+
" \n",
|
1080 |
+
" plt.plot(t_values, avg_cumulative, color=cmap[grp_idx], label=name)\n",
|
1081 |
+
"\n",
|
1082 |
+
" # bounds stuff\n",
|
1083 |
+
"\n",
|
1084 |
+
" # plt.plot(t_values, prc_25_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
|
1085 |
+
" # # plt.plot(t_values, prc_50_cumulative, color=cmap[grp_idx], linestyle='--', label=name+',50th') \n",
|
1086 |
+
" # plt.plot(t_values, prc_75_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',75th ') \n",
|
1087 |
+
" # #fill between the upper and lower bands\n",
|
1088 |
+
" # plt.fill_between(t_values, prc_25_cumulative, prc_75_cumulative, alpha = .1,color = cmap[grp_idx])\n",
|
1089 |
+
" # or just lower\n",
|
1090 |
+
" # plt.fill_between(t_values, prc_25_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n",
|
1091 |
+
"\n",
|
1092 |
+
" # plt.plot(t_values, avg_cumulative-std_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
|
1093 |
+
" # plt.plot(t_values, avg_cumulative+std_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
|
1094 |
+
" # plt.plot(t_values, avg_cumulative-std_err_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
|
1095 |
+
" # plt.plot(t_values, avg_cumulative+std_err_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
|
1096 |
+
" # plt.plot(t_values, avg_cumulative-var_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
|
1097 |
+
" # plt.plot(t_values, avg_cumulative+var_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
|
1098 |
+
" # fill between the upper and lower bands\n",
|
1099 |
+
" # plt.fill_between(t_values, avg_cumulative-std_cumulative, avg_cumulative+std_cumulative, alpha = .1,color = cmap[grp_idx])\n",
|
1100 |
+
" # plt.fill_between(t_values, avg_cumulative-std_err_cumulative, avg_cumulative+std_err_cumulative, alpha = .1,color = cmap[grp_idx])\n",
|
1101 |
+
" # or just lower\n",
|
1102 |
+
" # plt.fill_between(t_values, avg_cumulative-std_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n",
|
1103 |
+
" # plt.fill_between(t_values, avg_cumulative-std_err_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n",
|
1104 |
+
"\n",
|
1105 |
+
"# plt.plot([0.0],[0.0],label=f'25th Percentile', linestyle=\"dashed\", color=\"gray\")\n",
|
1106 |
+
"\n",
|
1107 |
+
"# if beam_search:\n",
|
1108 |
+
"# plt.title(f\"Greedy, {beam_search}-way BS\")\n",
|
1109 |
+
"\n",
|
1110 |
+
"legend_font = 11\n",
|
1111 |
+
"\n",
|
1112 |
+
"# zoom_midrange = True\n",
|
1113 |
+
"# zoom = True\n",
|
1114 |
+
"\n",
|
1115 |
+
"zoom = False\n",
|
1116 |
+
"\n",
|
1117 |
+
"if zoom:\n",
|
1118 |
+
" if z_scores:\n",
|
1119 |
+
" plt.legend(loc = 'upper left', fontsize=legend_font)\n",
|
1120 |
+
" else:\n",
|
1121 |
+
" plt.legend(loc = 'lower right', fontsize=legend_font)\n",
|
1122 |
+
" if zoom_midrange:\n",
|
1123 |
+
" plt.xlim([(min_length)/4, (3*(max_t if max_t else min_length)/4)+1])\n",
|
1124 |
+
" else:\n",
|
1125 |
+
" plt.xlim([0, ((max_t if max_t else min_length)/4)+1])\n",
|
1126 |
+
" plot_name = f\"z_vs_t_zoom_ablate_{ablate}\" if z_scores else f\"wl_vs_t_zoom_ablate_{ablate}\"\n",
|
1127 |
+
"else:\n",
|
1128 |
+
" if z_scores:\n",
|
1129 |
+
" plt.legend(loc = 'upper left', fontsize=legend_font)\n",
|
1130 |
+
" else:\n",
|
1131 |
+
" plt.legend(loc = 'lower right', fontsize=legend_font)\n",
|
1132 |
+
" \n",
|
1133 |
+
" plt.xlim([0, ((max_t if max_t else min_length))+1])\n",
|
1134 |
+
"\n",
|
1135 |
+
" plot_name = f\"z_vs_t_ablate_{ablate}\" if z_scores else f\"wl_vs_t_ablate_{ablate}\"\n",
|
1136 |
+
"\n",
|
1137 |
+
"axes_label_fonts = 14\n",
|
1138 |
+
"if z_scores:\n",
|
1139 |
+
" plt.ylabel('z-score',fontsize=axes_label_fonts)\n",
|
1140 |
+
"else:\n",
|
1141 |
+
" plt.ylabel('Whitelist Fraction',fontsize=axes_label_fonts)\n",
|
1142 |
+
"plt.xlabel('T',fontsize=axes_label_fonts)\n",
|
1143 |
+
"\n",
|
1144 |
+
"# import matplotlib.ticker as ticker\n",
|
1145 |
+
"# tick_spacing = 5.0\n",
|
1146 |
+
"# plt.gca().yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
|
1147 |
+
"\n",
|
1148 |
+
"axes_tick_font = 13\n",
|
1149 |
+
"plt.xticks(fontsize=axes_tick_font)\n",
|
1150 |
+
"plt.yticks(fontsize=axes_tick_font)\n",
|
1151 |
+
"\n",
|
1152 |
+
"plt.grid()\n",
|
1153 |
+
"plt.tight_layout()\n",
|
1154 |
+
"\n",
|
1155 |
+
"if beam_search:\n",
|
1156 |
+
" if ablate == \"gamma\":\n",
|
1157 |
+
" plot_name = f\"greedy_{beam_search}_beams_delta_{delta}\" \n",
|
1158 |
+
" if ablate == \"delta\":\n",
|
1159 |
+
" plot_name = f\"greedy_{beam_search}_beams_gamma_{gamma}\" \n",
|
1160 |
+
"\n",
|
1161 |
+
"# plot_name = \"z_vs_t_ablate_gamma_boosted_delta\"\n",
|
1162 |
+
"# plot_name = \"z_vs_t_ablate_delta_boosted_gamma\"\n",
|
1163 |
+
"\n",
|
1164 |
+
"print(plot_name)\n",
|
1165 |
+
"\n",
|
1166 |
+
"\n",
|
1167 |
+
"if save_fig:\n",
|
1168 |
+
" # fname = f\"figs/{plot_name}.pdf\"\n",
|
1169 |
+
" fname = f\"figs_new/{plot_name}.pdf\"\n",
|
1170 |
+
" plt.savefig(fname, format=\"pdf\")\n",
|
1171 |
+
"\n",
|
1172 |
+
"plt.show()"
|
1173 |
+
]
|
1174 |
+
},
|
1175 |
+
{
|
1176 |
+
"cell_type": "code",
|
1177 |
+
"execution_count": null,
|
1178 |
+
"metadata": {},
|
1179 |
+
"outputs": [],
|
1180 |
+
"source": []
|
1181 |
+
},
|
1182 |
+
{
|
1183 |
+
"attachments": {},
|
1184 |
+
"cell_type": "markdown",
|
1185 |
+
"metadata": {},
|
1186 |
+
"source": [
|
1187 |
+
"# Set up data for charts (setup for figures 2&7)"
|
1188 |
+
]
|
1189 |
+
},
|
1190 |
+
{
|
1191 |
+
"cell_type": "code",
|
1192 |
+
"execution_count": null,
|
1193 |
+
"metadata": {},
|
1194 |
+
"outputs": [],
|
1195 |
+
"source": [
|
1196 |
+
"viz_df = pd.DataFrame()\n",
|
1197 |
+
"\n",
|
1198 |
+
"# aggregating\n",
|
1199 |
+
"\n",
|
1200 |
+
"# set the hparam keys, including an indiv column for each you want to ablate on\n",
|
1201 |
+
"viz_df[\"bl_hparams\"] = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe().index.to_list()\n",
|
1202 |
+
"for i,key in enumerate(groupby_fields):\n",
|
1203 |
+
" viz_df[key] = viz_df[\"bl_hparams\"].apply(lambda tup: tup[i])\n",
|
1204 |
+
"\n",
|
1205 |
+
"# viz_df[\"delta\"] = viz_df[\"bl_logit_bias\"].values\n",
|
1206 |
+
"viz_df[\"gamma\"] = viz_df[\"gamma\"].values\n",
|
1207 |
+
"# viz_df[\"gamma\"] = np.ones_like(viz_df[\"bl_proportion\"].values) - viz_df[\"bl_proportion\"].values\n",
|
1208 |
+
"\n",
|
1209 |
+
"# aggregate each field of interest for each hparam setting (group)\n",
|
1210 |
+
"describe_dict = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe()\n",
|
1211 |
+
"viz_df[\"w_bl_exp_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1212 |
+
"viz_df[\"w_bl_exp_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
|
1213 |
+
"\n",
|
1214 |
+
"describe_dict = grouped_df[\"w_bl_var_whitelist_fraction\"].describe()\n",
|
1215 |
+
"viz_df[\"w_bl_var_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1216 |
+
"viz_df[\"w_bl_var_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
|
1217 |
+
"\n",
|
1218 |
+
"describe_dict = grouped_df[\"w_bl_whitelist_fraction\"].describe()\n",
|
1219 |
+
"viz_df[\"w_bl_whitelist_fraction_min\"] = describe_dict[\"min\"].to_list()\n",
|
1220 |
+
"viz_df[\"w_bl_whitelist_fraction_25\"] = describe_dict[\"25%\"].to_list()\n",
|
1221 |
+
"viz_df[\"w_bl_whitelist_fraction_50\"] = describe_dict[\"50%\"].to_list()\n",
|
1222 |
+
"viz_df[\"w_bl_whitelist_fraction_75\"] = describe_dict[\"75%\"].to_list()\n",
|
1223 |
+
"viz_df[\"w_bl_whitelist_fraction_max\"] = describe_dict[\"max\"].to_list()\n",
|
1224 |
+
"viz_df[\"w_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1225 |
+
"viz_df[\"w_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
|
1226 |
+
"\n",
|
1227 |
+
"describe_dict = grouped_df[\"no_bl_whitelist_fraction\"].describe()\n",
|
1228 |
+
"viz_df[\"no_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1229 |
+
"viz_df[\"no_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
|
1230 |
+
"\n",
|
1231 |
+
"\n",
|
1232 |
+
"describe_dict = grouped_df[\"w_bl_z_score\"].describe()\n",
|
1233 |
+
"viz_df[\"w_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1234 |
+
"viz_df[\"w_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
|
1235 |
+
"\n",
|
1236 |
+
"describe_dict = grouped_df[\"no_bl_z_score\"].describe()\n",
|
1237 |
+
"viz_df[\"no_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1238 |
+
"viz_df[\"no_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
|
1239 |
+
"\n",
|
1240 |
+
"describe_dict = grouped_df[\"baseline_z_score\"].describe()\n",
|
1241 |
+
"viz_df[\"baseline_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1242 |
+
"viz_df[\"baseline_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
|
1243 |
+
"\n",
|
1244 |
+
"\n",
|
1245 |
+
"describe_dict = grouped_df[\"w_bl_ppl\"].describe()\n",
|
1246 |
+
"viz_df[\"w_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1247 |
+
"viz_df[\"w_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
|
1248 |
+
"\n",
|
1249 |
+
"describe_dict = grouped_df[\"no_bl_ppl\"].describe()\n",
|
1250 |
+
"viz_df[\"no_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1251 |
+
"viz_df[\"no_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
|
1252 |
+
"\n",
|
1253 |
+
"describe_dict = grouped_df[\"baseline_ppl\"].describe()\n",
|
1254 |
+
"viz_df[\"baseline_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1255 |
+
"viz_df[\"baseline_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
|
1256 |
+
"\n",
|
1257 |
+
"describe_dict = grouped_df[\"avg_spike_entropy\"].describe()\n",
|
1258 |
+
"viz_df[\"avg_spike_entropy_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
1259 |
+
"viz_df[\"avg_spike_entropy_std\"] = describe_dict[\"std\"].to_list()\n",
|
1260 |
+
"\n",
|
1261 |
+
"print(f\"groupby legend: {groupby_fields}\")\n"
|
1262 |
+
]
|
1263 |
+
},
|
1264 |
+
{
|
1265 |
+
"cell_type": "code",
|
1266 |
+
"execution_count": null,
|
1267 |
+
"metadata": {},
|
1268 |
+
"outputs": [],
|
1269 |
+
"source": [
|
1270 |
+
"# filtering\n",
|
1271 |
+
"\n",
|
1272 |
+
"viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == True))] # sampling\n",
|
1273 |
+
"\n",
|
1274 |
+
"# viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False))] # greedy\n",
|
1275 |
+
"\n",
|
1276 |
+
"\n",
|
1277 |
+
"# fix one of the bl params for analytic chart\n",
|
1278 |
+
"# viz_df = viz_df[(viz_df[\"gamma\"]==0.9) & (viz_df[\"delta\"]<=10.0)]\n",
|
1279 |
+
"# viz_df = viz_df[(viz_df[\"gamma\"]==0.75) & (viz_df[\"delta\"]<=10.0)]\n",
|
1280 |
+
"# viz_df = viz_df[(viz_df[\"gamma\"]==0.5) & (viz_df[\"delta\"]<=10.0)]\n",
|
1281 |
+
"# viz_df = viz_df[(viz_df[\"gamma\"]==0.25) & (viz_df[\"delta\"]<=10.0)]\n",
|
1282 |
+
"# viz_df = viz_df[(viz_df[\"gamma\"]==0.1) & (viz_df[\"delta\"]<=10.0)]\n",
|
1283 |
+
"\n",
|
1284 |
+
"# for the sample pareto chart\n",
|
1285 |
+
"viz_df = viz_df[(viz_df[\"delta\"] > 0.5) & (viz_df[\"delta\"]<=10.0)]\n",
|
1286 |
+
"# viz_df = viz_df[(viz_df[\"delta\"]<=2.0)] # zoom in on lower deltas\n",
|
1287 |
+
"# viz_df = viz_df[(viz_df[\"delta\"] >= 2.0) & (viz_df[\"delta\"]<=10.0)] # mid deltas\n",
|
1288 |
+
"# viz_df = viz_df[(viz_df[\"gamma\"] != 0.25) & (viz_df[\"gamma\"] != 0.75) & (viz_df[\"delta\"]<=2.0)]\n",
|
1289 |
+
"# viz_df = viz_df[(viz_df[\"gamma\"] != 0.1) & (viz_df[\"gamma\"] != 0.9) & (viz_df[\"delta\"]<=2.0)]\n",
|
1290 |
+
"\n",
|
1291 |
+
"# viz_df = viz_df[(viz_df[\"delta\"]==0.5) | (viz_df[\"delta\"]==2.0) | (viz_df[\"delta\"]==10.0)]\n",
|
1292 |
+
"\n",
|
1293 |
+
"# viz_df = viz_df[(viz_df[\"delta\"]!=0.1)&(viz_df[\"delta\"]!=0.5)&(viz_df[\"delta\"]!=50.0)]\n",
|
1294 |
+
"\n",
|
1295 |
+
"# for the beams pareto\n",
|
1296 |
+
"# viz_df = viz_df[(viz_df[\"delta\"]!=50.0)]\n",
|
1297 |
+
"# viz_df = viz_df[(viz_df[\"delta\"]!=50.0) & (viz_df[\"num_beams\"]!=1)]\n",
|
1298 |
+
"\n",
|
1299 |
+
"print(len(viz_df))\n",
|
1300 |
+
"\n",
|
1301 |
+
"viz_df"
|
1302 |
+
]
|
1303 |
+
},
|
1304 |
+
{
|
1305 |
+
"cell_type": "code",
|
1306 |
+
"execution_count": null,
|
1307 |
+
"metadata": {},
|
1308 |
+
"outputs": [],
|
1309 |
+
"source": [
|
1310 |
+
"# grouped_df[\"avg_spike_entropy\"]"
|
1311 |
+
]
|
1312 |
+
},
|
1313 |
+
{
|
1314 |
+
"cell_type": "code",
|
1315 |
+
"execution_count": null,
|
1316 |
+
"metadata": {},
|
1317 |
+
"outputs": [],
|
1318 |
+
"source": [
|
1319 |
+
"# viz_df[[\"gamma\",\"avg_spike_entropy_mean\"]]"
|
1320 |
+
]
|
1321 |
+
},
|
1322 |
+
{
|
1323 |
+
"attachments": {},
|
1324 |
+
"cell_type": "markdown",
|
1325 |
+
"metadata": {},
|
1326 |
+
"source": [
|
1327 |
+
"# Basic Exp vs Empirical WL fraction chart (figure 7)"
|
1328 |
+
]
|
1329 |
+
},
|
1330 |
+
{
|
1331 |
+
"cell_type": "code",
|
1332 |
+
"execution_count": null,
|
1333 |
+
"metadata": {},
|
1334 |
+
"outputs": [],
|
1335 |
+
"source": [
|
1336 |
+
"\n",
|
1337 |
+
"# plt.style.use(\"classic\")\n",
|
1338 |
+
"plt.style.use(\"default\")\n",
|
1339 |
+
"# plt.style.use('ggplot') \n",
|
1340 |
+
"# plt.style.use('seaborn')\n",
|
1341 |
+
"\n",
|
1342 |
+
"rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
|
1343 |
+
"rc('text', usetex=True)\n",
|
1344 |
+
"\n",
|
1345 |
+
"\n",
|
1346 |
+
"plt.clf()\n",
|
1347 |
+
"# plt.figure(figsize=(16, 4))\n",
|
1348 |
+
"# plt.figure(figsize=(8, 4))\n",
|
1349 |
+
"plt.figure(constrained_layout=True)\n",
|
1350 |
+
"plt.figure(figsize=(5, 4))\n",
|
1351 |
+
"\n",
|
1352 |
+
"\n",
|
1353 |
+
"# x_col = 'bl_hparams'\n",
|
1354 |
+
"# a = viz_df[x_col].apply(str)\n",
|
1355 |
+
"\n",
|
1356 |
+
"# x_col = 'bl_logit_bias'\n",
|
1357 |
+
"# x_col = 'bl_proportion'\n",
|
1358 |
+
"x_col = \"delta\"\n",
|
1359 |
+
"# x_col = \"gamma\"\n",
|
1360 |
+
"\n",
|
1361 |
+
"a = viz_df[x_col]\n",
|
1362 |
+
"print(f\"Num configurations: {len(a)}\")\n",
|
1363 |
+
"\n",
|
1364 |
+
"y_col = 'w_bl_whitelist_fraction_mean'\n",
|
1365 |
+
"y_col_err = 'w_bl_whitelist_fraction_std'\n",
|
1366 |
+
"\n",
|
1367 |
+
"viridis = plt.colormaps['viridis'].resampled(4)\n",
|
1368 |
+
"# cmap = viridis.colors[::-1]\n",
|
1369 |
+
"cmap = viridis.colors\n",
|
1370 |
+
"\n",
|
1371 |
+
"plt.plot(a, viz_df[\"w_bl_whitelist_fraction_mean\"].values, color=cmap[1], marker='o', label='Mean') \n",
|
1372 |
+
"plt.plot(a, viz_df[\"w_bl_whitelist_fraction_25\"].values, color=cmap[1], linestyle='-.', label='25th Percentile') \n",
|
1373 |
+
"plt.plot(a, viz_df[\"w_bl_whitelist_fraction_75\"].values, color=cmap[1], linestyle='-.', label='75th Percentile') \n",
|
1374 |
+
"# plt.plot(a, viz_df[\"w_bl_whitelist_fraction_min\"].values, color=cmap[1], linestyle='-.', label='min') \n",
|
1375 |
+
"# plt.plot(a, viz_df[\"w_bl_whitelist_fraction_max\"].values, color=cmap[1], linestyle='-.', label='max') \n",
|
1376 |
+
"\n",
|
1377 |
+
"#fill between the upper and lower bands\n",
|
1378 |
+
"plt.fill_between(a, viz_df[\"w_bl_whitelist_fraction_25\"], viz_df[\"w_bl_whitelist_fraction_75\"], alpha = .1,color = cmap[1])\n",
|
1379 |
+
"# plt.fill_between(a, viz_df[\"w_bl_whitelist_fraction_25\"], viz_df[\"w_bl_whitelist_fraction_75\"], alpha = .1,color = 'darkorchid')\n",
|
1380 |
+
"# plt.fill_between(a, y1_low, y1_high, alpha = .1,color = 'goldenrod')\n",
|
1381 |
+
"\n",
|
1382 |
+
"\n",
|
1383 |
+
"y_col = 'w_bl_exp_whitelist_fraction_mean'\n",
|
1384 |
+
"# y_col_err = 'w_bl_var_whitelist_fraction_mean'\n",
|
1385 |
+
"# d = viz_df[x_col].apply(str)\n",
|
1386 |
+
"\n",
|
1387 |
+
"# sub_df = viz_df[viz_df[\"num_beams\"]==1]\n",
|
1388 |
+
"\n",
|
1389 |
+
"a = viz_df[x_col]\n",
|
1390 |
+
"e = viz_df[y_col].values\n",
|
1391 |
+
"# plt.plot(a, e, label=\"Predicted Lower Bound\", color=cmap[-1])\n",
|
1392 |
+
"plt.plot(a, e, label=\"Analytic Bound\", color=\"r\")\n",
|
1393 |
+
"# f = viz_df[y_col_err].values\n",
|
1394 |
+
"# # f = np.sqrt(viz_df[y_col_err].values)\n",
|
1395 |
+
"# plt.errorbar(d, e, yerr=f, fmt=\"o\")\n",
|
1396 |
+
"\n",
|
1397 |
+
"plt.legend(loc=\"lower right\",frameon=True, facecolor=\"white\")\n",
|
1398 |
+
"\n",
|
1399 |
+
"# for logit bias x axis\n",
|
1400 |
+
"# log_axis = True\n",
|
1401 |
+
"log_axis = False\n",
|
1402 |
+
"if log_axis:\n",
|
1403 |
+
" plt.xscale(\"log\")\n",
|
1404 |
+
"\n",
|
1405 |
+
"ax = plt.gca()\n",
|
1406 |
+
"plt.draw()\n",
|
1407 |
+
"\n",
|
1408 |
+
"\n",
|
1409 |
+
"\n",
|
1410 |
+
"plt.xlabel(f\"Green List Bias, $\\delta$\")\n",
|
1411 |
+
"# plt.xlabel(f\"Whitelist size := $\\gamma$\")\n",
|
1412 |
+
"\n",
|
1413 |
+
"plt.ylabel(\"Fraction in Green List\")\n",
|
1414 |
+
"\n",
|
1415 |
+
"\n",
|
1416 |
+
"plt.grid()\n",
|
1417 |
+
"\n",
|
1418 |
+
"plt.tight_layout()\n",
|
1419 |
+
"\n",
|
1420 |
+
"if log_axis:\n",
|
1421 |
+
" plot_name = \"analytic_w_sampling_log.pdf\"\n",
|
1422 |
+
"else:\n",
|
1423 |
+
" plot_name = \"analytic_w_sampling_linear.pdf\"\n",
|
1424 |
+
" # plot_name = f\"analytic_w_sampling_linear_gamma_{viz_df['gamma'].values[0]}.pdf\"\n",
|
1425 |
+
"\n",
|
1426 |
+
"# plot_name = \"analytic_w_sampling_linear_greenlist.pdf\"\n",
|
1427 |
+
"print(plot_name)\n",
|
1428 |
+
"\n",
|
1429 |
+
"# fname = f\"figs/{plot_name}\"\n",
|
1430 |
+
"# plt.savefig(fname, format=\"pdf\")\n",
|
1431 |
+
"plt.show()\n"
|
1432 |
+
]
|
1433 |
+
},
|
1434 |
+
{
|
1435 |
+
"attachments": {},
|
1436 |
+
"cell_type": "markdown",
|
1437 |
+
"metadata": {},
|
1438 |
+
"source": [
|
1439 |
+
"# delta gamma sampling pareto plot (figure 2 left)"
|
1440 |
+
]
|
1441 |
+
},
|
1442 |
+
{
|
1443 |
+
"cell_type": "code",
|
1444 |
+
"execution_count": null,
|
1445 |
+
"metadata": {},
|
1446 |
+
"outputs": [],
|
1447 |
+
"source": [
|
1448 |
+
"rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
|
1449 |
+
"rc('text', usetex=True)\n",
|
1450 |
+
"\n",
|
1451 |
+
"plt.clf()\n",
|
1452 |
+
"plt.figure(constrained_layout=True)\n",
|
1453 |
+
"plt.figure(figsize=(5, 4))\n",
|
1454 |
+
"\n",
|
1455 |
+
"\n",
|
1456 |
+
"x_col = 'w_bl_ppl_mean'\n",
|
1457 |
+
"y_col = 'w_bl_z_score_mean'\n",
|
1458 |
+
"\n",
|
1459 |
+
"# markers = [\"x\", \"p\", \"*\", \"P\"]\n",
|
1460 |
+
"\n",
|
1461 |
+
"deltas = sorted(np.unique(viz_df[\"delta\"].values))\n",
|
1462 |
+
"gammas = sorted(np.unique(viz_df[\"gamma\"].values), reverse=True)\n",
|
1463 |
+
"print(deltas, gammas)\n",
|
1464 |
+
"gamma_labels = [(g if g > 0.1 else 0.1) for g in gammas]\n",
|
1465 |
+
"\n",
|
1466 |
+
"markers = [\"x\", \"p\", \"*\", \"P\"][:len(deltas)]\n",
|
1467 |
+
"\n",
|
1468 |
+
"num_colors = len(gammas)\n",
|
1469 |
+
"cmap = cmr.get_sub_cmap('viridis', 0.0, 0.66, N=num_colors)\n",
|
1470 |
+
"# cmap = cmr.get_sub_cmap('plasma', 0.0, 0.66, N=num_colors)\n",
|
1471 |
+
"colors = cmap.colors#[::-1]\n",
|
1472 |
+
"\n",
|
1473 |
+
"\n",
|
1474 |
+
"for i,delta in enumerate(deltas):\n",
|
1475 |
+
" for j,gamma in enumerate(gammas):\n",
|
1476 |
+
" sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"gamma\"] == gamma)]\n",
|
1477 |
+
" a = sub_df[x_col].values\n",
|
1478 |
+
" b = sub_df[y_col].values\n",
|
1479 |
+
" # plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n",
|
1480 |
+
" plt.plot(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n",
|
1481 |
+
"\n",
|
1482 |
+
"\n",
|
1483 |
+
"x_col = 'no_bl_ppl_mean'\n",
|
1484 |
+
"y_col = 'no_bl_z_score_mean'\n",
|
1485 |
+
"# x_col = 'baseline_ppl_mean'\n",
|
1486 |
+
"# y_col = 'baseline_z_score_mean'\n",
|
1487 |
+
"\n",
|
1488 |
+
"\n",
|
1489 |
+
"for i,delta in enumerate(deltas):\n",
|
1490 |
+
" for j,gamma in enumerate(gammas):\n",
|
1491 |
+
" sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"gamma\"] == gamma)]\n",
|
1492 |
+
" a = sub_df[x_col].values\n",
|
1493 |
+
" b = sub_df[y_col].values\n",
|
1494 |
+
" plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j])\n",
|
1495 |
+
"\n",
|
1496 |
+
"# # # for manual legend\n",
|
1497 |
+
"plt.scatter([-1],[-1], label=\"Vanilla\", color=\"gray\", marker=\"o\")\n",
|
1498 |
+
"\n",
|
1499 |
+
"ax = plt.gca()\n",
|
1500 |
+
"\n",
|
1501 |
+
"from matplotlib.cm import ScalarMappable\n",
|
1502 |
+
"from matplotlib.colors import Normalize, NoNorm, ListedColormap\n",
|
1503 |
+
"cmap = ListedColormap(colors)\n",
|
1504 |
+
"cmappable = ScalarMappable(norm=NoNorm(),cmap=cmap)\n",
|
1505 |
+
"cbar = plt.colorbar(cmappable,ticks=[i for i in range(len(gammas))],shrink=0.6, pad = 0.03)\n",
|
1506 |
+
"cbar.ax.set_yticklabels(gamma_labels) \n",
|
1507 |
+
"cbar.set_label('$\\gamma$', rotation=0)\n",
|
1508 |
+
"\n",
|
1509 |
+
"\n",
|
1510 |
+
"all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['no_bl_ppl_mean'].values])\n",
|
1511 |
+
"all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['no_bl_z_score_mean'].values])\n",
|
1512 |
+
"# all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['baseline_ppl_mean'].values])\n",
|
1513 |
+
"# all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['baseline_z_score_mean'].values])\n",
|
1514 |
+
"\n",
|
1515 |
+
"min_x, max_x = np.min(all_x), np.max(all_x)\n",
|
1516 |
+
"min_y, max_y = np.min(all_y), np.max(all_y)\n",
|
1517 |
+
"\n",
|
1518 |
+
"# x_min_tick = 1.0\n",
|
1519 |
+
"x_min_tick = 3.0\n",
|
1520 |
+
"x_max_tick = np.ceil([max_x])[0]+1.0\n",
|
1521 |
+
"y_min_tick = 0.0\n",
|
1522 |
+
"y_max_tick = np.ceil([max_y])[0]+1.0\n",
|
1523 |
+
"\n",
|
1524 |
+
"x_ticks = np.arange(x_min_tick,x_max_tick,1.0)\n",
|
1525 |
+
"y_ticks = np.arange(y_min_tick,y_max_tick,5.0)\n",
|
1526 |
+
"\n",
|
1527 |
+
"\n",
|
1528 |
+
"x_lim_min = 3.0\n",
|
1529 |
+
"x_lim_max = x_max_tick\n",
|
1530 |
+
"y_lim_min = 0.45\n",
|
1531 |
+
"# y_lim_max = 1.09\n",
|
1532 |
+
"y_lim_max = 1.005\n",
|
1533 |
+
"\n",
|
1534 |
+
"\n",
|
1535 |
+
"# plt.xlim((x_min_tick-0.5,x_max_tick))\n",
|
1536 |
+
"plt.xlim((x_lim_min,x_lim_max))\n",
|
1537 |
+
"# plt.xlim((4.0,8.0))\n",
|
1538 |
+
"# plt.ylim((-1.0,20.0))\n",
|
1539 |
+
"# plt.ylim((y_lim_min,y_lim_max))\n",
|
1540 |
+
"\n",
|
1541 |
+
"ax.set_xticks(x_ticks)\n",
|
1542 |
+
"# ax.set_yticks(y_ticks)\n",
|
1543 |
+
"\n",
|
1544 |
+
"ax.invert_xaxis()\n",
|
1545 |
+
"\n",
|
1546 |
+
"# # manual legend for dual parameter visualization\n",
|
1547 |
+
"f = lambda m,c: plt.plot([],[],marker=m, color=c, ls=\"none\")[0]\n",
|
1548 |
+
"handles = [f(markers[::-1][i], \"gray\") for i in range(len(deltas))]\n",
|
1549 |
+
"handles += [f(\"o\", \"gray\")]\n",
|
1550 |
+
"labels = [f\"$\\delta={delta}$\" for delta in deltas[::-1]]+[f\"$\\delta=0.0$\"]\n",
|
1551 |
+
"plt.legend(handles, labels, loc=\"upper right\", framealpha=1)\n",
|
1552 |
+
"\n",
|
1553 |
+
"plt.grid()\n",
|
1554 |
+
"\n",
|
1555 |
+
"plt.xlabel(\"Oracle Model PPL (better →)\")\n",
|
1556 |
+
"plt.ylabel(\"z-score (better →)\")\n",
|
1557 |
+
"\n",
|
1558 |
+
"\n",
|
1559 |
+
"plt.tight_layout()\n",
|
1560 |
+
"\n",
|
1561 |
+
"# plot_name = \"pareto_sampling_no_beams\"\n",
|
1562 |
+
"# fname = f\"figs/{plot_name}.pdf\"\n",
|
1563 |
+
"# plt.savefig(fname, format=\"pdf\")\n",
|
1564 |
+
"plt.show()"
|
1565 |
+
]
|
1566 |
+
},
|
1567 |
+
{
|
1568 |
+
"attachments": {},
|
1569 |
+
"cell_type": "markdown",
|
1570 |
+
"metadata": {},
|
1571 |
+
"source": [
|
1572 |
+
"# beams pareto plot (figure 2 right)"
|
1573 |
+
]
|
1574 |
+
},
|
1575 |
+
{
|
1576 |
+
"cell_type": "code",
|
1577 |
+
"execution_count": null,
|
1578 |
+
"metadata": {},
|
1579 |
+
"outputs": [],
|
1580 |
+
"source": [
|
1581 |
+
"num_colors = 3\n",
|
1582 |
+
"cmap = cmr.get_sub_cmap('viridis', 0.0, 0.66, N=num_colors)\n",
|
1583 |
+
"colors = cmap.colors#[::-1]\n",
|
1584 |
+
"\n",
|
1585 |
+
"# plt.style.use('ggplot')\n",
|
1586 |
+
"# plt.style.use('seaborn')\n",
|
1587 |
+
"\n",
|
1588 |
+
"rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
|
1589 |
+
"rc('text', usetex=True)\n",
|
1590 |
+
"\n",
|
1591 |
+
"plt.clf()\n",
|
1592 |
+
"plt.figure(constrained_layout=True)\n",
|
1593 |
+
"plt.figure(figsize=(5, 4))\n",
|
1594 |
+
"\n",
|
1595 |
+
"\n",
|
1596 |
+
"x_col = 'w_bl_ppl_mean'\n",
|
1597 |
+
"y_col = 'w_bl_z_score_mean'\n",
|
1598 |
+
"\n",
|
1599 |
+
"markers = [\"s\",\"D\", \"x\", \"p\", \"*\", \"P\"] # <--- seems to match other pareto fig ordering\n",
|
1600 |
+
"\n",
|
1601 |
+
"deltas = sorted(np.unique(viz_df[\"delta\"].values))\n",
|
1602 |
+
"num_beams = sorted(np.unique(viz_df[\"num_beams\"].values))\n",
|
1603 |
+
"# gamma_labels = [(g if g > 0.1 else 0.1) for g in np.unique(viz_df[\"gamma\"].values)]\n",
|
1604 |
+
"\n",
|
1605 |
+
"for i,n_beams in enumerate(num_beams):\n",
|
1606 |
+
" for j,delta in enumerate(deltas):\n",
|
1607 |
+
" sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"num_beams\"] == n_beams)]\n",
|
1608 |
+
" a = sub_df[x_col].values\n",
|
1609 |
+
" b = sub_df[y_col].values\n",
|
1610 |
+
" # plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n",
|
1611 |
+
" plt.plot(a, b, label=f\"$\\delta={delta}$\", color=colors[i], marker=markers[j])\n",
|
1612 |
+
"\n",
|
1613 |
+
"\n",
|
1614 |
+
"x_col = 'no_bl_ppl_mean'\n",
|
1615 |
+
"y_col = 'no_bl_z_score_mean'\n",
|
1616 |
+
"\n",
|
1617 |
+
"\n",
|
1618 |
+
"\n",
|
1619 |
+
"for i,n_beams in enumerate(num_beams):\n",
|
1620 |
+
" for j,delta in enumerate(deltas):\n",
|
1621 |
+
" sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"num_beams\"] == n_beams)]\n",
|
1622 |
+
" a = sub_df[x_col].values\n",
|
1623 |
+
" b = sub_df[y_col].values\n",
|
1624 |
+
" plt.scatter(a, b, label=f\"$\\delta={delta}$\", color=colors[i])\n",
|
1625 |
+
"\n",
|
1626 |
+
"# # # for manual legend\n",
|
1627 |
+
"plt.scatter([-10],[-10], label=\"$\\delta=0$\", color=\"gray\", marker=\"o\")\n",
|
1628 |
+
"\n",
|
1629 |
+
"ax = plt.gca()\n",
|
1630 |
+
"\n",
|
1631 |
+
"from matplotlib.cm import ScalarMappable\n",
|
1632 |
+
"from matplotlib.colors import Normalize, NoNorm, ListedColormap\n",
|
1633 |
+
"cmap = ListedColormap(colors)\n",
|
1634 |
+
"cmappable = ScalarMappable(norm=NoNorm(),cmap=cmap)\n",
|
1635 |
+
"cbar = plt.colorbar(cmappable,ticks=[i for i in range(len(num_beams))],shrink=0.6, pad = 0.04)\n",
|
1636 |
+
"# cbar.set_ticks(num_beams)\n",
|
1637 |
+
"cbar.set_ticklabels(num_beams)\n",
|
1638 |
+
"# cbar.ax.set_yticklabels(num_beams) \n",
|
1639 |
+
"cbar.set_label('Num Beams', rotation=90)\n",
|
1640 |
+
"\n",
|
1641 |
+
"\n",
|
1642 |
+
"all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['no_bl_ppl_mean'].values])\n",
|
1643 |
+
"all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['no_bl_z_score_mean'].values])\n",
|
1644 |
+
"\n",
|
1645 |
+
"min_x, max_x = np.min(all_x), np.max(all_x)\n",
|
1646 |
+
"min_y, max_y = np.min(all_y), np.max(all_y)\n",
|
1647 |
+
"\n",
|
1648 |
+
"# x_max_tick = np.ceil([max_x])[0]+1.0\n",
|
1649 |
+
"x_max_tick = np.ceil([max_x])[0]\n",
|
1650 |
+
"y_max_tick = np.ceil([max_y])[0]+1.0\n",
|
1651 |
+
"\n",
|
1652 |
+
"\n",
|
1653 |
+
"plt.xlim((1.0,x_max_tick))\n",
|
1654 |
+
"plt.ylim((-1.0,y_max_tick))\n",
|
1655 |
+
"\n",
|
1656 |
+
"# x_ticks = np.arange(x_min_tick,x_max_tick,1.0)\n",
|
1657 |
+
"# y_ticks = np.arange(y_min_tick,y_max_tick,5.0)\n",
|
1658 |
+
"\n",
|
1659 |
+
"# ax.set_xticks(x_ticks)\n",
|
1660 |
+
"# ax.set_yticks(y_ticks)\n",
|
1661 |
+
"\n",
|
1662 |
+
"ax.invert_xaxis()\n",
|
1663 |
+
"\n",
|
1664 |
+
"# # manual legend for dual parameter visualization\n",
|
1665 |
+
"f = lambda m,c: plt.plot([],[],marker=m, color=c, ls=\"none\")[0]\n",
|
1666 |
+
"handles = [f(markers[::-1][i], \"gray\") for i in range(len(deltas))]\n",
|
1667 |
+
"handles += [f(\"o\", \"gray\")]\n",
|
1668 |
+
"labels = [f\"$\\delta={delta}$\" for delta in deltas[::-1]]+[f\"$\\delta=0.0$\"]\n",
|
1669 |
+
"plt.legend(handles, labels, loc=\"lower left\", framealpha=1)\n",
|
1670 |
+
"\n",
|
1671 |
+
"plt.grid()\n",
|
1672 |
+
"\n",
|
1673 |
+
"plt.xlabel(\"Oracle Model PPL (better →)\")\n",
|
1674 |
+
"plt.ylabel(\"z-score (better →)\")\n",
|
1675 |
+
"\n",
|
1676 |
+
"\n",
|
1677 |
+
"plt.tight_layout()\n",
|
1678 |
+
"\n",
|
1679 |
+
"\n",
|
1680 |
+
"plot_name = \"pareto_greedy_w_beams\"\n",
|
1681 |
+
"print(plot_name)\n",
|
1682 |
+
"\n",
|
1683 |
+
"# fname = f\"figs/{plot_name}.pdf\"\n",
|
1684 |
+
"# plt.savefig(fname, format=\"pdf\")\n",
|
1685 |
+
"plt.show()"
|
1686 |
+
]
|
1687 |
+
},
|
1688 |
+
{
|
1689 |
+
"cell_type": "code",
|
1690 |
+
"execution_count": null,
|
1691 |
+
"metadata": {},
|
1692 |
+
"outputs": [],
|
1693 |
+
"source": []
|
1694 |
+
},
|
1695 |
+
{
|
1696 |
+
"attachments": {},
|
1697 |
+
"cell_type": "markdown",
|
1698 |
+
"metadata": {},
|
1699 |
+
"source": [
|
1700 |
+
"## z vs entropy (not in paper)"
|
1701 |
+
]
|
1702 |
+
},
|
1703 |
+
{
|
1704 |
+
"cell_type": "code",
|
1705 |
+
"execution_count": null,
|
1706 |
+
"metadata": {},
|
1707 |
+
"outputs": [],
|
1708 |
+
"source": [
|
1709 |
+
"print(f\"groupby legend: {groupby_fields}\")\n",
|
1710 |
+
"# hist_subset = grouped_df.get_group((True,1,2.0,0.1)) # needs to match the groupby keys and order\n",
|
1711 |
+
"# hist_subset = grouped_df.get_group((True,1,2.0,0.25)) \n",
|
1712 |
+
"hist_subset = grouped_df.get_group((True,1,2.0,0.5)) \n",
|
1713 |
+
"# hist_subset = grouped_df.get_group((True,1,2.0,0.75)) \n",
|
1714 |
+
"# hist_subset = grouped_df.get_group((True,1,2.0,0.9)) "
|
1715 |
+
]
|
1716 |
+
},
|
1717 |
+
{
|
1718 |
+
"cell_type": "code",
|
1719 |
+
"execution_count": null,
|
1720 |
+
"metadata": {},
|
1721 |
+
"outputs": [],
|
1722 |
+
"source": [
|
1723 |
+
"print(len(hist_subset))\n",
|
1724 |
+
"# hist_subset = hist_subset[hist_subset[\"w_bl_space_frac\"] <= 0.9]\n",
|
1725 |
+
"# hist_subset = hist_subset[hist_subset[\"no_bl_space_frac\"] <= 0.9]\n",
|
1726 |
+
"# print(len(hist_subset))"
|
1727 |
+
]
|
1728 |
+
},
|
1729 |
+
{
|
1730 |
+
"cell_type": "code",
|
1731 |
+
"execution_count": null,
|
1732 |
+
"metadata": {},
|
1733 |
+
"outputs": [],
|
1734 |
+
"source": [
|
1735 |
+
"# y = hist_subset[\"w_bl_z_score\"]\n",
|
1736 |
+
"# y = hist_subset[\"no_bl_z_score\"]\n",
|
1737 |
+
"y = hist_subset[\"baseline_z_score\"]\n",
|
1738 |
+
"\n",
|
1739 |
+
"x = hist_subset[\"avg_spike_entropy\"]\n",
|
1740 |
+
"\n",
|
1741 |
+
"plt.clf()\n",
|
1742 |
+
"\n",
|
1743 |
+
"\n",
|
1744 |
+
"plt.scatter(x, y)\n",
|
1745 |
+
"\n",
|
1746 |
+
"\n",
|
1747 |
+
"plt.grid()\n",
|
1748 |
+
"\n",
|
1749 |
+
"plt.xlabel(\"Entropy\")\n",
|
1750 |
+
"plt.ylabel(\"z-score\")\n",
|
1751 |
+
"\n",
|
1752 |
+
"plt.show()"
|
1753 |
+
]
|
1754 |
+
},
|
1755 |
+
{
|
1756 |
+
"cell_type": "code",
|
1757 |
+
"execution_count": null,
|
1758 |
+
"metadata": {},
|
1759 |
+
"outputs": [],
|
1760 |
+
"source": [
|
1761 |
+
"cols_to_tabulate = [\n",
|
1762 |
+
" 'idx', \n",
|
1763 |
+
" 'truncated_input', \n",
|
1764 |
+
" 'baseline_completion',\n",
|
1765 |
+
" 'no_bl_output', \n",
|
1766 |
+
" 'w_bl_output', \n",
|
1767 |
+
" 'avg_spike_entropy',\n",
|
1768 |
+
" 'no_bl_z_score',\n",
|
1769 |
+
" 'w_bl_z_score',\n",
|
1770 |
+
" 'w_bl_whitelist_fraction',\n",
|
1771 |
+
" 'no_bl_whitelist_fraction',\n",
|
1772 |
+
" 'baseline_ppl',\n",
|
1773 |
+
" 'no_bl_ppl',\n",
|
1774 |
+
" 'w_bl_ppl'\n",
|
1775 |
+
"]\n",
|
1776 |
+
"\n",
|
1777 |
+
"slice_size = 10\n",
|
1778 |
+
"\n",
|
1779 |
+
"num_examples = len(hist_subset)\n",
|
1780 |
+
"midpt = num_examples//5\n",
|
1781 |
+
"lower = midpt - (slice_size//2)\n",
|
1782 |
+
"upper = midpt + (slice_size//2)+1\n",
|
1783 |
+
"\n",
|
1784 |
+
"high_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).tail(slice_size)\n",
|
1785 |
+
"mid_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).iloc[lower:upper]\n",
|
1786 |
+
"low_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).head(slice_size)"
|
1787 |
+
]
|
1788 |
+
},
|
1789 |
+
{
|
1790 |
+
"cell_type": "code",
|
1791 |
+
"execution_count": null,
|
1792 |
+
"metadata": {},
|
1793 |
+
"outputs": [],
|
1794 |
+
"source": [
|
1795 |
+
"# hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=14.0)]\n",
|
1796 |
+
"hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"baseline_z_score\"]>=7.0)]\n",
|
1797 |
+
"# hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=12.0)]\n",
|
1798 |
+
"# print(hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=14.0)].iloc[6][\"w_bl_output\"])\n",
|
1799 |
+
"# .to_csv(\"input/pile_low_S_high_z_outliers.csv\")"
|
1800 |
+
]
|
1801 |
+
},
|
1802 |
+
{
|
1803 |
+
"cell_type": "code",
|
1804 |
+
"execution_count": null,
|
1805 |
+
"metadata": {},
|
1806 |
+
"outputs": [],
|
1807 |
+
"source": [
|
1808 |
+
"high_entropy_examples"
|
1809 |
+
]
|
1810 |
+
},
|
1811 |
+
{
|
1812 |
+
"cell_type": "code",
|
1813 |
+
"execution_count": null,
|
1814 |
+
"metadata": {},
|
1815 |
+
"outputs": [],
|
1816 |
+
"source": [
|
1817 |
+
"mid_entropy_examples"
|
1818 |
+
]
|
1819 |
+
},
|
1820 |
+
{
|
1821 |
+
"cell_type": "code",
|
1822 |
+
"execution_count": null,
|
1823 |
+
"metadata": {},
|
1824 |
+
"outputs": [],
|
1825 |
+
"source": [
|
1826 |
+
"low_entropy_examples"
|
1827 |
+
]
|
1828 |
+
},
|
1829 |
+
{
|
1830 |
+
"cell_type": "code",
|
1831 |
+
"execution_count": null,
|
1832 |
+
"metadata": {},
|
1833 |
+
"outputs": [],
|
1834 |
+
"source": []
|
1835 |
+
},
|
1836 |
+
{
|
1837 |
+
"attachments": {},
|
1838 |
+
"cell_type": "markdown",
|
1839 |
+
"metadata": {},
|
1840 |
+
"source": [
|
1841 |
+
"# plotting histograms of the metric for single runs (not in paper)"
|
1842 |
+
]
|
1843 |
+
},
|
1844 |
+
{
|
1845 |
+
"cell_type": "code",
|
1846 |
+
"execution_count": null,
|
1847 |
+
"metadata": {},
|
1848 |
+
"outputs": [],
|
1849 |
+
"source": [
|
1850 |
+
"print(f\"groupby legend: {groupby_fields}\")\n",
|
1851 |
+
"# hist_subset = grouped_df.get_group((True,1,2.0,0.1)) # needs to match the groupby keys and order\n",
|
1852 |
+
"# hist_subset = grouped_df.get_group((True,1,2.0,0.25)) \n",
|
1853 |
+
"hist_subset = grouped_df.get_group((True,1,2.0,0.5)) \n",
|
1854 |
+
"# hist_subset = grouped_df.get_group((True,1,2.0,0.75)) \n",
|
1855 |
+
"# hist_subset = grouped_df.get_group((True,1,2.0,0.9)) "
|
1856 |
+
]
|
1857 |
+
},
|
1858 |
+
{
|
1859 |
+
"cell_type": "markdown",
|
1860 |
+
"metadata": {},
|
1861 |
+
"source": [
|
1862 |
+
"##### old filters to smooth the histograms"
|
1863 |
+
]
|
1864 |
+
},
|
1865 |
+
{
|
1866 |
+
"cell_type": "code",
|
1867 |
+
"execution_count": null,
|
1868 |
+
"metadata": {},
|
1869 |
+
"outputs": [],
|
1870 |
+
"source": [
|
1871 |
+
"# hist_subset = hist_subset[(hist_subset[\"no_bl_num_tokens_generated\"] == hist_subset[\"max_new_tokens\"]) & (hist_subset[\"w_bl_num_tokens_generated\"] == hist_subset[\"max_new_tokens\"])]\n",
|
1872 |
+
"# hist_subset = hist_subset[hist_subset[\"truncated_input\"] != \"\"]"
|
1873 |
+
]
|
1874 |
+
},
|
1875 |
+
{
|
1876 |
+
"cell_type": "code",
|
1877 |
+
"execution_count": null,
|
1878 |
+
"metadata": {},
|
1879 |
+
"outputs": [],
|
1880 |
+
"source": [
|
1881 |
+
"all_no_bl_wl_fractions = hist_subset[\"no_bl_whitelist_fraction\"]\n",
|
1882 |
+
"all_w_bl_wl_fractions = hist_subset[\"w_bl_whitelist_fraction\"]\n",
|
1883 |
+
"all_baseline_wl_fractions = hist_subset[\"baseline_whitelist_fraction\"]\n",
|
1884 |
+
"# all_no_bl_wl_fractions = hist_subset[\"no_bl_z_score\"]\n",
|
1885 |
+
"# all_w_bl_wl_fractions = hist_subset[\"w_bl_z_score\"]\n",
|
1886 |
+
"# all_baseline_wl_fractions = hist_subset[\"baseline_z_score\"]\n",
|
1887 |
+
"\n",
|
1888 |
+
"plt.clf()\n",
|
1889 |
+
"\n",
|
1890 |
+
"all_vals = np.concatenate([all_baseline_wl_fractions, all_w_bl_wl_fractions, all_no_bl_wl_fractions])\n",
|
1891 |
+
"n_bins = 50\n",
|
1892 |
+
"bins = np.linspace(np.min(all_vals), np.max(all_vals), n_bins)\n",
|
1893 |
+
"# bins = np.linspace(0.0, 1.0, n_bins)\n",
|
1894 |
+
"\n",
|
1895 |
+
"# plt.hist(all_no_bl_wl_fractions, \n",
|
1896 |
+
"# bins=bins,\n",
|
1897 |
+
"# alpha=0.6,\n",
|
1898 |
+
"# label='no blacklisting')\n",
|
1899 |
+
"\n",
|
1900 |
+
"\n",
|
1901 |
+
"plt.hist(all_w_bl_wl_fractions, \n",
|
1902 |
+
" bins=bins,\n",
|
1903 |
+
" alpha=0.6,\n",
|
1904 |
+
" label='with blacklisting')\n",
|
1905 |
+
"\n",
|
1906 |
+
"plt.hist(all_baseline_wl_fractions,\n",
|
1907 |
+
" bins=bins,\n",
|
1908 |
+
" alpha=0.4,\n",
|
1909 |
+
" # label='wl')\n",
|
1910 |
+
" label='ground truth/real text')\n",
|
1911 |
+
"\n",
|
1912 |
+
"# plt.hist(all_baseline_bl_fractions, \n",
|
1913 |
+
"# bins=bins,\n",
|
1914 |
+
"# alpha=0.5,\n",
|
1915 |
+
"# label='bl')\n",
|
1916 |
+
"\n",
|
1917 |
+
"plt.legend(loc='upper right')\n",
|
1918 |
+
"\n",
|
1919 |
+
"# plt.xlim((-0.1,1.1))\n",
|
1920 |
+
"# plt.xticks(np.arange(0.0,1.0,0.1))\n",
|
1921 |
+
"plt.xlabel(\"fraction of total toks gen'd in WL\")\n",
|
1922 |
+
"plt.ylabel(\"freq\")\n",
|
1923 |
+
"\n",
|
1924 |
+
"# plt.title('baseline wl/bl fractions')\n",
|
1925 |
+
"plt.title(\"Output Whitelist Token Distribution\")\n",
|
1926 |
+
"\n",
|
1927 |
+
"# plot_name = \"wl_distro\"\n",
|
1928 |
+
"# fname = f\"figs/{plot_name}.png\"\n",
|
1929 |
+
"# plt.savefig(fname, dpi=600)"
|
1930 |
+
]
|
1931 |
+
},
|
1932 |
+
{
|
1933 |
+
"cell_type": "code",
|
1934 |
+
"execution_count": null,
|
1935 |
+
"metadata": {},
|
1936 |
+
"outputs": [],
|
1937 |
+
"source": [
|
1938 |
+
"plt.clf()\n",
|
1939 |
+
"\n",
|
1940 |
+
"all_no_bl_ppls = hist_subset[\"no_bl_ppl\"]\n",
|
1941 |
+
"all_w_bl_ppls = hist_subset[\"w_bl_ppl\"]\n",
|
1942 |
+
"all_baseline_ppls = hist_subset[\"baseline_ppl\"]\n",
|
1943 |
+
"\n",
|
1944 |
+
"all_vals = list(np.concatenate([all_no_bl_ppls, all_w_bl_ppls]))\n",
|
1945 |
+
"all_vals = sorted(all_vals)\n",
|
1946 |
+
"n_bins = 50\n",
|
1947 |
+
"# bins = np.linspace(all_vals[0], all_vals[-1], n_bins)\n",
|
1948 |
+
"bins = np.linspace(all_vals[0], 20, n_bins)\n",
|
1949 |
+
"\n",
|
1950 |
+
"plt.hist(all_no_bl_ppls, \n",
|
1951 |
+
" bins=bins,\n",
|
1952 |
+
" alpha=0.6,\n",
|
1953 |
+
" label='no blacklisting')\n",
|
1954 |
+
"\n",
|
1955 |
+
"plt.hist(all_w_bl_ppls, \n",
|
1956 |
+
" bins=bins,\n",
|
1957 |
+
" alpha=0.6,\n",
|
1958 |
+
" label='with blacklisting')\n",
|
1959 |
+
"\n",
|
1960 |
+
"plt.legend(loc='upper right')\n",
|
1961 |
+
"\n",
|
1962 |
+
"# plt.xlim((0,1))\n",
|
1963 |
+
"plt.xlabel(\"perplexity (lower is better)\")\n",
|
1964 |
+
"plt.ylabel(\"freq\")\n",
|
1965 |
+
"\n",
|
1966 |
+
"plt.title('Model-based Output Quality/Fluency')\n",
|
1967 |
+
"\n",
|
1968 |
+
"# plot_name = \"ppl_no_baseline\"\n",
|
1969 |
+
"# fname = f\"figs/{plot_name}.png\"\n",
|
1970 |
+
"# plt.savefig(fname, dpi=600)"
|
1971 |
+
]
|
1972 |
+
},
|
1973 |
+
{
|
1974 |
+
"cell_type": "code",
|
1975 |
+
"execution_count": null,
|
1976 |
+
"metadata": {},
|
1977 |
+
"outputs": [],
|
1978 |
+
"source": [
|
1979 |
+
"plt.clf()\n",
|
1980 |
+
"\n",
|
1981 |
+
"all_vals = list(np.concatenate([all_no_bl_ppls, all_w_bl_ppls]))\n",
|
1982 |
+
"all_vals = sorted(all_vals)\n",
|
1983 |
+
"n_bins = 50\n",
|
1984 |
+
"# bins = np.linspace(all_vals[0], all_vals[-1], n_bins)\n",
|
1985 |
+
"bins = np.linspace(all_vals[0], 20, n_bins)\n",
|
1986 |
+
"\n",
|
1987 |
+
"plt.hist(all_no_bl_ppls, \n",
|
1988 |
+
" bins=bins,\n",
|
1989 |
+
" alpha=0.6,\n",
|
1990 |
+
" label='no blacklisting')\n",
|
1991 |
+
"\n",
|
1992 |
+
"plt.hist(all_w_bl_ppls, \n",
|
1993 |
+
" bins=bins,\n",
|
1994 |
+
" alpha=0.6,\n",
|
1995 |
+
" label='with blacklisting')\n",
|
1996 |
+
"\n",
|
1997 |
+
"plt.hist(all_baseline_ppls, \n",
|
1998 |
+
" bins=bins, \n",
|
1999 |
+
" alpha=0.4,\n",
|
2000 |
+
" label='ground truth/real text')\n",
|
2001 |
+
"\n",
|
2002 |
+
"plt.legend(loc='upper right')\n",
|
2003 |
+
"\n",
|
2004 |
+
"# plt.xlim((0,1))\n",
|
2005 |
+
"plt.xlabel(\"perplexity (lower is better)\")\n",
|
2006 |
+
"plt.ylabel(\"freq\")\n",
|
2007 |
+
"\n",
|
2008 |
+
"plt.title('Model-based Output Quality/Fluency')\n",
|
2009 |
+
"\n",
|
2010 |
+
"# plot_name = \"ppl_w_baseline\"\n",
|
2011 |
+
"# fname = f\"figs/{plot_name}.png\"\n",
|
2012 |
+
"# plt.savefig(fname, dpi=600)"
|
2013 |
+
]
|
2014 |
+
},
|
2015 |
+
{
|
2016 |
+
"cell_type": "code",
|
2017 |
+
"execution_count": null,
|
2018 |
+
"metadata": {},
|
2019 |
+
"outputs": [],
|
2020 |
+
"source": []
|
2021 |
+
}
|
2022 |
+
],
|
2023 |
+
"metadata": {
|
2024 |
+
"kernelspec": {
|
2025 |
+
"display_name": "Python 3",
|
2026 |
+
"language": "python",
|
2027 |
+
"name": "python3"
|
2028 |
+
},
|
2029 |
+
"language_info": {
|
2030 |
+
"codemirror_mode": {
|
2031 |
+
"name": "ipython",
|
2032 |
+
"version": 3
|
2033 |
+
},
|
2034 |
+
"file_extension": ".py",
|
2035 |
+
"mimetype": "text/x-python",
|
2036 |
+
"name": "python",
|
2037 |
+
"nbconvert_exporter": "python",
|
2038 |
+
"pygments_lexer": "ipython3",
|
2039 |
+
"version": "3.10.6"
|
2040 |
+
},
|
2041 |
+
"vscode": {
|
2042 |
+
"interpreter": {
|
2043 |
+
"hash": "365524a309ad80022da286f2ec5d2060ce5cb229abb6076cf68d9a1ab14bd8fe"
|
2044 |
+
}
|
2045 |
+
}
|
2046 |
+
},
|
2047 |
+
"nbformat": 4,
|
2048 |
+
"nbformat_minor": 4
|
2049 |
+
}
|
lm-watermarking-main/experiments/watermarking_example_finding.ipynb
ADDED
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Watermark Analysis\n",
|
8 |
+
"\n",
|
9 |
+
"Notebook for performing analysis and visualization of the effects of watermarking schemes"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"from datasets import load_from_disk"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"attachments": {},
|
23 |
+
"cell_type": "markdown",
|
24 |
+
"metadata": {},
|
25 |
+
"source": [
|
26 |
+
"### Load the processed dataset/frame"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n",
|
36 |
+
"# save_name = \"analysis_ds_1-21_greedy_redo\" \n",
|
37 |
+
"# save_name = \"analysis_ds_1-21_greedy_redo_truncated\" # in figure\n",
|
38 |
+
"\n",
|
39 |
+
"# save_name = \"analysis_ds_1-20_more_attack\" # in figure\n",
|
40 |
+
"\n",
|
41 |
+
"save_dir = f\"input/{save_name}\""
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": null,
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"raw_data = load_from_disk(save_dir)"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "markdown",
|
55 |
+
"metadata": {},
|
56 |
+
"source": [
|
57 |
+
"#### convert to pandas df"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": null,
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [],
|
65 |
+
"source": [
|
66 |
+
"df = raw_data.to_pandas()\n",
|
67 |
+
"\n",
|
68 |
+
"retok_problematic_rows = df[(df['w_bl_whitelist_fraction'] != -1.0) & (df['w_bl_whitelist_fraction'] != 1.0) & (df['bl_type'] == 'hard')]\n",
|
69 |
+
"print(f\"Num rows that are hard-blacklisted, and measureable, but still have a non-100% WL fraction: {len(retok_problematic_rows)} out of {len(df[df['bl_type'] == 'hard'])}\")\n",
|
70 |
+
"\n",
|
71 |
+
"orig_len = len(df)\n",
|
72 |
+
"\n",
|
73 |
+
"df = df[df[\"no_bl_whitelist_fraction\"] != -1.0]\n",
|
74 |
+
"df = df[df[\"w_bl_whitelist_fraction\"] != -1.0]\n",
|
75 |
+
"\n",
|
76 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
|
77 |
+
"\n",
|
78 |
+
"orig_len = len(df)\n",
|
79 |
+
"# df = df[df[\"no_bl_ppl\"].isna()]\n",
|
80 |
+
"# df = df[df[\"w_bl_ppl\"].isna()]\n",
|
81 |
+
"df = df[~(df[\"no_bl_ppl\"].isna() | df[\"w_bl_ppl\"].isna())]\n",
|
82 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
|
83 |
+
"\n",
|
84 |
+
"orig_len = len(df)\n",
|
85 |
+
"\n",
|
86 |
+
"df = df[df[\"bl_logit_bias\"] <= 100.0]\n",
|
87 |
+
"\n",
|
88 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
|
89 |
+
"\n",
|
90 |
+
"\n",
|
91 |
+
"orig_len = len(df)\n",
|
92 |
+
"\n",
|
93 |
+
"# df = df[df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False and tup[2] != 1) or (tup[0] == True and tup[2] == 1) or (tup[0] == False))]\n",
|
94 |
+
"df = df[((df[\"use_sampling\"]==True) & (df[\"num_beams\"] == 1)) | (df[\"use_sampling\"]==False)]\n",
|
95 |
+
"\n",
|
96 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
|
97 |
+
"\n",
|
98 |
+
"\n",
|
99 |
+
"df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"].fillna(0.0)\n",
|
100 |
+
"df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"].fillna(1.0)\n",
|
101 |
+
"\n",
|
102 |
+
"\n",
|
103 |
+
"df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = np.inf\n",
|
104 |
+
"# df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = 10000 # crosscheck with whats hardcoded in the bl processor\n",
|
105 |
+
"\n",
|
106 |
+
"\n",
|
107 |
+
"df[\"delta\"] = df[\"bl_logit_bias\"].values\n",
|
108 |
+
"df[\"gamma\"] = 1 - df[\"bl_proportion\"].values\n",
|
109 |
+
"df[\"gamma\"] = df[\"gamma\"].round(3)\n",
|
110 |
+
"\n",
|
111 |
+
"df[\"no_bl_act_num_wl_tokens\"] = np.round(df[\"no_bl_whitelist_fraction\"].values*df[\"no_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
|
112 |
+
"df[\"w_bl_act_num_wl_tokens\"] = np.round(df[\"w_bl_whitelist_fraction\"].values*df[\"w_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
|
113 |
+
"\n",
|
114 |
+
"df[\"w_bl_std_num_wl_tokens\"] = np.sqrt(df[\"w_bl_var_num_wl_tokens\"].values)\n",
|
115 |
+
"\n",
|
116 |
+
"if \"real_completion_length\":\n",
|
117 |
+
" df[\"baseline_num_tokens_generated\"] = df[\"real_completion_length\"].values\n",
|
118 |
+
"\n",
|
119 |
+
"if \"actual_attacked_ratio\" in df.columns:\n",
|
120 |
+
" df[\"actual_attacked_fraction\"] = df[\"actual_attacked_ratio\"].values*df[\"replace_ratio\"].values\n",
|
121 |
+
"\n",
|
122 |
+
"\n",
|
123 |
+
"\n",
|
124 |
+
"df[\"baseline_hit_list_length\"] = df[\"baseline_hit_list\"].apply(len)\n",
|
125 |
+
"df[\"no_bl_hit_list_length\"] = df[\"no_bl_hit_list\"].apply(len)\n",
|
126 |
+
"df[\"w_bl_hit_list_length\"] = df[\"w_bl_hit_list\"].apply(len)"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"attachments": {},
|
131 |
+
"cell_type": "markdown",
|
132 |
+
"metadata": {},
|
133 |
+
"source": [
|
134 |
+
"## Filter for the generation lengths we want to look at"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": null,
|
140 |
+
"metadata": {},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"orig_len = len(df)\n",
|
144 |
+
"\n",
|
145 |
+
"upper_T = 205\n",
|
146 |
+
"lower_T = 195\n",
|
147 |
+
"df = df[(df[\"baseline_hit_list_length\"] >= lower_T) & (df[\"no_bl_hit_list_length\"] >= lower_T) & (df[\"w_bl_hit_list_length\"] >= lower_T)] # now also applies to the truncated version\n",
|
148 |
+
"df = df[(df[\"baseline_hit_list_length\"] <= upper_T) & (df[\"no_bl_hit_list_length\"] <= upper_T) & (df[\"w_bl_hit_list_length\"] <= upper_T)] # now also applies to the truncated version\n",
|
149 |
+
"\n",
|
150 |
+
"\n",
|
151 |
+
"print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"attachments": {},
|
156 |
+
"cell_type": "markdown",
|
157 |
+
"metadata": {},
|
158 |
+
"source": [
|
159 |
+
"#### Add z-scores"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [],
|
167 |
+
"source": [
|
168 |
+
"from math import sqrt\n",
|
169 |
+
"import scipy.stats\n",
|
170 |
+
"def compute_z_score(observed_wl_frac, T, gamma):\n",
|
171 |
+
" numer = observed_wl_frac - gamma\n",
|
172 |
+
" denom = sqrt(gamma*(1-gamma)/T)\n",
|
173 |
+
" z = numer/denom\n",
|
174 |
+
" return z\n",
|
175 |
+
"\n",
|
176 |
+
"def compute_wl_for_z(z, T, gamma):\n",
|
177 |
+
" denom = sqrt(gamma*(1-gamma)/T)\n",
|
178 |
+
" numer = ((z*denom)+gamma)*T\n",
|
179 |
+
" return numer\n",
|
180 |
+
"\n",
|
181 |
+
"def compute_p_value(z):\n",
|
182 |
+
" p_value = scipy.stats.norm.sf(abs(z))\n",
|
183 |
+
" return p_value\n",
|
184 |
+
"\n",
|
185 |
+
"df[\"baseline_z_score\"] = df[[\"baseline_whitelist_fraction\", \"baseline_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
|
186 |
+
"df[\"no_bl_z_score\"] = df[[\"no_bl_whitelist_fraction\", \"no_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
|
187 |
+
"df[\"w_bl_z_score\"] = df[[\"w_bl_whitelist_fraction\", \"w_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
|
188 |
+
"\n",
|
189 |
+
"if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
|
190 |
+
" df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "code",
|
195 |
+
"execution_count": null,
|
196 |
+
"metadata": {},
|
197 |
+
"outputs": [],
|
198 |
+
"source": [
|
199 |
+
"# # if attacked in df\n",
|
200 |
+
"if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
|
201 |
+
" df[\"w_bl_attacked_act_num_wl_tokens\"] = np.round(df[\"w_bl_attacked_whitelist_fraction\"].values*df[\"w_bl_attacked_num_tokens_generated\"],1) # round to 1 for sanity\n",
|
202 |
+
"\n",
|
203 |
+
" df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
|
204 |
+
"\n",
|
205 |
+
" df[[\"bl_proportion\",\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\",\"w_bl_attacked_act_num_wl_tokens\", \"w_bl_attacked_z_score\"]]"
|
206 |
+
]
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "markdown",
|
210 |
+
"metadata": {},
|
211 |
+
"source": [
|
212 |
+
"#### Groupby"
|
213 |
+
]
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"cell_type": "code",
|
217 |
+
"execution_count": null,
|
218 |
+
"metadata": {},
|
219 |
+
"outputs": [],
|
220 |
+
"source": [
|
221 |
+
"if \"w_bl_attacked_whitelist_fraction\" in df.columns: \n",
|
222 |
+
" groupby_fields = ['use_sampling','num_beams','gamma','delta', 'replace_ratio'] # attack grouping\n",
|
223 |
+
"else:\n",
|
224 |
+
" groupby_fields = ['use_sampling','num_beams','delta','gamma'] # regular grouping\n",
|
225 |
+
" # groupby_fields = ['use_sampling','delta','gamma'] # regular grouping, but no beam variation\n",
|
226 |
+
" # groupby_fields = ['delta','gamma'] # regular grouping, but no beam variation, and all sampling"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"attachments": {},
|
231 |
+
"cell_type": "markdown",
|
232 |
+
"metadata": {},
|
233 |
+
"source": [
|
234 |
+
"#### Main groupby"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"cell_type": "code",
|
239 |
+
"execution_count": null,
|
240 |
+
"metadata": {},
|
241 |
+
"outputs": [],
|
242 |
+
"source": [
|
243 |
+
"grouped_df = df.groupby(groupby_fields)"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": null,
|
249 |
+
"metadata": {},
|
250 |
+
"outputs": [],
|
251 |
+
"source": [
|
252 |
+
"print(f\"Number of rows after filtering: {len(df)}\")\n",
|
253 |
+
"print(f\"Number of groups: {len(grouped_df)}\")"
|
254 |
+
]
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"attachments": {},
|
258 |
+
"cell_type": "markdown",
|
259 |
+
"metadata": {},
|
260 |
+
"source": [
|
261 |
+
"### Loop to compute confusion matrix at some z scores for tabulation"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "code",
|
266 |
+
"execution_count": null,
|
267 |
+
"metadata": {},
|
268 |
+
"outputs": [],
|
269 |
+
"source": [
|
270 |
+
"import sklearn.metrics as metrics\n",
|
271 |
+
"\n",
|
272 |
+
"def reject_null_hypo(z_score=None,cuttoff=None):\n",
|
273 |
+
" return z_score > cuttoff\n",
|
274 |
+
"\n",
|
275 |
+
"records = []\n",
|
276 |
+
"\n",
|
277 |
+
"for group_params in tqdm(list(grouped_df.groups.keys())):\n",
|
278 |
+
" sub_df = grouped_df.get_group(group_params)\n",
|
279 |
+
" grp_size = len(sub_df)\n",
|
280 |
+
"\n",
|
281 |
+
" # baseline_z_scores = sub_df[\"baseline_z_score\"].values\n",
|
282 |
+
" # w_bl_z_scores = sub_df[\"w_bl_z_score\"].values\n",
|
283 |
+
" # all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n",
|
284 |
+
"\n",
|
285 |
+
" # baseline_labels = np.zeros_like(baseline_z_scores)\n",
|
286 |
+
" # attacked_labels = np.ones_like(w_bl_z_scores)\n",
|
287 |
+
" # all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
|
288 |
+
"\n",
|
289 |
+
" # fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
|
290 |
+
" # roc_auc = metrics.auc(fpr, tpr)\n",
|
291 |
+
" record = {k:v for k,v in zip(groupby_fields,group_params)}\n",
|
292 |
+
"\n",
|
293 |
+
" for thresh in [4.0,5.0]:\n",
|
294 |
+
" \n",
|
295 |
+
" record[\"count\"] = grp_size\n",
|
296 |
+
" record[f\"baseline_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"baseline_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
|
297 |
+
" record[f\"baseline_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"baseline_z_score\"],cuttoff=thresh)).sum() / grp_size\n",
|
298 |
+
" record[f\"no_bl_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
|
299 |
+
" record[f\"no_bl_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
|
300 |
+
" record[f\"w_bl_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
|
301 |
+
" record[f\"w_bl_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
|
302 |
+
"\n",
|
303 |
+
" if \"w_bl_attacked_z_score\" in sub_df.columns:\n",
|
304 |
+
" record[f\"w_bl_attacked_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
|
305 |
+
" record[f\"w_bl_attacked_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
|
306 |
+
"\n",
|
307 |
+
" records.append(record)\n",
|
308 |
+
"\n",
|
309 |
+
" # # df[f\"baseline_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"baseline_z_score\"].values,cuttoff=thresh)\n",
|
310 |
+
" # # df[f\"baseline_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"baseline_z_score\"],cuttoff=thresh)\n",
|
311 |
+
" # # df[f\"no_bl_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
|
312 |
+
" # # df[f\"no_bl_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
|
313 |
+
" # # df[f\"w_bl_tp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
|
314 |
+
" # # df[f\"w_bl_fn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
|
315 |
+
"\n",
|
316 |
+
"\n",
|
317 |
+
"roc_df = pd.DataFrame.from_records(records)\n"
|
318 |
+
]
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"cell_type": "code",
|
322 |
+
"execution_count": null,
|
323 |
+
"metadata": {},
|
324 |
+
"outputs": [],
|
325 |
+
"source": [
|
326 |
+
"# thresh = 6.0\n",
|
327 |
+
"# thresh = 5.0\n",
|
328 |
+
"std_threshes = [4.0, 5.0] #, 6.0]\n",
|
329 |
+
"# std_threshes = [4.0]\n",
|
330 |
+
"\n",
|
331 |
+
"# roc_df[\"params\"] = roc_df.index.to_list()\n",
|
332 |
+
"\n",
|
333 |
+
"columns = [\"delta\", \"gamma\", \"count\"]\n",
|
334 |
+
"# columns = [\"use_sampling\", \"replace_ratio\", \"count\"]\n",
|
335 |
+
"\n",
|
336 |
+
"for thresh in std_threshes:\n",
|
337 |
+
" # columns += [f\"baseline_fpr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\"]\n",
|
338 |
+
" # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"no_bl_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fn_at_{thresh}\"]\n",
|
339 |
+
"\n",
|
340 |
+
"\n",
|
341 |
+
" # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
|
342 |
+
" \n",
|
343 |
+
" if f\"w_bl_attacked_fnr_at_{thresh}\" in roc_df.columns:\n",
|
344 |
+
" columns += [f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
|
345 |
+
" columns += [f\"w_bl_attacked_tpr_at_{thresh}\",f\"w_bl_attacked_fnr_at_{thresh}\"] # if attack\n",
|
346 |
+
" else:\n",
|
347 |
+
" columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
|
348 |
+
"\n",
|
349 |
+
"# filter ot not\n",
|
350 |
+
"sub_df = roc_df[(roc_df[\"use_sampling\"] == True) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 10.0)) & ((roc_df[\"gamma\"] == 0.1) | (roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) )]\n",
|
351 |
+
"# sub_df = roc_df[(roc_df[\"replace_ratio\"] == 0.1) | (roc_df[\"replace_ratio\"] == 0.3) | (roc_df[\"replace_ratio\"] == 0.5) | (roc_df[\"replace_ratio\"] == 0.7)]\n",
|
352 |
+
"# sub_df = roc_df\n",
|
353 |
+
"\n",
|
354 |
+
"sub_df.sort_values(\"delta\")[columns]\n",
|
355 |
+
"# sub_df.sort_values(\"num_beams\")[columns]"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": null,
|
361 |
+
"metadata": {},
|
362 |
+
"outputs": [],
|
363 |
+
"source": [
|
364 |
+
"# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"gamma\").round(3).to_latex(index=False))\n",
|
365 |
+
"# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"delta\").round(3).to_latex(index=False))\n",
|
366 |
+
"# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"num_beams\").round(3).to_latex(index=False))\n",
|
367 |
+
"\n",
|
368 |
+
"print(sub_df.sort_values(\"delta\")[columns].round(3).to_latex(index=False))\n",
|
369 |
+
"# print(sub_df.sort_values(\"num_beams\")[columns].round(3).to_latex(index=False))"
|
370 |
+
]
|
371 |
+
},
|
372 |
+
{
|
373 |
+
"attachments": {},
|
374 |
+
"cell_type": "markdown",
|
375 |
+
"metadata": {},
|
376 |
+
"source": [
|
377 |
+
"### write to csv maybe"
|
378 |
+
]
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"cell_type": "code",
|
382 |
+
"execution_count": null,
|
383 |
+
"metadata": {},
|
384 |
+
"outputs": [],
|
385 |
+
"source": [
|
386 |
+
"# cols_to_drop = ['no_bl_gen_time',\n",
|
387 |
+
"# 'w_bl_gen_time', 'spike_entropies', \n",
|
388 |
+
"# 'no_bl_sec_per_tok', 'no_bl_tok_per_sec', 'w_bl_sec_per_tok',\n",
|
389 |
+
"# 'w_bl_tok_per_sec', 'baseline_loss','no_bl_loss',\n",
|
390 |
+
"# 'w_bl_loss', 'model_name', 'dataset_name',\n",
|
391 |
+
"# 'dataset_config_name', 'shuffle_dataset', 'shuffle_seed',\n",
|
392 |
+
"# 'shuffle_buffer_size', 'max_new_tokens', 'min_prompt_tokens',\n",
|
393 |
+
"# 'limit_indices', 'input_truncation_strategy',\n",
|
394 |
+
"# 'input_filtering_strategy', 'output_filtering_strategy', 'initial_seed',\n",
|
395 |
+
"# 'dynamic_seed','no_repeat_ngram_size', 'early_stopping',\n",
|
396 |
+
"# 'oracle_model_name', 'no_wandb', 'wandb_project', 'wandb_entity', 'output_dir', 'load_prev_generations', 'store_bl_ids',\n",
|
397 |
+
"# 'store_spike_ents', 'generate_only',\n",
|
398 |
+
"# 'SLURM_JOB_ID', 'SLURM_ARRAY_JOB_ID', 'SLURM_ARRAY_TASK_ID',\n",
|
399 |
+
"# 'gen_table_already_existed', 'baseline_num_toks_gend_eq_0',\n",
|
400 |
+
"# 'baseline_hit_list', 'no_bl_num_toks_gend_eq_0',\n",
|
401 |
+
"# 'no_bl_hit_list', 'w_bl_num_toks_gend_eq_0', 'w_bl_hit_list']\n",
|
402 |
+
"# df.drop(cols_to_drop,axis=1).to_csv(\"input/for_poking.csv\")\n",
|
403 |
+
"# df"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "code",
|
408 |
+
"execution_count": null,
|
409 |
+
"metadata": {},
|
410 |
+
"outputs": [],
|
411 |
+
"source": [
|
412 |
+
"df.columns"
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "code",
|
417 |
+
"execution_count": null,
|
418 |
+
"metadata": {},
|
419 |
+
"outputs": [],
|
420 |
+
"source": []
|
421 |
+
},
|
422 |
+
{
|
423 |
+
"attachments": {},
|
424 |
+
"cell_type": "markdown",
|
425 |
+
"metadata": {},
|
426 |
+
"source": [
|
427 |
+
"# Extract examples (actual text) for tabulation based on entropy and z scores (tables 1,3,4,5,6)"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"cell_type": "code",
|
432 |
+
"execution_count": null,
|
433 |
+
"metadata": {},
|
434 |
+
"outputs": [],
|
435 |
+
"source": [
|
436 |
+
"print(f\"groupby legend: {groupby_fields}\")"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "code",
|
441 |
+
"execution_count": null,
|
442 |
+
"metadata": {},
|
443 |
+
"outputs": [],
|
444 |
+
"source": [
|
445 |
+
"groups = [\n",
|
446 |
+
" (True, 1, 2.0, 0.5),\n",
|
447 |
+
" # (True, 1, 10.0, 0.5),\n",
|
448 |
+
" # (False, 8, 2.0, 0.5),\n",
|
449 |
+
" # (False, 8, 10.0, 0.5),\n",
|
450 |
+
"]\n",
|
451 |
+
"group_dfs = []\n",
|
452 |
+
"for group in groups:\n",
|
453 |
+
" sub_df = grouped_df.get_group(group)\n",
|
454 |
+
" group_dfs.append(sub_df)\n",
|
455 |
+
"\n",
|
456 |
+
"subset_df = pd.concat(group_dfs,axis=0)\n",
|
457 |
+
"\n",
|
458 |
+
"print(len(subset_df))\n",
|
459 |
+
"# subset_df\n",
|
460 |
+
"\n",
|
461 |
+
"# cols_to_tabulate = groupby_fields + [\n",
|
462 |
+
"cols_to_tabulate = [\n",
|
463 |
+
" 'idx', \n",
|
464 |
+
" 'truncated_input', \n",
|
465 |
+
" # 'prompt_length',\n",
|
466 |
+
" 'baseline_completion',\n",
|
467 |
+
" 'no_bl_output', \n",
|
468 |
+
" 'w_bl_output', \n",
|
469 |
+
" # 'real_completion_length',\n",
|
470 |
+
" # 'no_bl_num_tokens_generated',\n",
|
471 |
+
" # 'w_bl_num_tokens_generated',\n",
|
472 |
+
" 'avg_spike_entropy',\n",
|
473 |
+
" # 'baseline_whitelist_fraction',\n",
|
474 |
+
" # 'no_bl_whitelist_fraction',\n",
|
475 |
+
" # 'w_bl_whitelist_fraction',\n",
|
476 |
+
" # 'baseline_z_score',\n",
|
477 |
+
" 'no_bl_z_score',\n",
|
478 |
+
" 'w_bl_z_score',\n",
|
479 |
+
" # 'baseline_ppl',\n",
|
480 |
+
" 'no_bl_ppl',\n",
|
481 |
+
" 'w_bl_ppl'\n",
|
482 |
+
"]\n",
|
483 |
+
"\n",
|
484 |
+
"# subset_df[cols_to_tabulate][\"idx\"].value_counts()\n",
|
485 |
+
"\n",
|
486 |
+
"for idx,occurrences in subset_df[\"idx\"].value_counts().to_dict().items():\n",
|
487 |
+
" subset_df.loc[(subset_df[\"idx\"]==idx),\"occurences\"] = occurrences\n",
|
488 |
+
"\n",
|
489 |
+
"subset_df[\"occurences\"] = subset_df[\"occurences\"].apply(int)\n",
|
490 |
+
"\n",
|
491 |
+
"# cols_to_tabulate = [\"occurences\"] + cols_to_tabulate"
|
492 |
+
]
|
493 |
+
},
|
494 |
+
{
|
495 |
+
"cell_type": "code",
|
496 |
+
"execution_count": null,
|
497 |
+
"metadata": {},
|
498 |
+
"outputs": [],
|
499 |
+
"source": [
|
500 |
+
"# subset_df[cols_to_tabulate].sort_values([\"occurences\", \"idx\"],ascending=False)\n",
|
501 |
+
"# subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=False)"
|
502 |
+
]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"cell_type": "code",
|
506 |
+
"execution_count": null,
|
507 |
+
"metadata": {},
|
508 |
+
"outputs": [],
|
509 |
+
"source": [
|
510 |
+
"max_prompt_chars = 200\n",
|
511 |
+
"max_output_chars = 200\n",
|
512 |
+
"# subset_df[\"truncated_input\"] = subset_df[\"truncated_input\"].apply(lambda s: f\"[...]{s[-max_prompt_chars:]}\")\n",
|
513 |
+
"# subset_df[\"baseline_completion\"] = subset_df[\"baseline_completion\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n",
|
514 |
+
"# subset_df[\"no_bl_output\"] = subset_df[\"no_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n",
|
515 |
+
"# subset_df[\"w_bl_output\"] = subset_df[\"w_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n",
|
516 |
+
"\n",
|
517 |
+
"# if you dont have the indexx you cant start with brackets\n",
|
518 |
+
"subset_df[\"truncated_input\"] = subset_df[\"truncated_input\"].apply(lambda s: f\"(...){s[-max_prompt_chars:]}\")\n",
|
519 |
+
"subset_df[\"baseline_completion\"] = subset_df[\"baseline_completion\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n",
|
520 |
+
"subset_df[\"no_bl_output\"] = subset_df[\"no_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n",
|
521 |
+
"subset_df[\"w_bl_output\"] = subset_df[\"w_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n"
|
522 |
+
]
|
523 |
+
},
|
524 |
+
{
|
525 |
+
"cell_type": "code",
|
526 |
+
"execution_count": null,
|
527 |
+
"metadata": {},
|
528 |
+
"outputs": [],
|
529 |
+
"source": [
|
530 |
+
"slice_size = 2\n",
|
531 |
+
"\n",
|
532 |
+
"# subset_df[cols_to_tabulate][\"avg_spike_entropy\"].describe()[]"
|
533 |
+
]
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "code",
|
537 |
+
"execution_count": null,
|
538 |
+
"metadata": {},
|
539 |
+
"outputs": [],
|
540 |
+
"source": [
|
541 |
+
"num_examples = len(subset_df)\n",
|
542 |
+
"midpt = num_examples//5\n",
|
543 |
+
"lower = midpt - (slice_size//2)\n",
|
544 |
+
"upper = midpt + (slice_size//2)+1\n",
|
545 |
+
"\n",
|
546 |
+
"high_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).tail(slice_size)\n",
|
547 |
+
"mid_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).iloc[lower:upper]\n",
|
548 |
+
"low_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).head(slice_size)\n",
|
549 |
+
"\n",
|
550 |
+
"num_examples = len(subset_df)\n",
|
551 |
+
"midpt = num_examples//65\n",
|
552 |
+
"lower = midpt - (slice_size//2)\n",
|
553 |
+
"upper = midpt + (slice_size//2)+1\n",
|
554 |
+
"\n",
|
555 |
+
"high_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).tail(slice_size)\n",
|
556 |
+
"mid_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).iloc[lower:upper]\n",
|
557 |
+
"low_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).head(slice_size)"
|
558 |
+
]
|
559 |
+
},
|
560 |
+
{
|
561 |
+
"cell_type": "code",
|
562 |
+
"execution_count": null,
|
563 |
+
"metadata": {},
|
564 |
+
"outputs": [],
|
565 |
+
"source": [
|
566 |
+
"# high_entropy_examples.head()\n",
|
567 |
+
"high_z_examples.head()"
|
568 |
+
]
|
569 |
+
},
|
570 |
+
{
|
571 |
+
"cell_type": "code",
|
572 |
+
"execution_count": null,
|
573 |
+
"metadata": {},
|
574 |
+
"outputs": [],
|
575 |
+
"source": [
|
576 |
+
"# mid_entropy_examples.head()\n",
|
577 |
+
"mid_z_examples.head()\n"
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "code",
|
582 |
+
"execution_count": null,
|
583 |
+
"metadata": {},
|
584 |
+
"outputs": [],
|
585 |
+
"source": [
|
586 |
+
"# low_entropy_examples.head()\n",
|
587 |
+
"low_z_examples.head()"
|
588 |
+
]
|
589 |
+
},
|
590 |
+
{
|
591 |
+
"cell_type": "code",
|
592 |
+
"execution_count": null,
|
593 |
+
"metadata": {},
|
594 |
+
"outputs": [],
|
595 |
+
"source": [
|
596 |
+
"# slices_set_df = pd.concat([high_entropy_examples,low_entropy_examples],axis=0)\n",
|
597 |
+
"slices_set_df = pd.concat([high_z_examples,low_z_examples],axis=0).sort_values(\"w_bl_z_score\",ascending=False)\n",
|
598 |
+
"slices_set_df"
|
599 |
+
]
|
600 |
+
},
|
601 |
+
{
|
602 |
+
"cell_type": "code",
|
603 |
+
"execution_count": null,
|
604 |
+
"metadata": {},
|
605 |
+
"outputs": [],
|
606 |
+
"source": [
|
607 |
+
"# slices_set_df.T.iloc[:,0:2]"
|
608 |
+
]
|
609 |
+
},
|
610 |
+
{
|
611 |
+
"cell_type": "code",
|
612 |
+
"execution_count": null,
|
613 |
+
"metadata": {},
|
614 |
+
"outputs": [],
|
615 |
+
"source": [
|
616 |
+
"# print(slices_set_df.to_latex(index=False))\n",
|
617 |
+
"# print(low_entropy_examples.to_latex(index=False))\n",
|
618 |
+
"# print(mid_entropy_examples.to_latex(index=False))\n",
|
619 |
+
"# print(high_entropy_examples.to_latex(index=False))"
|
620 |
+
]
|
621 |
+
},
|
622 |
+
{
|
623 |
+
"cell_type": "code",
|
624 |
+
"execution_count": null,
|
625 |
+
"metadata": {},
|
626 |
+
"outputs": [],
|
627 |
+
"source": [
|
628 |
+
"# for c,t in zip(low_entropy_examples.columns,low_entropy_examples.dtypes):\n",
|
629 |
+
"# if t==object:\n",
|
630 |
+
"# low_entropy_examples[c] = low_entropy_examples[c].apply(lambda s: f\"{s[:100]}[...truncated]\")"
|
631 |
+
]
|
632 |
+
},
|
633 |
+
{
|
634 |
+
"cell_type": "code",
|
635 |
+
"execution_count": null,
|
636 |
+
"metadata": {},
|
637 |
+
"outputs": [],
|
638 |
+
"source": [
|
639 |
+
"# low_entropy_examples.T.to_latex(buf=open(\"figs/low_ent_examples.txt\", \"w\"),index=False)"
|
640 |
+
]
|
641 |
+
},
|
642 |
+
{
|
643 |
+
"cell_type": "code",
|
644 |
+
"execution_count": null,
|
645 |
+
"metadata": {},
|
646 |
+
"outputs": [],
|
647 |
+
"source": [
|
648 |
+
"# df_to_write = high_entropy_examples\n",
|
649 |
+
"# df_to_write = mid_entropy_examples\n",
|
650 |
+
"# df_to_write = low_entropy_examples\n",
|
651 |
+
"# df_to_write = high_z_examples\n",
|
652 |
+
"# df_to_write = mid_z_examples\n",
|
653 |
+
"# df_to_write = low_z_examples\n",
|
654 |
+
"\n",
|
655 |
+
"cols_to_drop = [\"idx\", \"avg_spike_entropy\", \"no_bl_z_score\"] #, \"no_bl_ppl\", \"w_bl_ppl\"]\n",
|
656 |
+
"df_to_write = slices_set_df.drop(cols_to_drop,axis=1)\n",
|
657 |
+
"\n",
|
658 |
+
"\n",
|
659 |
+
"with pd.option_context(\"max_colwidth\", 1000):\n",
|
660 |
+
" column_format=\"\".join([(r'p{3cm}|' if t==object else r'p{0.4cm}|') for c,t in zip(df_to_write.columns,df_to_write.dtypes)])[:-1]\n",
|
661 |
+
" # low_entropy_examples.round(2).to_latex(buf=open(\"figs/low_ent_examples.txt\", \"w\"),column_format=column_format,index=False)\n",
|
662 |
+
" latex_str = df_to_write.round(2).to_latex(column_format=column_format,index=False)\n",
|
663 |
+
"\n",
|
664 |
+
"print(latex_str)"
|
665 |
+
]
|
666 |
+
},
|
667 |
+
{
|
668 |
+
"cell_type": "code",
|
669 |
+
"execution_count": null,
|
670 |
+
"metadata": {},
|
671 |
+
"outputs": [],
|
672 |
+
"source": []
|
673 |
+
},
|
674 |
+
{
|
675 |
+
"cell_type": "code",
|
676 |
+
"execution_count": null,
|
677 |
+
"metadata": {},
|
678 |
+
"outputs": [],
|
679 |
+
"source": []
|
680 |
+
},
|
681 |
+
{
|
682 |
+
"cell_type": "code",
|
683 |
+
"execution_count": null,
|
684 |
+
"metadata": {},
|
685 |
+
"outputs": [],
|
686 |
+
"source": [
|
687 |
+
"# column_format=\"\".join([r'p{2cm}|' for c in low_entropy_examples.columns])\n",
|
688 |
+
"# column_format"
|
689 |
+
]
|
690 |
+
},
|
691 |
+
{
|
692 |
+
"cell_type": "code",
|
693 |
+
"execution_count": null,
|
694 |
+
"metadata": {},
|
695 |
+
"outputs": [],
|
696 |
+
"source": [
|
697 |
+
"# low_entropy_examples.dtypes"
|
698 |
+
]
|
699 |
+
},
|
700 |
+
{
|
701 |
+
"cell_type": "code",
|
702 |
+
"execution_count": null,
|
703 |
+
"metadata": {},
|
704 |
+
"outputs": [],
|
705 |
+
"source": [
|
706 |
+
"with pd.option_context(\"max_colwidth\", 1000):\n",
|
707 |
+
" print(grouped_df.get_group((True, 1, 2.0, 0.9)).head(10)[\"w_bl_output\"])"
|
708 |
+
]
|
709 |
+
},
|
710 |
+
{
|
711 |
+
"cell_type": "code",
|
712 |
+
"execution_count": null,
|
713 |
+
"metadata": {},
|
714 |
+
"outputs": [],
|
715 |
+
"source": []
|
716 |
+
},
|
717 |
+
{
|
718 |
+
"cell_type": "code",
|
719 |
+
"execution_count": null,
|
720 |
+
"metadata": {},
|
721 |
+
"outputs": [],
|
722 |
+
"source": []
|
723 |
+
},
|
724 |
+
{
|
725 |
+
"attachments": {},
|
726 |
+
"cell_type": "markdown",
|
727 |
+
"metadata": {},
|
728 |
+
"source": [
|
729 |
+
"### Set up data for charts"
|
730 |
+
]
|
731 |
+
},
|
732 |
+
{
|
733 |
+
"cell_type": "code",
|
734 |
+
"execution_count": null,
|
735 |
+
"metadata": {},
|
736 |
+
"outputs": [],
|
737 |
+
"source": [
|
738 |
+
"# viz_df = pd.DataFrame()\n",
|
739 |
+
"\n",
|
740 |
+
"# # set the hparam keys, including an indiv column for each you want to ablate on\n",
|
741 |
+
"# viz_df[\"bl_hparams\"] = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe().index.to_list()\n",
|
742 |
+
"# for i,key in enumerate(groupby_fields):\n",
|
743 |
+
"# viz_df[key] = viz_df[\"bl_hparams\"].apply(lambda tup: tup[i])\n",
|
744 |
+
"\n",
|
745 |
+
"# describe_dict = grouped_df[\"w_bl_whitelist_fraction\"].describe()\n",
|
746 |
+
"# viz_df[\"w_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
747 |
+
"# viz_df[\"w_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
|
748 |
+
"\n",
|
749 |
+
"# describe_dict = grouped_df[\"no_bl_whitelist_fraction\"].describe()\n",
|
750 |
+
"# viz_df[\"no_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
751 |
+
"# viz_df[\"no_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
|
752 |
+
"\n",
|
753 |
+
"\n",
|
754 |
+
"# describe_dict = grouped_df[\"w_bl_z_score\"].describe()\n",
|
755 |
+
"# viz_df[\"w_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
756 |
+
"# viz_df[\"w_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
|
757 |
+
"\n",
|
758 |
+
"# describe_dict = grouped_df[\"no_bl_z_score\"].describe()\n",
|
759 |
+
"# viz_df[\"no_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
760 |
+
"# viz_df[\"no_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
|
761 |
+
"\n",
|
762 |
+
"\n",
|
763 |
+
"# describe_dict = grouped_df[\"w_bl_ppl\"].describe()\n",
|
764 |
+
"# viz_df[\"w_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
765 |
+
"# viz_df[\"w_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
|
766 |
+
"\n",
|
767 |
+
"# describe_dict = grouped_df[\"no_bl_ppl\"].describe()\n",
|
768 |
+
"# viz_df[\"no_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
769 |
+
"# viz_df[\"no_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
|
770 |
+
"\n",
|
771 |
+
"# describe_dict = grouped_df[\"avg_spike_entropy\"].describe()\n",
|
772 |
+
"# viz_df[\"avg_spike_entropy_mean\"] = describe_dict[\"mean\"].to_list()\n",
|
773 |
+
"# viz_df[\"avg_spike_entropy_std\"] = describe_dict[\"std\"].to_list()\n",
|
774 |
+
"\n",
|
775 |
+
"# print(f\"groupby legend: {groupby_fields}\")\n"
|
776 |
+
]
|
777 |
+
},
|
778 |
+
{
|
779 |
+
"cell_type": "code",
|
780 |
+
"execution_count": null,
|
781 |
+
"metadata": {},
|
782 |
+
"outputs": [],
|
783 |
+
"source": [
|
784 |
+
"# # filtering\n",
|
785 |
+
"\n",
|
786 |
+
"# viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == True))] # sampling\n",
|
787 |
+
"\n",
|
788 |
+
"# # viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False))] # greedy\n",
|
789 |
+
"\n",
|
790 |
+
"\n",
|
791 |
+
"# # fix one of the bl params for analytic chart\n",
|
792 |
+
"# viz_df = viz_df[(viz_df[\"gamma\"]==0.5) & (viz_df[\"delta\"]<=10.0)]\n",
|
793 |
+
"\n",
|
794 |
+
"# # viz_df = viz_df[(viz_df[\"delta\"] > 0.5) & (viz_df[\"delta\"]<=10.0)]\n",
|
795 |
+
"\n",
|
796 |
+
"# # viz_df = viz_df[(viz_df[\"delta\"]==0.5) | (viz_df[\"delta\"]==2.0) | (viz_df[\"delta\"]==10.0)]\n",
|
797 |
+
"\n",
|
798 |
+
"# # viz_df = viz_df[(viz_df[\"delta\"]!=0.1)&(viz_df[\"delta\"]!=0.5)&(viz_df[\"delta\"]!=50.0)]\n",
|
799 |
+
"\n",
|
800 |
+
"# # viz_df = viz_df[(viz_df[\"delta\"]!=50.0)]\n",
|
801 |
+
"# # viz_df = viz_df[(viz_df[\"delta\"]!=50.0) & (viz_df[\"num_beams\"]!=1)]\n",
|
802 |
+
"\n",
|
803 |
+
"# print(len(viz_df))\n",
|
804 |
+
"\n",
|
805 |
+
"# viz_df"
|
806 |
+
]
|
807 |
+
},
|
808 |
+
{
|
809 |
+
"cell_type": "markdown",
|
810 |
+
"metadata": {},
|
811 |
+
"source": [
|
812 |
+
"# Visualize the WL/BL hits via highlighting in html"
|
813 |
+
]
|
814 |
+
},
|
815 |
+
{
|
816 |
+
"cell_type": "code",
|
817 |
+
"execution_count": null,
|
818 |
+
"metadata": {},
|
819 |
+
"outputs": [],
|
820 |
+
"source": [
|
821 |
+
"# idx = 75\n",
|
822 |
+
"# # idx = 62\n",
|
823 |
+
"\n",
|
824 |
+
"# # debug\n",
|
825 |
+
"# # idx = 7\n",
|
826 |
+
"# # idx = 18\n",
|
827 |
+
"# # idx = 231\n",
|
828 |
+
"\n",
|
829 |
+
"# print(gen_table_w_bl_stats[idx])\n",
|
830 |
+
"# print(f\"\\nPrompt:\",gen_table_w_bl_stats[idx][\"truncated_input\"])\n",
|
831 |
+
"# print(f\"\\nBaseline (real text):{gen_table_w_bl_stats[idx]['baseline_completion']}\")\n",
|
832 |
+
"# print(f\"\\nNo Blacklist:{gen_table_w_bl_stats[idx]['no_bl_output']}\")\n",
|
833 |
+
"# print(f\"\\nw/ Blacklist:{gen_table_w_bl_stats[idx]['w_bl_output']}\")"
|
834 |
+
]
|
835 |
+
},
|
836 |
+
{
|
837 |
+
"cell_type": "code",
|
838 |
+
"execution_count": null,
|
839 |
+
"metadata": {},
|
840 |
+
"outputs": [],
|
841 |
+
"source": [
|
842 |
+
"# from ipymarkup import show_span_box_markup, get_span_box_markup\n",
|
843 |
+
"# from ipymarkup.palette import palette, RED, GREEN, BLUE\n",
|
844 |
+
"\n",
|
845 |
+
"# from IPython.display import display, HTML\n",
|
846 |
+
"\n",
|
847 |
+
"# from transformers import GPT2TokenizerFast\n",
|
848 |
+
"# # fast_tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
|
849 |
+
"# fast_tokenizer = GPT2TokenizerFast.from_pretrained(\"facebook/opt-2.7b\")"
|
850 |
+
]
|
851 |
+
},
|
852 |
+
{
|
853 |
+
"cell_type": "code",
|
854 |
+
"execution_count": null,
|
855 |
+
"metadata": {},
|
856 |
+
"outputs": [],
|
857 |
+
"source": [
|
858 |
+
"# %autoreload\n",
|
859 |
+
"\n",
|
860 |
+
"# vis_bl = partial(\n",
|
861 |
+
"# compute_bl_metrics,\n",
|
862 |
+
"# tokenizer=fast_tokenizer,\n",
|
863 |
+
"# hf_model_name=gen_table_meta[\"model_name\"],\n",
|
864 |
+
"# initial_seed=gen_table_meta[\"initial_seed\"],\n",
|
865 |
+
"# dynamic_seed=gen_table_meta[\"dynamic_seed\"],\n",
|
866 |
+
"# bl_proportion=gen_table_meta[\"bl_proportion\"],\n",
|
867 |
+
"# record_hits = True,\n",
|
868 |
+
"# use_cuda=True, # this is obvi critical to match the pseudorandomness\n",
|
869 |
+
"# )"
|
870 |
+
]
|
871 |
+
},
|
872 |
+
{
|
873 |
+
"cell_type": "code",
|
874 |
+
"execution_count": null,
|
875 |
+
"metadata": {},
|
876 |
+
"outputs": [],
|
877 |
+
"source": [
|
878 |
+
"# stats = vis_bl(gen_table_w_bl_stats[idx], 0)\n",
|
879 |
+
"\n",
|
880 |
+
"# baseline_hit_list = stats[\"baseline_hit_list\"]\n",
|
881 |
+
"# no_bl_hit_list = stats[\"no_bl_hit_list\"]\n",
|
882 |
+
"# w_bl_hit_list = stats[\"w_bl_hit_list\"]"
|
883 |
+
]
|
884 |
+
},
|
885 |
+
{
|
886 |
+
"cell_type": "code",
|
887 |
+
"execution_count": null,
|
888 |
+
"metadata": {},
|
889 |
+
"outputs": [],
|
890 |
+
"source": [
|
891 |
+
"# text = stats[\"truncated_input\"]\n",
|
892 |
+
"# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
|
893 |
+
"# hit_list = baseline_hit_list\n",
|
894 |
+
"\n",
|
895 |
+
"# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
|
896 |
+
"# charspans = [cs for cs in charspans if cs is not None]\n",
|
897 |
+
"# # spans = [(cs.start,cs.end, \"PR\") for i,cs in enumerate(charspans)]\n",
|
898 |
+
"# spans = []\n",
|
899 |
+
"\n",
|
900 |
+
"# html = get_span_box_markup(text, spans, palette=palette(PR=BLUE), background='white', text_color=\"black\")\n",
|
901 |
+
"\n",
|
902 |
+
"\n",
|
903 |
+
"# with open(\"figs/prompt_html.html\", \"w\") as f:\n",
|
904 |
+
"# f.write(HTML(html).data)\n",
|
905 |
+
"\n",
|
906 |
+
"# HTML(html)"
|
907 |
+
]
|
908 |
+
},
|
909 |
+
{
|
910 |
+
"cell_type": "code",
|
911 |
+
"execution_count": null,
|
912 |
+
"metadata": {},
|
913 |
+
"outputs": [],
|
914 |
+
"source": [
|
915 |
+
"# text = stats[\"baseline_completion\"]\n",
|
916 |
+
"# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
|
917 |
+
"# hit_list = baseline_hit_list\n",
|
918 |
+
"\n",
|
919 |
+
"# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
|
920 |
+
"# charspans = [cs for cs in charspans if cs is not None]\n",
|
921 |
+
"# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n",
|
922 |
+
"\n",
|
923 |
+
"# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n",
|
924 |
+
"\n",
|
925 |
+
"\n",
|
926 |
+
"# with open(\"figs/baseline_html.html\", \"w\") as f:\n",
|
927 |
+
"# f.write(HTML(html).data)\n",
|
928 |
+
"\n",
|
929 |
+
"# HTML(html)\n"
|
930 |
+
]
|
931 |
+
},
|
932 |
+
{
|
933 |
+
"cell_type": "code",
|
934 |
+
"execution_count": null,
|
935 |
+
"metadata": {},
|
936 |
+
"outputs": [],
|
937 |
+
"source": [
|
938 |
+
"\n",
|
939 |
+
"# text = stats[\"no_bl_output\"]\n",
|
940 |
+
"# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
|
941 |
+
"# hit_list = no_bl_hit_list\n",
|
942 |
+
"\n",
|
943 |
+
"# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
|
944 |
+
"# charspans = [cs for cs in charspans if cs is not None]\n",
|
945 |
+
"# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n",
|
946 |
+
"\n",
|
947 |
+
"# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n",
|
948 |
+
"\n",
|
949 |
+
"\n",
|
950 |
+
"# with open(\"figs/no_bl_html.html\", \"w\") as f:\n",
|
951 |
+
"# f.write(HTML(html).data)\n",
|
952 |
+
"\n",
|
953 |
+
"# HTML(html)\n"
|
954 |
+
]
|
955 |
+
},
|
956 |
+
{
|
957 |
+
"cell_type": "code",
|
958 |
+
"execution_count": null,
|
959 |
+
"metadata": {},
|
960 |
+
"outputs": [],
|
961 |
+
"source": [
|
962 |
+
"\n",
|
963 |
+
"# text = stats[\"w_bl_output\"]\n",
|
964 |
+
"# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
|
965 |
+
"# hit_list = w_bl_hit_list\n",
|
966 |
+
"\n",
|
967 |
+
"# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
|
968 |
+
"# charspans = [cs for cs in charspans if cs is not None]\n",
|
969 |
+
"# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n",
|
970 |
+
"\n",
|
971 |
+
"# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n",
|
972 |
+
"\n",
|
973 |
+
"\n",
|
974 |
+
"# with open(\"figs/w_bl_html.html\", \"w\") as f:\n",
|
975 |
+
"# f.write(HTML(html).data)\n",
|
976 |
+
"\n",
|
977 |
+
"# HTML(html)"
|
978 |
+
]
|
979 |
+
}
|
980 |
+
],
|
981 |
+
"metadata": {
|
982 |
+
"kernelspec": {
|
983 |
+
"display_name": "Python 3",
|
984 |
+
"language": "python",
|
985 |
+
"name": "python3"
|
986 |
+
},
|
987 |
+
"language_info": {
|
988 |
+
"codemirror_mode": {
|
989 |
+
"name": "ipython",
|
990 |
+
"version": 3
|
991 |
+
},
|
992 |
+
"file_extension": ".py",
|
993 |
+
"mimetype": "text/x-python",
|
994 |
+
"name": "python",
|
995 |
+
"nbconvert_exporter": "python",
|
996 |
+
"pygments_lexer": "ipython3",
|
997 |
+
"version": "3.10.6"
|
998 |
+
},
|
999 |
+
"vscode": {
|
1000 |
+
"interpreter": {
|
1001 |
+
"hash": "365524a309ad80022da286f2ec5d2060ce5cb229abb6076cf68d9a1ab14bd8fe"
|
1002 |
+
}
|
1003 |
+
}
|
1004 |
+
},
|
1005 |
+
"nbformat": 4,
|
1006 |
+
"nbformat_minor": 4
|
1007 |
+
}
|
lm-watermarking-main/extended_watermark_processor.py
ADDED
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
3 |
+
# available at https://arxiv.org/abs/2301.10226
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from __future__ import annotations
|
18 |
+
import collections
|
19 |
+
from math import sqrt
|
20 |
+
from itertools import chain, tee
|
21 |
+
from functools import lru_cache
|
22 |
+
|
23 |
+
import scipy.stats
|
24 |
+
import torch
|
25 |
+
from tokenizers import Tokenizer
|
26 |
+
from transformers import LogitsProcessor
|
27 |
+
|
28 |
+
from normalizers import normalization_strategy_lookup
|
29 |
+
from alternative_prf_schemes import prf_lookup, seeding_scheme_lookup
|
30 |
+
|
31 |
+
|
32 |
+
class WatermarkBase:
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
vocab: list[int] = None,
|
36 |
+
gamma: float = 0.25,
|
37 |
+
delta: float = 2.0,
|
38 |
+
seeding_scheme: str = "selfhash", # simple default, find more schemes in alternative_prf_schemes.py
|
39 |
+
select_green_tokens: bool = True, # should always be the default if not running in legacy mode
|
40 |
+
):
|
41 |
+
# patch now that None could now maybe be passed as seeding_scheme
|
42 |
+
if seeding_scheme is None:
|
43 |
+
seeding_scheme = "selfhash"
|
44 |
+
|
45 |
+
# Vocabulary setup
|
46 |
+
self.vocab = vocab
|
47 |
+
self.vocab_size = len(vocab)
|
48 |
+
|
49 |
+
# Watermark behavior:
|
50 |
+
self.gamma = gamma
|
51 |
+
self.delta = delta
|
52 |
+
self.rng = None
|
53 |
+
self._initialize_seeding_scheme(seeding_scheme)
|
54 |
+
# Legacy behavior:
|
55 |
+
self.select_green_tokens = select_green_tokens
|
56 |
+
|
57 |
+
def _initialize_seeding_scheme(self, seeding_scheme: str) -> None:
|
58 |
+
"""Initialize all internal settings of the seeding strategy from a colloquial, "public" name for the scheme."""
|
59 |
+
self.prf_type, self.context_width, self.self_salt, self.hash_key = seeding_scheme_lookup(seeding_scheme)
|
60 |
+
|
61 |
+
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
62 |
+
"""Seed RNG from local context. Not batched, because the generators we use (like cuda.random) are not batched."""
|
63 |
+
# Need to have enough context for seed generation
|
64 |
+
if input_ids.shape[-1] < self.context_width:
|
65 |
+
raise ValueError(f"seeding_scheme requires at least a {self.context_width} token prefix to seed the RNG.")
|
66 |
+
|
67 |
+
prf_key = prf_lookup[self.prf_type](input_ids[-self.context_width :], salt_key=self.hash_key)
|
68 |
+
# enable for long, interesting streams of pseudorandom numbers: print(prf_key)
|
69 |
+
self.rng.manual_seed(prf_key % (2**64 - 1)) # safeguard against overflow from long
|
70 |
+
|
71 |
+
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> torch.LongTensor:
|
72 |
+
"""Seed rng based on local context width and use this information to generate ids on the green list."""
|
73 |
+
self._seed_rng(input_ids)
|
74 |
+
|
75 |
+
greenlist_size = int(self.vocab_size * self.gamma)
|
76 |
+
vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
|
77 |
+
if self.select_green_tokens: # directly
|
78 |
+
greenlist_ids = vocab_permutation[:greenlist_size] # new
|
79 |
+
else: # select green via red
|
80 |
+
greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
|
81 |
+
return greenlist_ids
|
82 |
+
|
83 |
+
|
84 |
+
class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
|
85 |
+
"""LogitsProcessor modifying model output scores in a pipe. Can be used in any HF pipeline to modify scores to fit the watermark,
|
86 |
+
but can also be used as a standalone tool inserted for any model producing scores inbetween model outputs and next token sampler.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, *args, store_spike_ents: bool = False, **kwargs):
|
90 |
+
super().__init__(*args, **kwargs)
|
91 |
+
|
92 |
+
self.store_spike_ents = store_spike_ents
|
93 |
+
self.spike_entropies = None
|
94 |
+
if self.store_spike_ents:
|
95 |
+
self._init_spike_entropies()
|
96 |
+
|
97 |
+
def _init_spike_entropies(self):
|
98 |
+
alpha = torch.exp(torch.tensor(self.delta)).item()
|
99 |
+
gamma = self.gamma
|
100 |
+
|
101 |
+
self.z_value = ((1 - gamma) * (alpha - 1)) / (1 - gamma + (alpha * gamma))
|
102 |
+
self.expected_gl_coef = (gamma * alpha) / (1 - gamma + (alpha * gamma))
|
103 |
+
|
104 |
+
# catch for overflow when bias is "infinite"
|
105 |
+
if alpha == torch.inf:
|
106 |
+
self.z_value = 1.0
|
107 |
+
self.expected_gl_coef = 1.0
|
108 |
+
|
109 |
+
def _get_spike_entropies(self):
|
110 |
+
spike_ents = [[] for _ in range(len(self.spike_entropies))]
|
111 |
+
for b_idx, ent_tensor_list in enumerate(self.spike_entropies):
|
112 |
+
for ent_tensor in ent_tensor_list:
|
113 |
+
spike_ents[b_idx].append(ent_tensor.item())
|
114 |
+
return spike_ents
|
115 |
+
|
116 |
+
def _get_and_clear_stored_spike_ents(self):
|
117 |
+
spike_ents = self._get_spike_entropies()
|
118 |
+
self.spike_entropies = None
|
119 |
+
return spike_ents
|
120 |
+
|
121 |
+
def _compute_spike_entropy(self, scores):
|
122 |
+
# precomputed z value in init
|
123 |
+
probs = scores.softmax(dim=-1)
|
124 |
+
denoms = 1 + (self.z_value * probs)
|
125 |
+
renormed_probs = probs / denoms
|
126 |
+
sum_renormed_probs = renormed_probs.sum()
|
127 |
+
return sum_renormed_probs
|
128 |
+
|
129 |
+
def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
|
130 |
+
# Cannot lose loop, greenlists might have different lengths
|
131 |
+
green_tokens_mask = torch.zeros_like(scores, dtype=torch.bool)
|
132 |
+
for b_idx, greenlist in enumerate(greenlist_token_ids):
|
133 |
+
if len(greenlist) > 0:
|
134 |
+
green_tokens_mask[b_idx][greenlist] = True
|
135 |
+
return green_tokens_mask
|
136 |
+
|
137 |
+
def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
|
138 |
+
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
139 |
+
return scores
|
140 |
+
|
141 |
+
def _score_rejection_sampling(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, tail_rule="fixed_compute") -> list[int]:
|
142 |
+
"""Generate greenlist based on current candidate next token. Reject and move on if necessary. Method not batched.
|
143 |
+
This is only a partial version of Alg.3 "Robust Private Watermarking", as it always assumes greedy sampling. It will still (kinda)
|
144 |
+
work for all types of sampling, but less effectively.
|
145 |
+
To work efficiently, this function can switch between a number of rules for handling the distribution tail.
|
146 |
+
These are not exposed by default.
|
147 |
+
"""
|
148 |
+
sorted_scores, greedy_predictions = scores.sort(dim=-1, descending=True)
|
149 |
+
|
150 |
+
final_greenlist = []
|
151 |
+
for idx, prediction_candidate in enumerate(greedy_predictions):
|
152 |
+
greenlist_ids = self._get_greenlist_ids(torch.cat([input_ids, prediction_candidate[None]], dim=0)) # add candidate to prefix
|
153 |
+
if prediction_candidate in greenlist_ids: # test for consistency
|
154 |
+
final_greenlist.append(prediction_candidate)
|
155 |
+
|
156 |
+
# What follows below are optional early-stopping rules for efficiency
|
157 |
+
if tail_rule == "fixed_score":
|
158 |
+
if sorted_scores[0] - sorted_scores[idx + 1] > self.delta:
|
159 |
+
break
|
160 |
+
elif tail_rule == "fixed_list_length":
|
161 |
+
if len(final_greenlist) == 10:
|
162 |
+
break
|
163 |
+
elif tail_rule == "fixed_compute":
|
164 |
+
if idx == 40:
|
165 |
+
break
|
166 |
+
else:
|
167 |
+
pass # do not break early
|
168 |
+
return torch.as_tensor(final_greenlist, device=input_ids.device)
|
169 |
+
|
170 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
171 |
+
"""Call with previous context as input_ids, and scores for next token."""
|
172 |
+
|
173 |
+
# this is lazy to allow us to co-locate on the watermarked model's device
|
174 |
+
self.rng = torch.Generator(device=input_ids.device) if self.rng is None else self.rng
|
175 |
+
|
176 |
+
# NOTE, it would be nice to get rid of this batch loop, but currently,
|
177 |
+
# the seed and partition operations are not tensor/vectorized, thus
|
178 |
+
# each sequence in the batch needs to be treated separately.
|
179 |
+
|
180 |
+
list_of_greenlist_ids = [None for _ in input_ids] # Greenlists could differ in length
|
181 |
+
for b_idx, input_seq in enumerate(input_ids):
|
182 |
+
if self.self_salt:
|
183 |
+
greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx])
|
184 |
+
else:
|
185 |
+
greenlist_ids = self._get_greenlist_ids(input_seq)
|
186 |
+
list_of_greenlist_ids[b_idx] = greenlist_ids
|
187 |
+
|
188 |
+
# logic for computing and storing spike entropies for analysis
|
189 |
+
if self.store_spike_ents:
|
190 |
+
if self.spike_entropies is None:
|
191 |
+
self.spike_entropies = [[] for _ in range(input_ids.shape[0])]
|
192 |
+
self.spike_entropies[b_idx].append(self._compute_spike_entropy(scores[b_idx]))
|
193 |
+
|
194 |
+
green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=list_of_greenlist_ids)
|
195 |
+
scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
|
196 |
+
|
197 |
+
return scores
|
198 |
+
|
199 |
+
|
200 |
+
class WatermarkDetector(WatermarkBase):
|
201 |
+
"""This is the detector for all watermarks imprinted with WatermarkLogitsProcessor.
|
202 |
+
|
203 |
+
The detector needs to be given the exact same settings that were given during text generation to replicate the watermark
|
204 |
+
greenlist generation and so detect the watermark.
|
205 |
+
This includes the correct device that was used during text generation, the correct tokenizer, the correct
|
206 |
+
seeding_scheme name, and parameters (delta, gamma).
|
207 |
+
|
208 |
+
Optional arguments are
|
209 |
+
* normalizers ["unicode", "homoglyphs", "truecase"] -> These can mitigate modifications to generated text that could trip the watermark
|
210 |
+
* ignore_repeated_ngrams -> This option changes the detection rules to count every unique ngram only once.
|
211 |
+
* z_threshold -> Changing this threshold will change the sensitivity of the detector.
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
*args,
|
217 |
+
device: torch.device = None,
|
218 |
+
tokenizer: Tokenizer = None,
|
219 |
+
z_threshold: float = 4.0,
|
220 |
+
normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
|
221 |
+
ignore_repeated_ngrams: bool = True,
|
222 |
+
**kwargs,
|
223 |
+
):
|
224 |
+
super().__init__(*args, **kwargs)
|
225 |
+
# also configure the metrics returned/preprocessing options
|
226 |
+
assert device, "Must pass device"
|
227 |
+
assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
|
228 |
+
|
229 |
+
self.tokenizer = tokenizer
|
230 |
+
self.device = device
|
231 |
+
self.z_threshold = z_threshold
|
232 |
+
self.rng = torch.Generator(device=self.device)
|
233 |
+
|
234 |
+
self.normalizers = []
|
235 |
+
for normalization_strategy in normalizers:
|
236 |
+
self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
|
237 |
+
self.ignore_repeated_ngrams = ignore_repeated_ngrams
|
238 |
+
|
239 |
+
def dummy_detect(
|
240 |
+
self,
|
241 |
+
return_prediction: bool = True,
|
242 |
+
return_scores: bool = True,
|
243 |
+
z_threshold: float = None,
|
244 |
+
return_num_tokens_scored: bool = True,
|
245 |
+
return_num_green_tokens: bool = True,
|
246 |
+
return_green_fraction: bool = True,
|
247 |
+
return_green_token_mask: bool = False,
|
248 |
+
return_all_window_scores: bool = False,
|
249 |
+
return_z_score: bool = True,
|
250 |
+
return_z_at_T: bool = True,
|
251 |
+
return_p_value: bool = True,
|
252 |
+
):
|
253 |
+
# HF-style output dictionary
|
254 |
+
score_dict = dict()
|
255 |
+
if return_num_tokens_scored:
|
256 |
+
score_dict.update(dict(num_tokens_scored=float("nan")))
|
257 |
+
if return_num_green_tokens:
|
258 |
+
score_dict.update(dict(num_green_tokens=float("nan")))
|
259 |
+
if return_green_fraction:
|
260 |
+
score_dict.update(dict(green_fraction=float("nan")))
|
261 |
+
if return_z_score:
|
262 |
+
score_dict.update(dict(z_score=float("nan")))
|
263 |
+
if return_p_value:
|
264 |
+
z_score = score_dict.get("z_score")
|
265 |
+
if z_score is None:
|
266 |
+
z_score = float("nan")
|
267 |
+
score_dict.update(dict(p_value=float("nan")))
|
268 |
+
if return_green_token_mask:
|
269 |
+
score_dict.update(dict(green_token_mask=[]))
|
270 |
+
if return_all_window_scores:
|
271 |
+
score_dict.update(dict(window_list=[]))
|
272 |
+
if return_z_at_T:
|
273 |
+
score_dict.update(dict(z_score_at_T=torch.tensor([])))
|
274 |
+
|
275 |
+
output_dict = {}
|
276 |
+
if return_scores:
|
277 |
+
output_dict.update(score_dict)
|
278 |
+
# if passed return_prediction then perform the hypothesis test and return the outcome
|
279 |
+
if return_prediction:
|
280 |
+
z_threshold = z_threshold if z_threshold else self.z_threshold
|
281 |
+
assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
|
282 |
+
output_dict["prediction"] = False
|
283 |
+
|
284 |
+
return output_dict
|
285 |
+
|
286 |
+
def _compute_z_score(self, observed_count, T):
|
287 |
+
# count refers to number of green tokens, T is total number of tokens
|
288 |
+
expected_count = self.gamma
|
289 |
+
numer = observed_count - expected_count * T
|
290 |
+
denom = sqrt(T * expected_count * (1 - expected_count))
|
291 |
+
z = numer / denom
|
292 |
+
return z
|
293 |
+
|
294 |
+
def _compute_p_value(self, z):
|
295 |
+
p_value = scipy.stats.norm.sf(z)
|
296 |
+
return p_value
|
297 |
+
|
298 |
+
@lru_cache(maxsize=2**32)
|
299 |
+
def _get_ngram_score_cached(self, prefix: tuple[int], target: int):
|
300 |
+
"""Expensive re-seeding and sampling is cached."""
|
301 |
+
# Handle with care, should ideally reset on __getattribute__ access to self.prf_type, self.context_width, self.self_salt, self.hash_key
|
302 |
+
greenlist_ids = self._get_greenlist_ids(torch.as_tensor(prefix, device=self.device))
|
303 |
+
return True if target in greenlist_ids else False
|
304 |
+
|
305 |
+
def _score_ngrams_in_passage(self, input_ids: torch.Tensor):
|
306 |
+
"""Core function to gather all ngrams in the input and compute their watermark."""
|
307 |
+
if len(input_ids) - self.context_width < 1:
|
308 |
+
raise ValueError(
|
309 |
+
f"Must have at least {1} token to score after "
|
310 |
+
f"the first min_prefix_len={self.context_width} tokens required by the seeding scheme."
|
311 |
+
)
|
312 |
+
|
313 |
+
# Compute scores for all ngrams contexts in the passage:
|
314 |
+
token_ngram_generator = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt)
|
315 |
+
frequencies_table = collections.Counter(token_ngram_generator)
|
316 |
+
ngram_to_watermark_lookup = {}
|
317 |
+
for idx, ngram_example in enumerate(frequencies_table.keys()):
|
318 |
+
prefix = ngram_example if self.self_salt else ngram_example[:-1]
|
319 |
+
target = ngram_example[-1]
|
320 |
+
ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
|
321 |
+
|
322 |
+
return ngram_to_watermark_lookup, frequencies_table
|
323 |
+
|
324 |
+
def _get_green_at_T_booleans(self, input_ids, ngram_to_watermark_lookup) -> tuple[torch.Tensor]:
|
325 |
+
"""Generate binary list of green vs. red per token, a separate list that ignores repeated ngrams, and a list of offsets to
|
326 |
+
convert between both representations:
|
327 |
+
green_token_mask = green_token_mask_unique[offsets] except for all locations where otherwise a repeat would be counted
|
328 |
+
"""
|
329 |
+
green_token_mask, green_token_mask_unique, offsets = [], [], []
|
330 |
+
used_ngrams = {}
|
331 |
+
unique_ngram_idx = 0
|
332 |
+
ngram_examples = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt)
|
333 |
+
|
334 |
+
for idx, ngram_example in enumerate(ngram_examples):
|
335 |
+
green_token_mask.append(ngram_to_watermark_lookup[ngram_example])
|
336 |
+
if self.ignore_repeated_ngrams:
|
337 |
+
if ngram_example in used_ngrams:
|
338 |
+
pass
|
339 |
+
else:
|
340 |
+
used_ngrams[ngram_example] = True
|
341 |
+
unique_ngram_idx += 1
|
342 |
+
green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example])
|
343 |
+
else:
|
344 |
+
green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example])
|
345 |
+
unique_ngram_idx += 1
|
346 |
+
offsets.append(unique_ngram_idx - 1)
|
347 |
+
return (
|
348 |
+
torch.tensor(green_token_mask),
|
349 |
+
torch.tensor(green_token_mask_unique),
|
350 |
+
torch.tensor(offsets),
|
351 |
+
)
|
352 |
+
|
353 |
+
def _score_sequence(
|
354 |
+
self,
|
355 |
+
input_ids: torch.Tensor,
|
356 |
+
return_num_tokens_scored: bool = True,
|
357 |
+
return_num_green_tokens: bool = True,
|
358 |
+
return_green_fraction: bool = True,
|
359 |
+
return_green_token_mask: bool = False,
|
360 |
+
return_z_score: bool = True,
|
361 |
+
return_z_at_T: bool = True,
|
362 |
+
return_p_value: bool = True,
|
363 |
+
):
|
364 |
+
ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids)
|
365 |
+
green_token_mask, green_unique, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup)
|
366 |
+
|
367 |
+
# Count up scores over all ngrams
|
368 |
+
if self.ignore_repeated_ngrams:
|
369 |
+
# Method that only counts a green/red hit once per unique ngram.
|
370 |
+
# New num total tokens scored (T) becomes the number unique ngrams.
|
371 |
+
# We iterate over all unqiue token ngrams in the input, computing the greenlist
|
372 |
+
# induced by the context in each, and then checking whether the last
|
373 |
+
# token falls in that greenlist.
|
374 |
+
num_tokens_scored = len(frequencies_table.keys())
|
375 |
+
green_token_count = sum(ngram_to_watermark_lookup.values())
|
376 |
+
else:
|
377 |
+
num_tokens_scored = sum(frequencies_table.values())
|
378 |
+
assert num_tokens_scored == len(input_ids) - self.context_width + self.self_salt
|
379 |
+
green_token_count = sum(freq * outcome for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values()))
|
380 |
+
assert green_token_count == green_unique.sum()
|
381 |
+
|
382 |
+
# HF-style output dictionary
|
383 |
+
score_dict = dict()
|
384 |
+
if return_num_tokens_scored:
|
385 |
+
score_dict.update(dict(num_tokens_scored=num_tokens_scored))
|
386 |
+
if return_num_green_tokens:
|
387 |
+
score_dict.update(dict(num_green_tokens=green_token_count))
|
388 |
+
if return_green_fraction:
|
389 |
+
score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
|
390 |
+
if return_z_score:
|
391 |
+
score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
|
392 |
+
if return_p_value:
|
393 |
+
z_score = score_dict.get("z_score")
|
394 |
+
if z_score is None:
|
395 |
+
z_score = self._compute_z_score(green_token_count, num_tokens_scored)
|
396 |
+
score_dict.update(dict(p_value=self._compute_p_value(z_score)))
|
397 |
+
if return_green_token_mask:
|
398 |
+
score_dict.update(dict(green_token_mask=green_token_mask.tolist()))
|
399 |
+
if return_z_at_T:
|
400 |
+
# Score z_at_T separately:
|
401 |
+
sizes = torch.arange(1, len(green_unique) + 1)
|
402 |
+
seq_z_score_enum = torch.cumsum(green_unique, dim=0) - self.gamma * sizes
|
403 |
+
seq_z_score_denom = torch.sqrt(sizes * self.gamma * (1 - self.gamma))
|
404 |
+
z_score_at_effective_T = seq_z_score_enum / seq_z_score_denom
|
405 |
+
z_score_at_T = z_score_at_effective_T[offsets]
|
406 |
+
assert torch.isclose(z_score_at_T[-1], torch.tensor(z_score))
|
407 |
+
|
408 |
+
score_dict.update(dict(z_score_at_T=z_score_at_T))
|
409 |
+
|
410 |
+
return score_dict
|
411 |
+
|
412 |
+
def _score_windows_impl_batched(
|
413 |
+
self,
|
414 |
+
input_ids: torch.Tensor,
|
415 |
+
window_size: str,
|
416 |
+
window_stride: int = 1,
|
417 |
+
):
|
418 |
+
# Implementation details:
|
419 |
+
# 1) --ignore_repeated_ngrams is applied globally, and windowing is then applied over the reduced binary vector
|
420 |
+
# this is only one way of doing it, another would be to ignore bigrams within each window (maybe harder to parallelize that)
|
421 |
+
# 2) These windows on the binary vector of green/red hits, independent of context_width, in contrast to Kezhi's first implementation
|
422 |
+
# 3) z-scores from this implementation cannot be directly converted to p-values, and should only be used as labels for a
|
423 |
+
# ROC chart that calibrates to a chosen FPR. Due, to windowing, the multiple hypotheses will increase scores across the board#
|
424 |
+
# naive_count_correction=True is a partial remedy to this
|
425 |
+
|
426 |
+
ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids)
|
427 |
+
green_mask, green_ids, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup)
|
428 |
+
len_full_context = len(green_ids)
|
429 |
+
|
430 |
+
partial_sum_id_table = torch.cumsum(green_ids, dim=0)
|
431 |
+
|
432 |
+
if window_size == "max":
|
433 |
+
# could start later, small window sizes cannot generate enough power
|
434 |
+
# more principled: solve (T * Spike_Entropy - g * T) / sqrt(T * g * (1 - g)) = z_thresh for T
|
435 |
+
sizes = range(1, len_full_context)
|
436 |
+
else:
|
437 |
+
sizes = [int(x) for x in window_size.split(",") if len(x) > 0]
|
438 |
+
|
439 |
+
z_score_max_per_window = torch.zeros(len(sizes))
|
440 |
+
cumulative_eff_z_score = torch.zeros(len_full_context)
|
441 |
+
s = window_stride
|
442 |
+
|
443 |
+
window_fits = False
|
444 |
+
for idx, size in enumerate(sizes):
|
445 |
+
if size <= len_full_context:
|
446 |
+
# Compute hits within window for all positions in parallel:
|
447 |
+
window_score = torch.zeros(len_full_context - size + 1, dtype=torch.long)
|
448 |
+
# Include 0-th window
|
449 |
+
window_score[0] = partial_sum_id_table[size - 1]
|
450 |
+
# All other windows from the 1st:
|
451 |
+
window_score[1:] = partial_sum_id_table[size::s] - partial_sum_id_table[:-size:s]
|
452 |
+
|
453 |
+
# Now compute batched z_scores
|
454 |
+
batched_z_score_enum = window_score - self.gamma * size
|
455 |
+
z_score_denom = sqrt(size * self.gamma * (1 - self.gamma))
|
456 |
+
batched_z_score = batched_z_score_enum / z_score_denom
|
457 |
+
|
458 |
+
# And find the maximal hit
|
459 |
+
maximal_z_score = batched_z_score.max()
|
460 |
+
z_score_max_per_window[idx] = maximal_z_score
|
461 |
+
|
462 |
+
z_score_at_effective_T = torch.cummax(batched_z_score, dim=0)[0]
|
463 |
+
cumulative_eff_z_score[size::s] = torch.maximum(cumulative_eff_z_score[size::s], z_score_at_effective_T[:-1])
|
464 |
+
window_fits = True # successful computation for any window in sizes
|
465 |
+
|
466 |
+
if not window_fits:
|
467 |
+
raise ValueError(
|
468 |
+
f"Could not find a fitting window with window sizes {window_size} for (effective) context length {len_full_context}."
|
469 |
+
)
|
470 |
+
|
471 |
+
# Compute optimal window size and z-score
|
472 |
+
cumulative_z_score = cumulative_eff_z_score[offsets]
|
473 |
+
optimal_z, optimal_window_size_idx = z_score_max_per_window.max(dim=0)
|
474 |
+
optimal_window_size = sizes[optimal_window_size_idx]
|
475 |
+
return (
|
476 |
+
optimal_z,
|
477 |
+
optimal_window_size,
|
478 |
+
z_score_max_per_window,
|
479 |
+
cumulative_z_score,
|
480 |
+
green_mask,
|
481 |
+
)
|
482 |
+
|
483 |
+
def _score_sequence_window(
|
484 |
+
self,
|
485 |
+
input_ids: torch.Tensor,
|
486 |
+
return_num_tokens_scored: bool = True,
|
487 |
+
return_num_green_tokens: bool = True,
|
488 |
+
return_green_fraction: bool = True,
|
489 |
+
return_green_token_mask: bool = False,
|
490 |
+
return_z_score: bool = True,
|
491 |
+
return_z_at_T: bool = True,
|
492 |
+
return_p_value: bool = True,
|
493 |
+
window_size: str = None,
|
494 |
+
window_stride: int = 1,
|
495 |
+
):
|
496 |
+
(
|
497 |
+
optimal_z,
|
498 |
+
optimal_window_size,
|
499 |
+
_,
|
500 |
+
z_score_at_T,
|
501 |
+
green_mask,
|
502 |
+
) = self._score_windows_impl_batched(input_ids, window_size, window_stride)
|
503 |
+
|
504 |
+
# HF-style output dictionary
|
505 |
+
score_dict = dict()
|
506 |
+
if return_num_tokens_scored:
|
507 |
+
score_dict.update(dict(num_tokens_scored=optimal_window_size))
|
508 |
+
|
509 |
+
denom = sqrt(optimal_window_size * self.gamma * (1 - self.gamma))
|
510 |
+
green_token_count = int(optimal_z * denom + self.gamma * optimal_window_size)
|
511 |
+
green_fraction = green_token_count / optimal_window_size
|
512 |
+
if return_num_green_tokens:
|
513 |
+
score_dict.update(dict(num_green_tokens=green_token_count))
|
514 |
+
if return_green_fraction:
|
515 |
+
score_dict.update(dict(green_fraction=green_fraction))
|
516 |
+
if return_z_score:
|
517 |
+
score_dict.update(dict(z_score=optimal_z))
|
518 |
+
if return_z_at_T:
|
519 |
+
score_dict.update(dict(z_score_at_T=z_score_at_T))
|
520 |
+
if return_p_value:
|
521 |
+
z_score = score_dict.get("z_score", optimal_z)
|
522 |
+
score_dict.update(dict(p_value=self._compute_p_value(z_score)))
|
523 |
+
|
524 |
+
# Return per-token results for mask. This is still the same, just scored by windows
|
525 |
+
# todo would be to mark the actually counted tokens differently
|
526 |
+
if return_green_token_mask:
|
527 |
+
score_dict.update(dict(green_token_mask=green_mask.tolist()))
|
528 |
+
|
529 |
+
return score_dict
|
530 |
+
|
531 |
+
def detect(
|
532 |
+
self,
|
533 |
+
text: str = None,
|
534 |
+
tokenized_text: list[int] = None,
|
535 |
+
window_size: str = None,
|
536 |
+
window_stride: int = None,
|
537 |
+
return_prediction: bool = True,
|
538 |
+
return_scores: bool = True,
|
539 |
+
z_threshold: float = None,
|
540 |
+
convert_to_float: bool = False,
|
541 |
+
**kwargs,
|
542 |
+
) -> dict:
|
543 |
+
"""Scores a given string of text and returns a dictionary of results."""
|
544 |
+
|
545 |
+
assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
|
546 |
+
if return_prediction:
|
547 |
+
kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections
|
548 |
+
|
549 |
+
# run optional normalizers on text
|
550 |
+
for normalizer in self.normalizers:
|
551 |
+
text = normalizer(text)
|
552 |
+
if len(self.normalizers) > 0:
|
553 |
+
print(f"Text after normalization:\n\n{text}\n")
|
554 |
+
|
555 |
+
if tokenized_text is None:
|
556 |
+
assert self.tokenizer is not None, (
|
557 |
+
"Watermark detection on raw string ",
|
558 |
+
"requires an instance of the tokenizer ",
|
559 |
+
"that was used at generation time.",
|
560 |
+
)
|
561 |
+
tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
|
562 |
+
if tokenized_text[0] == self.tokenizer.bos_token_id:
|
563 |
+
tokenized_text = tokenized_text[1:]
|
564 |
+
else:
|
565 |
+
# try to remove the bos_tok at beginning if it's there
|
566 |
+
if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
|
567 |
+
tokenized_text = tokenized_text[1:]
|
568 |
+
|
569 |
+
# call score method
|
570 |
+
output_dict = {}
|
571 |
+
|
572 |
+
if window_size is not None:
|
573 |
+
# assert window_size <= len(tokenized_text) cannot assert for all new types
|
574 |
+
score_dict = self._score_sequence_window(
|
575 |
+
tokenized_text,
|
576 |
+
window_size=window_size,
|
577 |
+
window_stride=window_stride,
|
578 |
+
**kwargs,
|
579 |
+
)
|
580 |
+
output_dict.update(score_dict)
|
581 |
+
else:
|
582 |
+
score_dict = self._score_sequence(tokenized_text, **kwargs)
|
583 |
+
if return_scores:
|
584 |
+
output_dict.update(score_dict)
|
585 |
+
# if passed return_prediction then perform the hypothesis test and return the outcome
|
586 |
+
if return_prediction:
|
587 |
+
z_threshold = z_threshold if z_threshold else self.z_threshold
|
588 |
+
assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
|
589 |
+
output_dict["prediction"] = score_dict["z_score"] > z_threshold
|
590 |
+
if output_dict["prediction"]:
|
591 |
+
output_dict["confidence"] = 1 - score_dict["p_value"]
|
592 |
+
|
593 |
+
# convert any numerical values to float if requested
|
594 |
+
if convert_to_float:
|
595 |
+
for key, value in output_dict.items():
|
596 |
+
if isinstance(value, int):
|
597 |
+
output_dict[key] = float(value)
|
598 |
+
|
599 |
+
return output_dict
|
600 |
+
|
601 |
+
|
602 |
+
##########################################################################
|
603 |
+
# Ngram iteration from nltk, extracted to remove the dependency
|
604 |
+
# Natural Language Toolkit: Utility functions
|
605 |
+
#
|
606 |
+
# Copyright (C) 2001-2023 NLTK Project
|
607 |
+
# Author: Steven Bird <[email protected]>
|
608 |
+
# Eric Kafe <[email protected]> (acyclic closures)
|
609 |
+
# URL: <https://www.nltk.org/>
|
610 |
+
# For license information, see https://github.com/nltk/nltk/blob/develop/LICENSE.txt
|
611 |
+
##########################################################################
|
612 |
+
|
613 |
+
|
614 |
+
def ngrams(sequence, n, pad_left=False, pad_right=False, pad_symbol=None):
|
615 |
+
sequence = iter(sequence)
|
616 |
+
if pad_left:
|
617 |
+
sequence = chain((pad_symbol,) * (n - 1), sequence)
|
618 |
+
if pad_right:
|
619 |
+
sequence = chain(sequence, (pad_symbol,) * (n - 1))
|
620 |
+
iterables = tee(sequence, n)
|
621 |
+
|
622 |
+
for i, sub_iterable in enumerate(iterables): # For each window,
|
623 |
+
for _ in range(i): # iterate through every order of ngrams
|
624 |
+
next(sub_iterable, None) # generate the ngrams within the window.
|
625 |
+
return zip(*iterables) # Unpack and flattens the iterables.
|
lm-watermarking-main/homoglyph_data/__init__.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is data for homoglyph finding
|
2 |
+
|
3 |
+
"""Original package info:
|
4 |
+
|
5 |
+
Homoglyphs
|
6 |
+
* Get similar letters
|
7 |
+
* Convert string to ASCII letters
|
8 |
+
* Detect possible letter languages
|
9 |
+
* Detect letter UTF-8 group.
|
10 |
+
|
11 |
+
# main package info
|
12 |
+
__title__ = 'Homoglyphs'
|
13 |
+
__version__ = '2.0.4'
|
14 |
+
__author__ = 'Gram Orsinium'
|
15 |
+
__license__ = 'MIT'
|
16 |
+
|
17 |
+
# License:
|
18 |
+
|
19 |
+
MIT License 2019 orsinium <[email protected]>
|
20 |
+
|
21 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
of this software and associated documentation files (the "Software"), to deal
|
23 |
+
in the Software without restriction, including without limitation the rights
|
24 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
copies of the Software, and to permit persons to whom the Software is
|
26 |
+
furnished to do so, subject to the following conditions:
|
27 |
+
|
28 |
+
The above copyright notice and this permission notice (including the next
|
29 |
+
paragraph) shall be included in all copies or substantial portions of the
|
30 |
+
Software.
|
31 |
+
|
32 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
33 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
34 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
35 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
36 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
37 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
38 |
+
SOFTWARE.
|
39 |
+
|
40 |
+
"""
|
lm-watermarking-main/homoglyph_data/categories.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lm-watermarking-main/homoglyph_data/confusables_sept2022.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lm-watermarking-main/homoglyph_data/languages.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"ar": "ءآأؤإئابةتثجحخدذرزسشصضطظعغػؼؽؾؿـفقكلمنهوىيًٌٍَُِّ",
|
3 |
+
"be": "ʼЁІЎАБВГДЕЖЗЙКЛМНОПРСТУФХЦЧШЫЬЭЮЯабвгдежзйклмнопрстуфхцчшыьэюяёіў",
|
4 |
+
"bg": "АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯабвгдежзийклмнопрстуфхцчшщъьюя",
|
5 |
+
"ca": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÈÉÍÏÒÓÚÜÇàèéíïòóúüç·",
|
6 |
+
"cz": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÉÍÓÚÝáéíóúýČčĎďĚěŇňŘřŠšŤťŮůŽž",
|
7 |
+
"da": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÅÆØåæø",
|
8 |
+
"de": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÄÖÜßäöü",
|
9 |
+
"el": "ΪΫΆΈΉΊΌΎΏΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩΐΰϊϋάέήίαβγδεζηθικλμνξοπρςστυφχψωόύώ",
|
10 |
+
"en": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
|
11 |
+
"eo": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĈĉĜĝĤĥĴĵŜŝŬŭ",
|
12 |
+
"es": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÉÍÑÓÚÜáéíñóúü",
|
13 |
+
"et": "ABDEGHIJKLMNOPRSTUVabdeghijklmnoprstuvÄÕÖÜäõöü",
|
14 |
+
"fi": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÄÅÖäåöŠšŽž",
|
15 |
+
"fr": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÂÇÈÉÊÎÏÙÛàâçèéêîïùûŒœ",
|
16 |
+
"he": "אבגדהוזחטיךכלםמןנסעףפץצקרשתװױײ",
|
17 |
+
"hr": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĆćČčĐ𩹮ž",
|
18 |
+
"hu": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzÁÉÍÓÖÚÜáéíóöúüŐőŰű",
|
19 |
+
"it": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÈÉÌÒÓÙàèéìòóù",
|
20 |
+
"lt": "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzĄąČčĖėĘęĮįŠšŪūŲųŽž",
|
21 |
+
"lv": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĀāČčĒēĢģĪīĶķĻļŅņŠšŪūŽž",
|
22 |
+
"mk": "ЃЅЈЉЊЌЏАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШабвгдежзиклмнопрстуфхцчшѓѕјљњќџ",
|
23 |
+
"nl": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
|
24 |
+
"pl": "ABCDEFGHIJKLMNOPRSTUWYZabcdefghijklmnoprstuwyzÓóĄąĆćĘꣳŃńŚśŹźŻż",
|
25 |
+
"pt": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÁÂÃÇÉÊÍÓÔÕÚàáâãçéêíóôõú",
|
26 |
+
"ro": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÂÎâîĂăȘșȚț",
|
27 |
+
"ru": "ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
|
28 |
+
"sk": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÄÉÍÓÔÚÝáäéíóôúýČčĎďĹ弾ŇňŔ੹ŤťŽž",
|
29 |
+
"sl": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzČ芚Žž",
|
30 |
+
"sr": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzЂЈЉЊЋЏАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШабвгдежзиклмнопрстуфхцчшђјљњћџ",
|
31 |
+
"th": "กขฃคฅฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลฦวศษสหฬอฮฯะัาำิีึืฺุู฿เแโใไๅๆ็่้๊๋์ํ๎๏๐๑๒๓๔๕๖๗๘๙๚๛",
|
32 |
+
"tr": "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzÂÇÎÖÛÜâçîöûüĞğİıŞş",
|
33 |
+
"vi": "ABCDEGHIKLMNOPQRSTUVXYabcdeghiklmnopqrstuvxyÂÊÔâêôĂăĐđƠơƯư"
|
34 |
+
}
|
lm-watermarking-main/homoglyphs.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Updated version of core.py from
|
2 |
+
https://github.com/yamatt/homoglyphs/tree/main/homoglyphs_fork
|
3 |
+
for modern python3
|
4 |
+
"""
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
import json
|
8 |
+
from itertools import product
|
9 |
+
import os
|
10 |
+
import unicodedata
|
11 |
+
|
12 |
+
# Actions if char not in alphabet
|
13 |
+
STRATEGY_LOAD = 1 # load category for this char
|
14 |
+
STRATEGY_IGNORE = 2 # add char to result
|
15 |
+
STRATEGY_REMOVE = 3 # remove char from result
|
16 |
+
|
17 |
+
ASCII_RANGE = range(128)
|
18 |
+
|
19 |
+
|
20 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
21 |
+
DATA_LOCATION = os.path.join(CURRENT_DIR, "homoglyph_data")
|
22 |
+
|
23 |
+
|
24 |
+
class Categories:
|
25 |
+
"""
|
26 |
+
Work with aliases from ISO 15924.
|
27 |
+
https://en.wikipedia.org/wiki/ISO_15924#List_of_codes
|
28 |
+
"""
|
29 |
+
|
30 |
+
fpath = os.path.join(DATA_LOCATION, "categories.json")
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def _get_ranges(cls, categories):
|
34 |
+
"""
|
35 |
+
:return: iter: (start code, end code)
|
36 |
+
:rtype: list
|
37 |
+
"""
|
38 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
39 |
+
data = json.load(f)
|
40 |
+
|
41 |
+
for category in categories:
|
42 |
+
if category not in data["aliases"]:
|
43 |
+
raise ValueError("Invalid category: {}".format(category))
|
44 |
+
|
45 |
+
for point in data["points"]:
|
46 |
+
if point[2] in categories:
|
47 |
+
yield point[:2]
|
48 |
+
|
49 |
+
@classmethod
|
50 |
+
def get_alphabet(cls, categories):
|
51 |
+
"""
|
52 |
+
:return: set of chars in alphabet by categories list
|
53 |
+
:rtype: set
|
54 |
+
"""
|
55 |
+
alphabet = set()
|
56 |
+
for start, end in cls._get_ranges(categories):
|
57 |
+
chars = (chr(code) for code in range(start, end + 1))
|
58 |
+
alphabet.update(chars)
|
59 |
+
return alphabet
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def detect(cls, char):
|
63 |
+
"""
|
64 |
+
:return: category
|
65 |
+
:rtype: str
|
66 |
+
"""
|
67 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
68 |
+
data = json.load(f)
|
69 |
+
|
70 |
+
# try detect category by unicodedata
|
71 |
+
try:
|
72 |
+
category = unicodedata.name(char).split()[0]
|
73 |
+
except (TypeError, ValueError):
|
74 |
+
# In Python2 unicodedata.name raise error for non-unicode chars
|
75 |
+
# Python3 raise ValueError for non-unicode characters
|
76 |
+
pass
|
77 |
+
else:
|
78 |
+
if category in data["aliases"]:
|
79 |
+
return category
|
80 |
+
|
81 |
+
# try detect category by ranges from JSON file.
|
82 |
+
code = ord(char)
|
83 |
+
for point in data["points"]:
|
84 |
+
if point[0] <= code <= point[1]:
|
85 |
+
return point[2]
|
86 |
+
|
87 |
+
@classmethod
|
88 |
+
def get_all(cls):
|
89 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
90 |
+
data = json.load(f)
|
91 |
+
return set(data["aliases"])
|
92 |
+
|
93 |
+
|
94 |
+
class Languages:
|
95 |
+
fpath = os.path.join(DATA_LOCATION, "languages.json")
|
96 |
+
|
97 |
+
@classmethod
|
98 |
+
def get_alphabet(cls, languages):
|
99 |
+
"""
|
100 |
+
:return: set of chars in alphabet by languages list
|
101 |
+
:rtype: set
|
102 |
+
"""
|
103 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
104 |
+
data = json.load(f)
|
105 |
+
alphabet = set()
|
106 |
+
for lang in languages:
|
107 |
+
if lang not in data:
|
108 |
+
raise ValueError("Invalid language code: {}".format(lang))
|
109 |
+
alphabet.update(data[lang])
|
110 |
+
return alphabet
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def detect(cls, char):
|
114 |
+
"""
|
115 |
+
:return: set of languages which alphabet contains passed char.
|
116 |
+
:rtype: set
|
117 |
+
"""
|
118 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
119 |
+
data = json.load(f)
|
120 |
+
languages = set()
|
121 |
+
for lang, alphabet in data.items():
|
122 |
+
if char in alphabet:
|
123 |
+
languages.add(lang)
|
124 |
+
return languages
|
125 |
+
|
126 |
+
@classmethod
|
127 |
+
def get_all(cls):
|
128 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
129 |
+
data = json.load(f)
|
130 |
+
return set(data.keys())
|
131 |
+
|
132 |
+
|
133 |
+
class Homoglyphs:
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
categories=None,
|
137 |
+
languages=None,
|
138 |
+
alphabet=None,
|
139 |
+
strategy=STRATEGY_IGNORE,
|
140 |
+
ascii_strategy=STRATEGY_IGNORE,
|
141 |
+
ascii_range=ASCII_RANGE,
|
142 |
+
):
|
143 |
+
# strategies
|
144 |
+
if strategy not in (STRATEGY_LOAD, STRATEGY_IGNORE, STRATEGY_REMOVE):
|
145 |
+
raise ValueError("Invalid strategy")
|
146 |
+
self.strategy = strategy
|
147 |
+
self.ascii_strategy = ascii_strategy
|
148 |
+
self.ascii_range = ascii_range
|
149 |
+
|
150 |
+
# Homoglyphs must be initialized by any alphabet for correct work
|
151 |
+
if not categories and not languages and not alphabet:
|
152 |
+
categories = ("LATIN", "COMMON")
|
153 |
+
|
154 |
+
# cats and langs
|
155 |
+
self.categories = set(categories or [])
|
156 |
+
self.languages = set(languages or [])
|
157 |
+
|
158 |
+
# alphabet
|
159 |
+
self.alphabet = set(alphabet or [])
|
160 |
+
if self.categories:
|
161 |
+
alphabet = Categories.get_alphabet(self.categories)
|
162 |
+
self.alphabet.update(alphabet)
|
163 |
+
if self.languages:
|
164 |
+
alphabet = Languages.get_alphabet(self.languages)
|
165 |
+
self.alphabet.update(alphabet)
|
166 |
+
self.table = self.get_table(self.alphabet)
|
167 |
+
|
168 |
+
@staticmethod
|
169 |
+
def get_table(alphabet):
|
170 |
+
table = defaultdict(set)
|
171 |
+
with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
|
172 |
+
data = json.load(f)
|
173 |
+
for char in alphabet:
|
174 |
+
if char in data:
|
175 |
+
for homoglyph in data[char]:
|
176 |
+
if homoglyph in alphabet:
|
177 |
+
table[char].add(homoglyph)
|
178 |
+
return table
|
179 |
+
|
180 |
+
@staticmethod
|
181 |
+
def get_restricted_table(source_alphabet, target_alphabet):
|
182 |
+
table = defaultdict(set)
|
183 |
+
with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
|
184 |
+
data = json.load(f)
|
185 |
+
for char in source_alphabet:
|
186 |
+
if char in data:
|
187 |
+
for homoglyph in data[char]:
|
188 |
+
if homoglyph in target_alphabet:
|
189 |
+
table[char].add(homoglyph)
|
190 |
+
return table
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def uniq_and_sort(data):
|
194 |
+
result = list(set(data))
|
195 |
+
result.sort(key=lambda x: (-len(x), x))
|
196 |
+
return result
|
197 |
+
|
198 |
+
def _update_alphabet(self, char):
|
199 |
+
# try detect languages
|
200 |
+
langs = Languages.detect(char)
|
201 |
+
if langs:
|
202 |
+
self.languages.update(langs)
|
203 |
+
alphabet = Languages.get_alphabet(langs)
|
204 |
+
self.alphabet.update(alphabet)
|
205 |
+
else:
|
206 |
+
# try detect categories
|
207 |
+
category = Categories.detect(char)
|
208 |
+
if category is None:
|
209 |
+
return False
|
210 |
+
self.categories.add(category)
|
211 |
+
alphabet = Categories.get_alphabet([category])
|
212 |
+
self.alphabet.update(alphabet)
|
213 |
+
# update table for new alphabet
|
214 |
+
self.table = self.get_table(self.alphabet)
|
215 |
+
return True
|
216 |
+
|
217 |
+
def _get_char_variants(self, char):
|
218 |
+
if char not in self.alphabet:
|
219 |
+
if self.strategy == STRATEGY_LOAD:
|
220 |
+
if not self._update_alphabet(char):
|
221 |
+
return []
|
222 |
+
elif self.strategy == STRATEGY_IGNORE:
|
223 |
+
return [char]
|
224 |
+
elif self.strategy == STRATEGY_REMOVE:
|
225 |
+
return []
|
226 |
+
|
227 |
+
# find alternative chars for current char
|
228 |
+
alt_chars = self.table.get(char, set())
|
229 |
+
if alt_chars:
|
230 |
+
# find alternative chars for alternative chars for current char
|
231 |
+
alt_chars2 = [self.table.get(alt_char, set()) for alt_char in alt_chars]
|
232 |
+
# combine all alternatives
|
233 |
+
alt_chars.update(*alt_chars2)
|
234 |
+
# add current char to alternatives
|
235 |
+
alt_chars.add(char)
|
236 |
+
|
237 |
+
# uniq, sort and return
|
238 |
+
return self.uniq_and_sort(alt_chars)
|
239 |
+
|
240 |
+
def _get_combinations(self, text, ascii=False):
|
241 |
+
variations = []
|
242 |
+
for char in text:
|
243 |
+
alt_chars = self._get_char_variants(char)
|
244 |
+
|
245 |
+
if ascii:
|
246 |
+
alt_chars = [char for char in alt_chars if ord(char) in self.ascii_range]
|
247 |
+
if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE:
|
248 |
+
return
|
249 |
+
|
250 |
+
if alt_chars:
|
251 |
+
variations.append(alt_chars)
|
252 |
+
if variations:
|
253 |
+
for variant in product(*variations):
|
254 |
+
yield "".join(variant)
|
255 |
+
|
256 |
+
def get_combinations(self, text):
|
257 |
+
return list(self._get_combinations(text))
|
258 |
+
|
259 |
+
def _to_ascii(self, text):
|
260 |
+
for variant in self._get_combinations(text, ascii=True):
|
261 |
+
if max(map(ord, variant)) in self.ascii_range:
|
262 |
+
yield variant
|
263 |
+
|
264 |
+
def to_ascii(self, text):
|
265 |
+
return self.uniq_and_sort(self._to_ascii(text))
|
lm-watermarking-main/normalizers.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Text-based normalizers, used to mitigate simple attacks against watermarking.
|
2 |
+
|
3 |
+
This implementation is unlikely to be a complete list of all possible exploits within the unicode standard,
|
4 |
+
it represents our best effort at the time of writing.
|
5 |
+
|
6 |
+
These normalizers can be used as stand-alone normalizers. They could be made to conform to HF tokenizers standard, but that would
|
7 |
+
require messing with the limited rust interface of tokenizers.NormalizedString
|
8 |
+
"""
|
9 |
+
from collections import defaultdict
|
10 |
+
from functools import cache
|
11 |
+
|
12 |
+
import re
|
13 |
+
import unicodedata
|
14 |
+
import homoglyphs as hg
|
15 |
+
|
16 |
+
|
17 |
+
def normalization_strategy_lookup(strategy_name: str) -> object:
|
18 |
+
if strategy_name == "unicode":
|
19 |
+
return UnicodeSanitizer()
|
20 |
+
elif strategy_name == "homoglyphs":
|
21 |
+
return HomoglyphCanonizer()
|
22 |
+
elif strategy_name == "truecase":
|
23 |
+
return TrueCaser()
|
24 |
+
|
25 |
+
|
26 |
+
class HomoglyphCanonizer:
|
27 |
+
"""Attempts to detect homoglyph attacks and find a consistent canon.
|
28 |
+
|
29 |
+
This function does so on a per-ISO-category level. Language-level would also be possible (see commented code).
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
self.homoglyphs = None
|
34 |
+
|
35 |
+
def __call__(self, homoglyphed_str: str) -> str:
|
36 |
+
# find canon:
|
37 |
+
target_category, all_categories = self._categorize_text(homoglyphed_str)
|
38 |
+
homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
|
39 |
+
return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)
|
40 |
+
|
41 |
+
def _categorize_text(self, text: str) -> dict:
|
42 |
+
iso_categories = defaultdict(int)
|
43 |
+
# self.iso_languages = defaultdict(int)
|
44 |
+
|
45 |
+
for char in text:
|
46 |
+
iso_categories[hg.Categories.detect(char)] += 1
|
47 |
+
# for lang in hg.Languages.detect(char):
|
48 |
+
# self.iso_languages[lang] += 1
|
49 |
+
target_category = max(iso_categories, key=iso_categories.get)
|
50 |
+
all_categories = tuple(iso_categories)
|
51 |
+
return target_category, all_categories
|
52 |
+
|
53 |
+
@cache
|
54 |
+
def _select_canon_category_and_load(
|
55 |
+
self, target_category: str, all_categories: tuple[str]
|
56 |
+
) -> dict:
|
57 |
+
homoglyph_table = hg.Homoglyphs(
|
58 |
+
categories=(target_category, "COMMON")
|
59 |
+
) # alphabet loaded here from file
|
60 |
+
|
61 |
+
source_alphabet = hg.Categories.get_alphabet(all_categories)
|
62 |
+
restricted_table = homoglyph_table.get_restricted_table(
|
63 |
+
source_alphabet, homoglyph_table.alphabet
|
64 |
+
) # table loaded here from file
|
65 |
+
return restricted_table
|
66 |
+
|
67 |
+
def _sanitize_text(
|
68 |
+
self, target_category: str, homoglyph_table: dict, homoglyphed_str: str
|
69 |
+
) -> str:
|
70 |
+
sanitized_text = ""
|
71 |
+
for char in homoglyphed_str:
|
72 |
+
# langs = hg.Languages.detect(char)
|
73 |
+
cat = hg.Categories.detect(char)
|
74 |
+
if target_category in cat or "COMMON" in cat or len(cat) == 0:
|
75 |
+
sanitized_text += char
|
76 |
+
else:
|
77 |
+
sanitized_text += list(homoglyph_table[char])[0]
|
78 |
+
return sanitized_text
|
79 |
+
|
80 |
+
|
81 |
+
class UnicodeSanitizer:
|
82 |
+
"""Regex-based unicode sanitzer. Has different levels of granularity.
|
83 |
+
|
84 |
+
* ruleset="whitespaces" - attempts to remove only whitespace unicode characters
|
85 |
+
* ruleset="IDN.blacklist" - does its best to remove unusual unicode based on Network.IDN.blacklist characters
|
86 |
+
* ruleset="ascii" - brute-forces all text into ascii
|
87 |
+
|
88 |
+
This is unlikely to be a comprehensive list.
|
89 |
+
|
90 |
+
You can find a more comprehensive discussion at https://www.unicode.org/reports/tr36/
|
91 |
+
and https://www.unicode.org/faq/security.html
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, ruleset="whitespaces"):
|
95 |
+
if ruleset == "whitespaces":
|
96 |
+
"""Documentation:
|
97 |
+
\u00A0: Non-breaking space
|
98 |
+
\u1680: Ogham space mark
|
99 |
+
\u180E: Mongolian vowel separator
|
100 |
+
\u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
|
101 |
+
\u200C\u200D: Zero-width non-joiner and zero-width joiner
|
102 |
+
\u200E,\u200F: Left-to-right-mark, Right-to-left-mark
|
103 |
+
\u2060: Word joiner
|
104 |
+
\u2063: Invisible separator
|
105 |
+
\u202F: Narrow non-breaking space
|
106 |
+
\u205F: Medium mathematical space
|
107 |
+
\u3000: Ideographic space
|
108 |
+
\uFEFF: Zero-width non-breaking space
|
109 |
+
\uFFA0: Halfwidth hangul filler
|
110 |
+
\uFFF9\uFFFA\uFFFB: Interlinear annotation characters
|
111 |
+
\uFE00-\uFE0F: Variation selectors
|
112 |
+
\u202A-\u202F: Embedding characters
|
113 |
+
\u3164: Korean hangul filler.
|
114 |
+
|
115 |
+
Note that these characters are not always superfluous whitespace characters!
|
116 |
+
"""
|
117 |
+
|
118 |
+
self.pattern = re.compile(
|
119 |
+
r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
|
120 |
+
r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
|
121 |
+
r"\u202E\u202F]"
|
122 |
+
)
|
123 |
+
elif ruleset == "IDN.blacklist":
|
124 |
+
"""Documentation:
|
125 |
+
[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
|
126 |
+
set that are included in the IDN blacklist.
|
127 |
+
\uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
|
128 |
+
These characters are not allowed in domain names.
|
129 |
+
\uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
|
130 |
+
set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
|
131 |
+
and the second part is in the range U+DC00 to U+DFFF.
|
132 |
+
\uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
|
133 |
+
to U+DFFF, and is optional.
|
134 |
+
[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
|
135 |
+
"""
|
136 |
+
|
137 |
+
self.pattern = re.compile(
|
138 |
+
r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
|
139 |
+
r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
"""Documentation:
|
143 |
+
This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
|
144 |
+
"""
|
145 |
+
self.pattern = re.compile(r"[^\x00-\x7F]+")
|
146 |
+
|
147 |
+
def __call__(self, text: str) -> str:
|
148 |
+
text = unicodedata.normalize("NFC", text) # canon forms
|
149 |
+
text = self.pattern.sub(" ", text) # pattern match
|
150 |
+
text = re.sub(" +", " ", text) # collapse whitespaces
|
151 |
+
text = "".join(
|
152 |
+
c for c in text if unicodedata.category(c) != "Cc"
|
153 |
+
) # Remove any remaining non-printable characters
|
154 |
+
return text
|
155 |
+
|
156 |
+
|
157 |
+
class TrueCaser:
|
158 |
+
"""True-casing, is a capitalization normalization that returns text to its original capitalization.
|
159 |
+
|
160 |
+
This defends against attacks that wRIte TeXt lIkE spOngBoB.
|
161 |
+
|
162 |
+
Here, a simple POS-tagger is used.
|
163 |
+
"""
|
164 |
+
|
165 |
+
uppercase_pos = ["PROPN"] # Name POS tags that should be upper-cased
|
166 |
+
|
167 |
+
def __init__(self, backend="spacy"):
|
168 |
+
if backend == "spacy":
|
169 |
+
import spacy
|
170 |
+
|
171 |
+
self.nlp = spacy.load("en_core_web_sm")
|
172 |
+
self.normalize_fn = self._spacy_truecasing
|
173 |
+
else:
|
174 |
+
from nltk import pos_tag, word_tokenize # noqa
|
175 |
+
import nltk
|
176 |
+
|
177 |
+
nltk.download("punkt")
|
178 |
+
nltk.download("averaged_perceptron_tagger")
|
179 |
+
nltk.download("universal_tagset")
|
180 |
+
self.normalize_fn = self._nltk_truecasing
|
181 |
+
|
182 |
+
def __call__(self, random_capitalized_string: str) -> str:
|
183 |
+
truecased_str = self.normalize_fn(random_capitalized_string)
|
184 |
+
return truecased_str
|
185 |
+
|
186 |
+
def _spacy_truecasing(self, random_capitalized_string: str):
|
187 |
+
doc = self.nlp(random_capitalized_string.lower())
|
188 |
+
POS = self.uppercase_pos
|
189 |
+
truecased_str = "".join(
|
190 |
+
[
|
191 |
+
w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws
|
192 |
+
for w in doc
|
193 |
+
]
|
194 |
+
)
|
195 |
+
return truecased_str
|
196 |
+
|
197 |
+
def _nltk_truecasing(self, random_capitalized_string: str):
|
198 |
+
from nltk import pos_tag, word_tokenize
|
199 |
+
import nltk
|
200 |
+
|
201 |
+
nltk.download("punkt")
|
202 |
+
nltk.download("averaged_perceptron_tagger")
|
203 |
+
nltk.download("universal_tagset")
|
204 |
+
POS = ["NNP", "NNPS"]
|
205 |
+
|
206 |
+
tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
|
207 |
+
truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
|
208 |
+
return truecased_str
|
lm-watermarking-main/pyproject.toml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[tool.black]
|
6 |
+
line-length = 140
|
lm-watermarking-main/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
nltk
|
3 |
+
scipy
|
4 |
+
torch
|
5 |
+
transformers
|
6 |
+
tokenizers
|
lm-watermarking-main/setup.cfg
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[metadata]
|
2 |
+
name = lm-watermarking
|
3 |
+
version = 0.1.0
|
4 |
+
author = Authors of 'A Watermark for Large Language Models'
|
5 |
+
author_email = [email protected]
|
6 |
+
url = https://github.com/jwkirchenbauer/lm-watermarking
|
7 |
+
description = Implementation of watermark algorithms for large language models.
|
8 |
+
long_description = file: README.md, LICENSE.md
|
9 |
+
long_description_content_type = text/markdown
|
10 |
+
license = Apache 2.0
|
11 |
+
license_file = LICENSE.md
|
12 |
+
platform = any
|
13 |
+
keywords = Machine Learning, NLP, Language Models, Watermark, Safety, Model Output Detection
|
14 |
+
classifiers =
|
15 |
+
Topic :: Security
|
16 |
+
License :: OSI Approved :: Apache 2.0
|
17 |
+
Operating System :: OS Independent
|
18 |
+
Programming Language :: Python
|
19 |
+
homepage = https://github.com/jwkirchenbauer/lm-watermarking
|
20 |
+
repository = https://github.com/jwkirchenbauer/lm-watermarking
|
21 |
+
documentation = https://arxiv.org/abs/2301.10226
|
22 |
+
|
23 |
+
[options]
|
24 |
+
zip_safe = False
|
25 |
+
include_package_data = True
|
26 |
+
python_requires = >= 3.9
|
27 |
+
packages = find:
|
28 |
+
|
29 |
+
setup_requires =
|
30 |
+
setuptools
|
31 |
+
|
32 |
+
install_requires =
|
33 |
+
nltk
|
34 |
+
scipy
|
35 |
+
torch
|
36 |
+
transformers
|
37 |
+
tokenizers
|
38 |
+
|
39 |
+
[tool.black]
|
40 |
+
line-length = 140
|
41 |
+
|
42 |
+
[check-manifest]
|
43 |
+
ignore =
|
44 |
+
.ipynb
|
45 |
+
.sh
|
46 |
+
|
47 |
+
#inspired by https://github.com/pytorch/pytorch/blob/master/.flake8
|
48 |
+
[flake8]
|
49 |
+
select = B,C,E,F,P,T4,W,B9
|
50 |
+
max-line-length = 140
|
51 |
+
extend-ignore = E203
|
52 |
+
|
53 |
+
ignore =
|
54 |
+
E203,E305,E402,E501,E721,E741,F821,F841,F999,W503,W504,C408,E302,W291,E303,
|
55 |
+
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
56 |
+
# to line this up with executable bit
|
57 |
+
EXE001,
|
58 |
+
# these ignores are from flake8-bugbear; please fix!
|
59 |
+
B007,B008,
|
60 |
+
# these ignores are from flake8-comprehensions; please fix!
|
61 |
+
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
|
62 |
+
#unignored: F403,F405,
|
63 |
+
D102,D103,D403 # for doc linting
|
64 |
+
|
65 |
+
exclude =
|
66 |
+
.git
|
67 |
+
__pycache__
|
68 |
+
log/*
|
lm-watermarking-main/watermark_processor.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
3 |
+
# available at https://arxiv.org/abs/2301.10226
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from __future__ import annotations
|
18 |
+
import collections
|
19 |
+
from math import sqrt
|
20 |
+
|
21 |
+
import scipy.stats
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from torch import Tensor
|
25 |
+
from tokenizers import Tokenizer
|
26 |
+
from transformers import LogitsProcessor
|
27 |
+
|
28 |
+
from nltk.util import ngrams
|
29 |
+
|
30 |
+
from normalizers import normalization_strategy_lookup
|
31 |
+
|
32 |
+
|
33 |
+
class WatermarkBase:
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vocab: list[int] = None,
|
37 |
+
gamma: float = 0.5,
|
38 |
+
delta: float = 2.0,
|
39 |
+
seeding_scheme: str = "simple_1", # mostly unused/always default
|
40 |
+
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
41 |
+
select_green_tokens: bool = True,
|
42 |
+
):
|
43 |
+
|
44 |
+
# watermarking parameters
|
45 |
+
self.vocab = vocab
|
46 |
+
self.vocab_size = len(vocab)
|
47 |
+
self.gamma = gamma
|
48 |
+
self.delta = delta
|
49 |
+
self.seeding_scheme = seeding_scheme
|
50 |
+
self.rng = None
|
51 |
+
self.hash_key = hash_key
|
52 |
+
self.select_green_tokens = select_green_tokens
|
53 |
+
|
54 |
+
def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
|
55 |
+
# can optionally override the seeding scheme,
|
56 |
+
# but uses the instance attr by default
|
57 |
+
if seeding_scheme is None:
|
58 |
+
seeding_scheme = self.seeding_scheme
|
59 |
+
|
60 |
+
if seeding_scheme == "simple_1":
|
61 |
+
assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
|
62 |
+
prev_token = input_ids[-1].item()
|
63 |
+
self.rng.manual_seed(self.hash_key * prev_token)
|
64 |
+
else:
|
65 |
+
raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
|
66 |
+
return
|
67 |
+
|
68 |
+
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
|
69 |
+
# seed the rng using the previous tokens/prefix
|
70 |
+
# according to the seeding_scheme
|
71 |
+
self._seed_rng(input_ids)
|
72 |
+
|
73 |
+
greenlist_size = int(self.vocab_size * self.gamma)
|
74 |
+
vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
|
75 |
+
if self.select_green_tokens: # directly
|
76 |
+
greenlist_ids = vocab_permutation[:greenlist_size] # new
|
77 |
+
else: # select green via red
|
78 |
+
greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
|
79 |
+
return greenlist_ids
|
80 |
+
|
81 |
+
|
82 |
+
class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
|
83 |
+
def __init__(self, *args, **kwargs):
|
84 |
+
super().__init__(*args, **kwargs)
|
85 |
+
|
86 |
+
def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
|
87 |
+
# TODO lets see if we can lose this loop
|
88 |
+
green_tokens_mask = torch.zeros_like(scores)
|
89 |
+
for b_idx in range(len(greenlist_token_ids)):
|
90 |
+
green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
|
91 |
+
final_mask = green_tokens_mask.bool()
|
92 |
+
return final_mask
|
93 |
+
|
94 |
+
def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
|
95 |
+
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
96 |
+
return scores
|
97 |
+
|
98 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
99 |
+
|
100 |
+
# this is lazy to allow us to colocate on the watermarked model's device
|
101 |
+
if self.rng is None:
|
102 |
+
self.rng = torch.Generator(device=input_ids.device)
|
103 |
+
|
104 |
+
# NOTE, it would be nice to get rid of this batch loop, but currently,
|
105 |
+
# the seed and partition operations are not tensor/vectorized, thus
|
106 |
+
# each sequence in the batch needs to be treated separately.
|
107 |
+
batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
|
108 |
+
|
109 |
+
for b_idx in range(input_ids.shape[0]):
|
110 |
+
greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
|
111 |
+
batched_greenlist_ids[b_idx] = greenlist_ids
|
112 |
+
|
113 |
+
green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)
|
114 |
+
|
115 |
+
scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
|
116 |
+
return scores
|
117 |
+
|
118 |
+
|
119 |
+
class WatermarkDetector(WatermarkBase):
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
*args,
|
123 |
+
device: torch.device = None,
|
124 |
+
tokenizer: Tokenizer = None,
|
125 |
+
z_threshold: float = 4.0,
|
126 |
+
normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
|
127 |
+
ignore_repeated_bigrams: bool = True,
|
128 |
+
**kwargs,
|
129 |
+
):
|
130 |
+
super().__init__(*args, **kwargs)
|
131 |
+
# also configure the metrics returned/preprocessing options
|
132 |
+
assert device, "Must pass device"
|
133 |
+
assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
|
134 |
+
|
135 |
+
self.tokenizer = tokenizer
|
136 |
+
self.device = device
|
137 |
+
self.z_threshold = z_threshold
|
138 |
+
self.rng = torch.Generator(device=self.device)
|
139 |
+
|
140 |
+
if self.seeding_scheme == "simple_1":
|
141 |
+
self.min_prefix_len = 1
|
142 |
+
else:
|
143 |
+
raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
|
144 |
+
|
145 |
+
self.normalizers = []
|
146 |
+
for normalization_strategy in normalizers:
|
147 |
+
self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
|
148 |
+
|
149 |
+
self.ignore_repeated_bigrams = ignore_repeated_bigrams
|
150 |
+
if self.ignore_repeated_bigrams:
|
151 |
+
assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme."
|
152 |
+
|
153 |
+
def _compute_z_score(self, observed_count, T):
|
154 |
+
# count refers to number of green tokens, T is total number of tokens
|
155 |
+
expected_count = self.gamma
|
156 |
+
numer = observed_count - expected_count * T
|
157 |
+
denom = sqrt(T * expected_count * (1 - expected_count))
|
158 |
+
z = numer / denom
|
159 |
+
return z
|
160 |
+
|
161 |
+
def _compute_p_value(self, z):
|
162 |
+
p_value = scipy.stats.norm.sf(z)
|
163 |
+
return p_value
|
164 |
+
|
165 |
+
def _score_sequence(
|
166 |
+
self,
|
167 |
+
input_ids: Tensor,
|
168 |
+
return_num_tokens_scored: bool = True,
|
169 |
+
return_num_green_tokens: bool = True,
|
170 |
+
return_green_fraction: bool = True,
|
171 |
+
return_green_token_mask: bool = False,
|
172 |
+
return_z_score: bool = True,
|
173 |
+
return_p_value: bool = True,
|
174 |
+
):
|
175 |
+
if self.ignore_repeated_bigrams:
|
176 |
+
# Method that only counts a green/red hit once per unique bigram.
|
177 |
+
# New num total tokens scored (T) becomes the number unique bigrams.
|
178 |
+
# We iterate over all unqiue token bigrams in the input, computing the greenlist
|
179 |
+
# induced by the first token in each, and then checking whether the second
|
180 |
+
# token falls in that greenlist.
|
181 |
+
assert return_green_token_mask is False, "Can't return the green/red mask when ignoring repeats."
|
182 |
+
bigram_table = {}
|
183 |
+
token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
|
184 |
+
freq = collections.Counter(token_bigram_generator)
|
185 |
+
num_tokens_scored = len(freq.keys())
|
186 |
+
for idx, bigram in enumerate(freq.keys()):
|
187 |
+
prefix = torch.tensor([bigram[0]], device=self.device) # expects a 1-d prefix tensor on the randperm device
|
188 |
+
greenlist_ids = self._get_greenlist_ids(prefix)
|
189 |
+
bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
|
190 |
+
green_token_count = sum(bigram_table.values())
|
191 |
+
else:
|
192 |
+
num_tokens_scored = len(input_ids) - self.min_prefix_len
|
193 |
+
if num_tokens_scored < 1:
|
194 |
+
raise ValueError(
|
195 |
+
(
|
196 |
+
f"Must have at least {1} token to score after "
|
197 |
+
f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."
|
198 |
+
)
|
199 |
+
)
|
200 |
+
# Standard method.
|
201 |
+
# Since we generally need at least 1 token (for the simplest scheme)
|
202 |
+
# we start the iteration over the token sequence with a minimum
|
203 |
+
# num tokens as the first prefix for the seeding scheme,
|
204 |
+
# and at each step, compute the greenlist induced by the
|
205 |
+
# current prefix and check if the current token falls in the greenlist.
|
206 |
+
green_token_count, green_token_mask = 0, []
|
207 |
+
for idx in range(self.min_prefix_len, len(input_ids)):
|
208 |
+
curr_token = input_ids[idx]
|
209 |
+
greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
|
210 |
+
if curr_token in greenlist_ids:
|
211 |
+
green_token_count += 1
|
212 |
+
green_token_mask.append(True)
|
213 |
+
else:
|
214 |
+
green_token_mask.append(False)
|
215 |
+
|
216 |
+
score_dict = dict()
|
217 |
+
if return_num_tokens_scored:
|
218 |
+
score_dict.update(dict(num_tokens_scored=num_tokens_scored))
|
219 |
+
if return_num_green_tokens:
|
220 |
+
score_dict.update(dict(num_green_tokens=green_token_count))
|
221 |
+
if return_green_fraction:
|
222 |
+
score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
|
223 |
+
if return_z_score:
|
224 |
+
score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
|
225 |
+
if return_p_value:
|
226 |
+
z_score = score_dict.get("z_score")
|
227 |
+
if z_score is None:
|
228 |
+
z_score = self._compute_z_score(green_token_count, num_tokens_scored)
|
229 |
+
score_dict.update(dict(p_value=self._compute_p_value(z_score)))
|
230 |
+
if return_green_token_mask:
|
231 |
+
score_dict.update(dict(green_token_mask=green_token_mask))
|
232 |
+
|
233 |
+
return score_dict
|
234 |
+
|
235 |
+
def detect(
|
236 |
+
self,
|
237 |
+
text: str = None,
|
238 |
+
tokenized_text: list[int] = None,
|
239 |
+
return_prediction: bool = True,
|
240 |
+
return_scores: bool = True,
|
241 |
+
z_threshold: float = None,
|
242 |
+
**kwargs,
|
243 |
+
) -> dict:
|
244 |
+
|
245 |
+
assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
|
246 |
+
if return_prediction:
|
247 |
+
kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections
|
248 |
+
|
249 |
+
# run optional normalizers on text
|
250 |
+
for normalizer in self.normalizers:
|
251 |
+
text = normalizer(text)
|
252 |
+
if len(self.normalizers) > 0:
|
253 |
+
print(f"Text after normalization:\n\n{text}\n")
|
254 |
+
|
255 |
+
if tokenized_text is None:
|
256 |
+
assert self.tokenizer is not None, (
|
257 |
+
"Watermark detection on raw string ",
|
258 |
+
"requires an instance of the tokenizer ",
|
259 |
+
"that was used at generation time.",
|
260 |
+
)
|
261 |
+
tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
|
262 |
+
if tokenized_text[0] == self.tokenizer.bos_token_id:
|
263 |
+
tokenized_text = tokenized_text[1:]
|
264 |
+
else:
|
265 |
+
# try to remove the bos_tok at beginning if it's there
|
266 |
+
if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
|
267 |
+
tokenized_text = tokenized_text[1:]
|
268 |
+
|
269 |
+
# call score method
|
270 |
+
output_dict = {}
|
271 |
+
score_dict = self._score_sequence(tokenized_text, **kwargs)
|
272 |
+
if return_scores:
|
273 |
+
output_dict.update(score_dict)
|
274 |
+
# if passed return_prediction then perform the hypothesis test and return the outcome
|
275 |
+
if return_prediction:
|
276 |
+
z_threshold = z_threshold if z_threshold else self.z_threshold
|
277 |
+
assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
|
278 |
+
output_dict["prediction"] = score_dict["z_score"] > z_threshold
|
279 |
+
if output_dict["prediction"]:
|
280 |
+
output_dict["confidence"] = 1 - score_dict["p_value"]
|
281 |
+
|
282 |
+
return output_dict
|
lm-watermarking-main/watermark_reliability_release/PIPELINE.md
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Usage document for pipeline
|
2 |
+
|
3 |
+
6/7/23: Will be updated and built out as required.
|
4 |
+
|
5 |
+
## (1) **generate** a bunch of samples
|
6 |
+
|
7 |
+
The point of all this code is to construct pairwise examples
|
8 |
+
of human text, unwatermarked, and watermarked text in something
|
9 |
+
resembling an unbiased or IID manner, despite the difficulty of this ask.
|
10 |
+
|
11 |
+
The key functionality is _oversampling_. A series of arguments control how
|
12 |
+
the raw datasets samples are turned into prompts, and then, provided
|
13 |
+
the raw prompts pass some checks, the prompts are
|
14 |
+
fed to the model, and the number of tokens naturally generated under normal
|
15 |
+
decoding, as well as watermark decoding. If the generations match the given
|
16 |
+
(length) output filtering criteria, then the row "counts" as one of the `N`
|
17 |
+
requested samples.
|
18 |
+
|
19 |
+
Otherwise, the generations are stored, but the global counter of progress
|
20 |
+
towards `N`, is not incremented, and thus this "overhead" is the cost
|
21 |
+
of being very restrictive in desiring "square" (`N` x `T`) shaped table of samples
|
22 |
+
in that all three of the human text, unwatermarked, and watermarked output columns
|
23 |
+
always have the same tokenized length.
|
24 |
+
|
25 |
+
At evaluation time, by default, all the point estimates, means, and ROC and AUC calculations are performed
|
26 |
+
on the subset of rows that all have about the target length (i.e. a subset with shape ~ `N` x `T`).
|
27 |
+
|
28 |
+
The `generation_pipeline.py` call in `run_pipeline.sh` demonstrates the basic usage.
|
29 |
+
|
30 |
+
### Key arguments controlling the oversampling logic...
|
31 |
+
|
32 |
+
### 'Shape' Controls
|
33 |
+
|
34 |
+
- `max_new_tokens`: an upperbound, i.e. target length `T=200`
|
35 |
+
- `min_prompt_tokens` : prompt len lower bound such as 50
|
36 |
+
- `min_generations` : the number of 'good' samples we'd like, ie `N=500`
|
37 |
+
|
38 |
+
### Prompt construction strategy
|
39 |
+
|
40 |
+
- `input_truncation_strategy`
|
41 |
+
|
42 |
+
One in `["completion_length", "prompt_length"]`. If the former, slices the end
|
43 |
+
`max_new_tokens` off of the raw sample to create the 'prompt' with the leading prefix (which can have variable length), making the `max_new_tokens` removed, the `baseline_completion`, or gold output.
|
44 |
+
If the latter, selects the leading `min_prompt_tokens` off of the raw sample as the prompt,
|
45 |
+
leaving the remaining tokens (variable length) the `baseline_completion`.
|
46 |
+
|
47 |
+
### Filtering/oversampling criteria
|
48 |
+
|
49 |
+
- `input_filtering_strategy`: Can be one of `["completion_length", "prompt_length", "prompt_and_completion_length"]`.
|
50 |
+
In each case, if the relevant field doesn't meet the minimum criteria given by
|
51 |
+
`max_new_tokens` or `min_prompt_tokens` respectively, then the raw sample is thrown
|
52 |
+
away before ever even being fed to the model.
|
53 |
+
|
54 |
+
- `output_filtering_strategy`: Can be one in `["no_filter", "max_new_tokens"]`, if the former, then no output filtering
|
55 |
+
is performed after generations are sampled from the model. However, if `max_new_tokens`
|
56 |
+
then each both the unwatermarked and watermarked generations are checked to ensure that
|
57 |
+
they are at least `max_new_tokens` long.
|
58 |
+
|
59 |
+
This is a subtle way of trying to adaptively collect samples (online, from any dataset) such that eventually we end up with at least a subset that matches the squareness (`N` x `T`) criteria we desire, without _forcing_ this to happen on every sample
|
60 |
+
by turning off the EOS token which amounts to a potentially
|
61 |
+
pathological distribution shift in the unwatermarked and watermarked output distributions
|
62 |
+
which would potentially confound generality of results.
|
63 |
+
|
64 |
+
Other generation args descriptions are explained by their argparse defintions, but these in particular control the watermarking:
|
65 |
+
- `seeding_scheme`: the watermarking embedding scheme being used, such as `lefthash` (formerly `simple_1`) or `selfhash` (formerly `algorithm-3` in reference to previous paper)
|
66 |
+
- `gamma`: parameter controlling size of the green partition for watermarking
|
67 |
+
- `delta`: parameter controlling how much bias is added to the green token logits before sampling
|
68 |
+
|
69 |
+
---
|
70 |
+
|
71 |
+
## (2) Optionally, apply an **attack** transformation to weaken the watermark, or make detection harder (for non-watermarking methods as well).
|
72 |
+
|
73 |
+
We implement three types of attacks in this pipeline: `gpt`, `dipper`, and `copy-paste`.
|
74 |
+
The key parameters for each are as follows:
|
75 |
+
|
76 |
+
- `gpt`:
|
77 |
+
- `attack_model_name`: the OpenAI model variant to use
|
78 |
+
- `attack_prompt_id` : the index of the prompt to use, see `utils/prompts.json`
|
79 |
+
- `no_wm_attack`: whether to attack the un-watermarked generation column (`no_wm_output`).
|
80 |
+
Default is the watermarked generation (`w_wm_output`)
|
81 |
+
|
82 |
+
- `dipper`:
|
83 |
+
- `lex`: lexical diversity knob for the dipper model/method
|
84 |
+
- `order`: order diversity knob for the paraphrase attack
|
85 |
+
|
86 |
+
- `copy-paste`:
|
87 |
+
- `cp_attack_type`: k-t means `k` insertions of length `t`
|
88 |
+
- `cp_attack_num_insertions`: `k` spec'd as an integer
|
89 |
+
- `cp_attack_insertion_len`: `t` but generally spec'd as a percent of the full starting sequence length (i.e `25%`)
|
90 |
+
- `cp_attack_src_col` : the sequence we're taking the tokens "to be detected" from , i.e. "positive" examples for
|
91 |
+
the detector of interest. for watermarking this is `w_wm_output`
|
92 |
+
- `cp_attack_dst_col` : the sequence we treat as "negative" surrounding context for the detector of interest. for watermarking this is `no_wm_output`.
|
93 |
+
|
94 |
+
All parameters have an associated help string in their argparse definition.
|
95 |
+
|
96 |
+
The `attack_pipeline.py` call in `run_pipeline.sh` demonstrates the basic usage of the attack functionality.
|
97 |
+
|
98 |
+
---
|
99 |
+
|
100 |
+
## (3) Run **evaluation** and watermark detection
|
101 |
+
|
102 |
+
This batches the process of applying a combination of metric
|
103 |
+
functions to the dataset of generations (jsonl) and returns a
|
104 |
+
new dataset of generations (jsonl) just with extra columns for a bunch of metrics.
|
105 |
+
|
106 |
+
This is separated from the generation phase to allow a given set of
|
107 |
+
expensive generations to be reanalyzed in differnet ways with differnet metric
|
108 |
+
flavors as necessary.
|
109 |
+
|
110 |
+
The key parameters controlling metrics:
|
111 |
+
|
112 |
+
|
113 |
+
Key parameters and usage notes for detection:
|
114 |
+
- `evaluation_metrics`: a comma sep list of metrics to evaluate, such as `p-sp,repetition,diversity,z-score,windowed-z-score`
|
115 |
+
- `window_settings`: if running windowed detection specs the comma sep'd windowing strategies (such as `20,40,max`)
|
116 |
+
- `retrieval_technique`: if running retrieval detection, whether to use the `sim` or `bm25` strategy
|
117 |
+
|
118 |
+
All (other) parameters have a help string in their argparse definition.
|
119 |
+
|
120 |
+
The `evaluation_pipeline.py` call in `run_pipeline.sh` demonstrates the basic usage.
|
121 |
+
|
122 |
+
### Argument union and precedence
|
123 |
+
|
124 |
+
First, all arguments used at generation time (metadata file) are loaded by the
|
125 |
+
evaluation pipeline. Then the commandline args that were passed to the eval pipeline
|
126 |
+
are added via an update, or "overwriting union" operator, where all new args for
|
127 |
+
evaluation only are added to the current metadata object, but those that were
|
128 |
+
also present at generation time are _**overwritten**_ by those included in the
|
129 |
+
evaluation argparse.
|
130 |
+
|
131 |
+
If they match, then this is standard behavior. Overwriting shared arguments
|
132 |
+
is disabled via the `overwrite_args` flag by default, but can be allowed this way.
|
133 |
+
|
134 |
+
Additionally, the code writes the metrics file into the same directory as the
|
135 |
+
generations file if only `input_dir` is passed. However, for safety clarity and organization,
|
136 |
+
one can pass an output dir in which to write the new dataset with metrics, as well
|
137 |
+
as the evaluation metadata as demonstrated in the `run_pipeline.sh` example.
|
138 |
+
|
139 |
+
---
|
140 |
+
|
141 |
+
## (3.1) Retrieval and DetectGPT detection
|
142 |
+
|
143 |
+
### Creating **prefixes**:
|
144 |
+
|
145 |
+
**Retrieval** detection is implemented as a metric, i.e. it is run by the evaluation script. To perform retrieval detection on full examples, nothing extra is required. To run retrieval at T, you first must run `broadcast_token_prefixes.py` with the `save_per_prefix` argument as `False` and with a `prefix_stride` of choice, such as 50, with a clean generation or attacked generation directory (with `jsonl` and meta file inside) as input. This will create a version of the dataset (new `jsonl` file) that contains all of the original rows, duplicated and then sliced to each prefix length defined by iterating by `prefix_stride` in the sequence length dimension.
|
146 |
+
|
147 |
+
For ex, if you have a file with `N=500` rows of length about `T=200` each, then running this script with `prefix_stride=50` would create a new file with `N=2000` where the first `500` rows all have length `50`, the next `500` have length `100` etc. If a given row say length `119` is too short for prefix length `i`, say the 3rd slice size in this example, `150`, then in the third block, it would be marked as `None`. This is to avoid any prefix block expected to be totally comprising a certain prefix length from containing a bunch of sequnces that are shorter than expected which confounds the measurement.
|
148 |
+
|
149 |
+
Now for **DetectGPT** a separate script, `detectgpt/detectgpt_main.py`, must be run pointing at a clean generation or attacked generation `jsonl` file. Additionally, to run detectgpt @ T, similar prefixing logic must be used. However, it must be run with `save_per_prefix` as `True` this time, which then creates a set of new files, each containing all the rows of the input `jsonl` file but trucated to each prefix length as described above. Then each run of the detectgpt script produces a new `jsonl` file (of length `N=500` in the above example) with the detectgpt score column added. Then, the notebook `join_jsonl_prefix_files.ipynb` can be used to join all those separate jsonl files for each individual prefix into one full file (`N=2000`).
|
150 |
+
|
151 |
+
### Running **detection**
|
152 |
+
For Retrieval detection, all that is necessary is to run the evaluation script on the `jsonl` containing all the prefixes, and point estimates for the detection at each prefix length will be created by grouping by the prefix length column and reducing. Note, the retrieval method will load only the full sequences into the retrieval database (by loading only the longest sample for each original row, so just `500` sequences in our example), but will query, or perform detection using all of the different prefixes.
|
153 |
+
|
154 |
+
For DetectGPT, the evaluation script must also be run, but with the `evaluation_metrics=detectgpt` alone, and no other metrics. This is because most of the script is a no-op at this point as every row already contains a detectgpt score and they just need to be turned into ROC plots or AUC measurements. As with retrieval detection, these will be automatically grouped by prefix length and reduced.
|
lm-watermarking-main/watermark_reliability_release/README.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 💧2.0: [On the Reliability of Watermarks for Large Language Models](https://arxiv.org/abs/2306.04634)
|
2 |
+
|
3 |
+
This directory contains the codebase for reproducing the experiments in our [new 6/7/23 preprint](https://arxiv.org/abs/2306.04634).
|
4 |
+
|
5 |
+
### **NOTE**: this is a preliminary release, so please expect some small changes in the future as required.
|
6 |
+
|
7 |
+
---
|
8 |
+
|
9 |
+
The watermarking and watermark detection code itself is an extension of the `WatermarkLogitsProcessor` and `WatermarkDetector` classes released as part of the original work and contained in the root of the repository. Additional logic implementing a wider array of seeding schemes and alternate detection strategies is included and depended upon by the extended versions of the classes in this directory.
|
10 |
+
|
11 |
+
To facilitate the broader array of experiments required for this study, an extra pipeline abstraction was implemented to manage the "generation", paraphrase "attack", and "evaluation" or detection phases. The general setup is that data, i.e. sets of generated samples, is written and read by each stage as "json lines" files `*.jsonl` with associated metadata files `*.json` to keep track of parameter settings used at each stage.
|
12 |
+
|
13 |
+
A prose version of usage instructions for the pipeline is described in a separate markdown file here: [PIPELINE.md](PIPELINE.md)
|
14 |
+
|
15 |
+
## wandb
|
16 |
+
|
17 |
+
The pipeline scripts, and in particular, the evaluation stage where detection is run and generation quality metrics are computed, are configured to push results to weights and biases (wandb). The figures in the paper are produced by:
|
18 |
+
1. sketching out the charts in wandb using filters and tags
|
19 |
+
2. exporting/downloading the csv's of the data for each chart, and
|
20 |
+
3. loading them in a notebook to format plots as necessary.
|
21 |
+
|
22 |
+
Alternately, the evaluation stage also saves a jsonl file where every line is a set of generations and all associated metrics and detection scores computed for it. This can also be loaded and analyzed manually in pandas, though the ROC space analyzes and average@T series for some metrics will have to be recomputed.
|
23 |
+
|
24 |
+
## llama
|
25 |
+
|
26 |
+
In order to use the llama model, you need to bring-your-own-weights, and then covert them to the huggingface format.
|
27 |
+
|
lm-watermarking-main/watermark_reliability_release/alternative_prf_schemes.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implement other PRF functions, so, hashing schemes.
|
2 |
+
|
3 |
+
Can be hooked into existing WatermarkLogitsProcessor as modified base class WatermarkBase
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from itertools import combinations
|
8 |
+
from functools import cache
|
9 |
+
|
10 |
+
# Key properties of a hashing scheme
|
11 |
+
props = {
|
12 |
+
"prf_type": str, # string name of the underlying PRF mapping multiple token ids to a random seed
|
13 |
+
"context_width": int, # this is h in the paper, how many previous tokens should be considered for each PRF
|
14 |
+
"self_salt": bool, # Use the rules laid in robust-watermarking to use the token itself to seed and possibly reject its own list
|
15 |
+
"hash_key": int, # integer, large prime, used to move seed away from low-entrop bit sequences in PRF chosen above
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def seeding_scheme_lookup(seeding_scheme: str):
|
20 |
+
if not isinstance(seeding_scheme, str):
|
21 |
+
raise ValueError("Seeding scheme should be a string summarizing the procedure.")
|
22 |
+
if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
|
23 |
+
# Default, simple bigram hash # alias for ff-additive_prf-1-False-15485863
|
24 |
+
prf_type = "additive_prf"
|
25 |
+
context_width = 1
|
26 |
+
self_salt = False
|
27 |
+
hash_key = 15485863
|
28 |
+
elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
|
29 |
+
prf_type = "anchored_minhash_prf"
|
30 |
+
context_width = 4
|
31 |
+
self_salt = True
|
32 |
+
hash_key = 15485863
|
33 |
+
elif seeding_scheme == "skipgram":
|
34 |
+
prf_type = "skipgram_prf"
|
35 |
+
context_width = 5
|
36 |
+
self_salt = False
|
37 |
+
hash_key = 15485863
|
38 |
+
elif seeding_scheme.startswith(
|
39 |
+
"ff"
|
40 |
+
): # freeform seeding scheme API - only use for experimenting
|
41 |
+
# expects strings of the form ff-additive_prf-4-True-hash or ff-additive_prf-5-True (hash key is optional)
|
42 |
+
split_scheme = seeding_scheme.split("-")
|
43 |
+
prf_type = str(split_scheme[1])
|
44 |
+
context_width = int(split_scheme[2])
|
45 |
+
self_salt = split_scheme[3] == "True"
|
46 |
+
if len(split_scheme) == 5:
|
47 |
+
hash_key = int(split_scheme[4])
|
48 |
+
else:
|
49 |
+
hash_key = 15485863
|
50 |
+
else:
|
51 |
+
raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?")
|
52 |
+
|
53 |
+
assert prf_type in prf_lookup.keys()
|
54 |
+
return prf_type, context_width, self_salt, hash_key
|
55 |
+
|
56 |
+
|
57 |
+
def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
58 |
+
return salt_key * input_ids.prod().item()
|
59 |
+
|
60 |
+
|
61 |
+
def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
62 |
+
return salt_key * input_ids.sum().item()
|
63 |
+
|
64 |
+
|
65 |
+
def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
66 |
+
# not a great idea for non-random input ids as in text
|
67 |
+
return salt_key * input_ids.min().item()
|
68 |
+
|
69 |
+
|
70 |
+
def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
|
71 |
+
# k is the skip distance
|
72 |
+
return hashint(salt_key * input_ids[::k]).prod().item()
|
73 |
+
|
74 |
+
|
75 |
+
def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
76 |
+
# maximum distance skipgram within context
|
77 |
+
return hashint(salt_key * input_ids[0]).item()
|
78 |
+
|
79 |
+
|
80 |
+
def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
|
81 |
+
# maximum distance skipgram within context
|
82 |
+
return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()
|
83 |
+
|
84 |
+
|
85 |
+
def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
86 |
+
# slightly less not the greatest idea for non-random input ids as in text
|
87 |
+
return hashint(salt_key * input_ids).min().item()
|
88 |
+
|
89 |
+
|
90 |
+
def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
|
91 |
+
# Anchor to one key to produce a min over pairs again
|
92 |
+
return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()
|
93 |
+
|
94 |
+
|
95 |
+
def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
|
96 |
+
# min over all skipgrams in context, k=2 is all pairs
|
97 |
+
skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
|
98 |
+
return skipgrams.prod(dim=1).min().item()
|
99 |
+
|
100 |
+
|
101 |
+
def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
|
102 |
+
key = torch.as_tensor(salt_key, dtype=torch.long)
|
103 |
+
for entry in input_ids:
|
104 |
+
key *= hashint(key * entry)
|
105 |
+
key %= 2**32
|
106 |
+
return key.item()
|
107 |
+
|
108 |
+
|
109 |
+
def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
|
110 |
+
return (
|
111 |
+
(salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device))
|
112 |
+
.sum()
|
113 |
+
.item()
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
prf_lookup = {
|
118 |
+
"multiplicative_prf": multiplicative_prf,
|
119 |
+
"additive_prf": additive_prf,
|
120 |
+
"minfunc_prf": minfunc_prf,
|
121 |
+
"simple_skip_prf": simple_skip_prf,
|
122 |
+
"skipgram_prf": skipgram_prf,
|
123 |
+
"anchored_skipgram_prf": anchored_skipgram_prf,
|
124 |
+
"minhash_prf": minhash_prf,
|
125 |
+
"anchored_minhash_prf": anchored_minhash_prf,
|
126 |
+
"minskipgram_prf": minskipgram_prf,
|
127 |
+
"noncomm_prf": noncomm_prf,
|
128 |
+
"position_prf": position_prf,
|
129 |
+
}
|
130 |
+
|
131 |
+
# Generate a global permute table once at startup
|
132 |
+
rng = torch.Generator(device=torch.device("cpu"))
|
133 |
+
rng.manual_seed(2971215073) # fib47 is prime
|
134 |
+
table_size = 1_000_003
|
135 |
+
fixed_table = torch.randperm(
|
136 |
+
1_000_003, device=torch.device("cpu"), generator=rng
|
137 |
+
) # actually faster than I thought
|
138 |
+
|
139 |
+
|
140 |
+
def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
|
141 |
+
"""Sane version, in the end we only need a small permutation table."""
|
142 |
+
return (
|
143 |
+
fixed_table[integer_tensor.cpu() % table_size] + 1
|
144 |
+
) # minor cheat here, this function always return CPU values
|
145 |
+
|
146 |
+
|
147 |
+
def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
|
148 |
+
"""http://burtleburtle.net/bob/hash/integer.html, ported into pytorch, runs on tensors. Apparently a decent avalanche."""
|
149 |
+
i = integer_tensor.to(torch.int32).clone() # or torch.int16?
|
150 |
+
i -= i << 6
|
151 |
+
i ^= i >> 17
|
152 |
+
i -= i << 9
|
153 |
+
i ^= i << 4
|
154 |
+
i -= i << 3
|
155 |
+
i ^= i << 10
|
156 |
+
i ^= i >> 15
|
157 |
+
return i.to(torch.long)
|
158 |
+
|
159 |
+
|
160 |
+
@cache
|
161 |
+
def _hashint_avalanche_int(integer: int):
|
162 |
+
"""http://burtleburtle.net/bob/hash/integer.html, runs in base python, caches based on access.
|
163 |
+
Does this make sense for signed 64bit ints?"""
|
164 |
+
i = integer % (2**32)
|
165 |
+
i -= i << 6
|
166 |
+
i ^= i >> 17
|
167 |
+
i -= i << 9
|
168 |
+
i ^= i << 4
|
169 |
+
i -= i << 3
|
170 |
+
i ^= i << 10
|
171 |
+
i ^= i >> 15
|
172 |
+
return i
|
lm-watermarking-main/watermark_reliability_release/attack_pipeline.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
3 |
+
# available at https://arxiv.org/abs/2301.10226
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import os
|
18 |
+
import argparse
|
19 |
+
from functools import partial
|
20 |
+
|
21 |
+
from tqdm import tqdm
|
22 |
+
import wandb
|
23 |
+
|
24 |
+
from datasets import Dataset
|
25 |
+
from utils.submitit import str2bool # better bool flag type for argparse
|
26 |
+
from utils.io import read_jsonlines, read_json, write_json, write_jsonlines
|
27 |
+
|
28 |
+
from utils.evaluation import NO_CHECK_ARGS, load_tokenizer
|
29 |
+
|
30 |
+
from utils.attack import (
|
31 |
+
SUPPORTED_ATTACK_METHODS,
|
32 |
+
gpt_attack,
|
33 |
+
dipper_attack,
|
34 |
+
tokenize_for_copy_paste,
|
35 |
+
copy_paste_attack,
|
36 |
+
scramble_attack,
|
37 |
+
)
|
38 |
+
|
39 |
+
print(f"Current huggingface cache dir: {os.environ['HF_HOME']}")
|
40 |
+
|
41 |
+
|
42 |
+
def main(args):
|
43 |
+
###########################################################################
|
44 |
+
# Create output dir if it doesn't exist, and warn if it contains an
|
45 |
+
# attacked generations file
|
46 |
+
###########################################################################
|
47 |
+
gen_table_attacked_path = f"{args.output_dir}/gen_table_attacked.jsonl"
|
48 |
+
attacked_meta_path = f"{args.output_dir}/gen_table_attacked_meta.json"
|
49 |
+
|
50 |
+
print(f"Output dir for this run: {args.output_dir}")
|
51 |
+
# notify if exists
|
52 |
+
if os.path.exists(args.output_dir):
|
53 |
+
print(f"Output dir for this run already exists!")
|
54 |
+
print(f"Contents: {sorted(os.listdir(args.output_dir))}")
|
55 |
+
# warn if metrics file exists
|
56 |
+
if os.path.exists(gen_table_attacked_path):
|
57 |
+
if not args.overwrite_output_file:
|
58 |
+
print(
|
59 |
+
f"WARNING: Exiting to avoid overwriting output file. "
|
60 |
+
f"Pass the '--overwrite_output_file' flag to ignore this check."
|
61 |
+
)
|
62 |
+
exit()
|
63 |
+
else:
|
64 |
+
print(
|
65 |
+
f"WARNING: Found existing generation files with metrics added at this output dir. "
|
66 |
+
f"Overwriting anyway :/"
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
# create the output dir where run artifacts are stored
|
70 |
+
os.makedirs(args.output_dir)
|
71 |
+
|
72 |
+
###########################################################################
|
73 |
+
# Parse attack_method arg
|
74 |
+
###########################################################################
|
75 |
+
# check that attack method is supported
|
76 |
+
assert (
|
77 |
+
args.attack_method in SUPPORTED_ATTACK_METHODS
|
78 |
+
), f"Unsupported attack '{args.attack_method}'"
|
79 |
+
print(f"Attack method: {args.attack_method}")
|
80 |
+
|
81 |
+
###########################################################################
|
82 |
+
# Load generations
|
83 |
+
###########################################################################
|
84 |
+
print(f"Input dir for this run: {args.input_dir}")
|
85 |
+
print(f"Loading previously generated outputs for attacking ...")
|
86 |
+
gen_table_meta_path = f"{args.input_dir}/gen_table_meta.json"
|
87 |
+
gen_table_path = f"{args.input_dir}/gen_table.jsonl"
|
88 |
+
safe_gen_table_path = f"{args.input_dir}/gen_table_safe.jsonl"
|
89 |
+
|
90 |
+
assert os.path.exists(
|
91 |
+
gen_table_meta_path
|
92 |
+
), f"failed file check for prev generations metadata json file: {gen_table_meta_path}"
|
93 |
+
assert os.path.exists(
|
94 |
+
gen_table_path
|
95 |
+
), f"failed file check for prev generations jsonl file: {gen_table_path}"
|
96 |
+
assert not os.path.exists(safe_gen_table_path), (
|
97 |
+
f"failed for safety bc there is a secondary 'safe' marked file",
|
98 |
+
f" in this dir indicating a possible issue with the generation step. ",
|
99 |
+
)
|
100 |
+
|
101 |
+
cmdline_args = args.__dict__.copy()
|
102 |
+
prev_gen_table_meta = read_json(gen_table_meta_path)
|
103 |
+
|
104 |
+
joined_args = prev_gen_table_meta.copy()
|
105 |
+
joined_args.update(cmdline_args)
|
106 |
+
|
107 |
+
# check that the args used to generate the prev generations are the same as
|
108 |
+
# the current args, for the intersection of keys
|
109 |
+
if not args.overwrite_args:
|
110 |
+
for key in prev_gen_table_meta.keys():
|
111 |
+
if key in NO_CHECK_ARGS:
|
112 |
+
continue
|
113 |
+
assert joined_args[key] == prev_gen_table_meta[key], (
|
114 |
+
f"failed for safety bc after merging the prev metadata with "
|
115 |
+
f"the current cmdline args, values for '{key}' are not the same. "
|
116 |
+
f"in metadata: {prev_gen_table_meta[key]}, passed: {cmdline_args[key]}. "
|
117 |
+
f"Pass the '--overwrite_args' flag to ignore this check."
|
118 |
+
)
|
119 |
+
|
120 |
+
args = argparse.Namespace(**joined_args)
|
121 |
+
gen_table = [ex for ex in read_jsonlines(gen_table_path)]
|
122 |
+
gen_table_ds = Dataset.from_list(gen_table[: args.limit_rows])
|
123 |
+
|
124 |
+
###########################################################################
|
125 |
+
# Start logging, we wait to do this until after loading the generations
|
126 |
+
# so that we can log the args used to generate them unioned with the
|
127 |
+
# cmdline args
|
128 |
+
###########################################################################
|
129 |
+
# storing slurm info to allow auditing logfiles
|
130 |
+
# note this is set after the metadata check to ignore overwriting
|
131 |
+
args.SLURM_JOB_ID = os.getenv("SLURM_JOB_ID")
|
132 |
+
args.SLURM_ARRAY_JOB_ID = os.getenv("SLURM_ARRAY_JOB_ID")
|
133 |
+
args.SLURM_ARRAY_TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
|
134 |
+
|
135 |
+
if args.wandb:
|
136 |
+
# start a new wandb run to track this experiment, will send data to it
|
137 |
+
run = wandb.init(
|
138 |
+
# set the wandb project where this run will be logged
|
139 |
+
project=args.wandb_project,
|
140 |
+
entity=args.wandb_entity,
|
141 |
+
name=f"{args.run_name}",
|
142 |
+
# track hyperparameters and run metadata
|
143 |
+
config=args,
|
144 |
+
tags=args.wandb_tags,
|
145 |
+
)
|
146 |
+
|
147 |
+
###########################################################################
|
148 |
+
# GPT attack
|
149 |
+
###########################################################################
|
150 |
+
|
151 |
+
if args.attack_method == "gpt":
|
152 |
+
print("Running GPT attack")
|
153 |
+
import openai
|
154 |
+
|
155 |
+
openai.api_key = os.environ["OPENAI_API_KEY"]
|
156 |
+
prompt_pool = read_json("utils/prompts.json")["prompt_pool"]
|
157 |
+
prompt_pool = {int(k): v for k, v in prompt_pool.items()}
|
158 |
+
|
159 |
+
if args.attack_prompt is None:
|
160 |
+
attack_prompt = prompt_pool[args.attack_prompt_id]
|
161 |
+
args.attack_prompt = attack_prompt
|
162 |
+
|
163 |
+
print(f"Using attack prompt: {attack_prompt}")
|
164 |
+
|
165 |
+
gpt_attack_partial = partial(
|
166 |
+
gpt_attack,
|
167 |
+
attack_prompt=attack_prompt,
|
168 |
+
args=args,
|
169 |
+
)
|
170 |
+
# gen_table_attacked_ds = gen_table_ds.map(
|
171 |
+
# gpt_attack_partial, batched=False, num_proc=min(len(gen_table_ds), 16)
|
172 |
+
# )
|
173 |
+
gen_table_attacked_ds = gen_table_ds.map(gpt_attack_partial, batched=False)
|
174 |
+
|
175 |
+
###########################################################################
|
176 |
+
# DIPPER attack
|
177 |
+
###########################################################################
|
178 |
+
|
179 |
+
elif args.attack_method == "dipper":
|
180 |
+
print("Running DIPPER attack")
|
181 |
+
print(f"Using lexical diversity: {args.lex}, order diversity: {args.order}")
|
182 |
+
gen_table_attacked_ds = dipper_attack(
|
183 |
+
gen_table_ds, lex=args.lex, order=args.order, args=args
|
184 |
+
)
|
185 |
+
|
186 |
+
###########################################################################
|
187 |
+
# Scramble attack
|
188 |
+
###########################################################################
|
189 |
+
elif args.attack_method == "scramble":
|
190 |
+
# if no cp_attack_min_len specified, use args.max_new_tokens
|
191 |
+
if args.cp_attack_min_len == 0:
|
192 |
+
args.cp_attack_min_len = args.max_new_tokens
|
193 |
+
tokenizer = load_tokenizer(args)
|
194 |
+
scramble_attack_partial = partial(
|
195 |
+
scramble_attack,
|
196 |
+
tokenizer=tokenizer,
|
197 |
+
args=args,
|
198 |
+
)
|
199 |
+
gen_table_attacked_ds = gen_table_ds.map(scramble_attack_partial, batched=False)
|
200 |
+
###########################################################################
|
201 |
+
# Copy-paste attack
|
202 |
+
###########################################################################
|
203 |
+
elif args.attack_method == "copy-paste":
|
204 |
+
# if no cp_attack_min_len specified, use args.max_new_tokens
|
205 |
+
if args.cp_attack_min_len == 0:
|
206 |
+
args.cp_attack_min_len = args.max_new_tokens
|
207 |
+
|
208 |
+
# NOTE FIXME: the above arg indicates the filter condition by which
|
209 |
+
# some rows are skipped/not attacked/NOOP. Since the attacked col
|
210 |
+
# is set to the empty string, and length 0, the detection code
|
211 |
+
# including the baselines 🤞🏼 will ignore these rows one way or another
|
212 |
+
|
213 |
+
# convert cp_attack_insertion_len to int
|
214 |
+
if "%" in args.cp_attack_insertion_len:
|
215 |
+
original_len_str = args.cp_attack_insertion_len
|
216 |
+
# treat as a percent of 1 minus the length of the source col
|
217 |
+
# effectively how much of the source col "remains", accounting for
|
218 |
+
# the number of insertions that will be made to total this length
|
219 |
+
args.cp_attack_insertion_len = (
|
220 |
+
int((int(args.cp_attack_insertion_len[:-1]) / 100) * args.max_new_tokens)
|
221 |
+
// args.cp_attack_num_insertions
|
222 |
+
)
|
223 |
+
# check that this is not more than args.max_new_tokens total
|
224 |
+
assert (
|
225 |
+
args.cp_attack_insertion_len * args.cp_attack_num_insertions <= args.max_new_tokens
|
226 |
+
) and (
|
227 |
+
args.cp_attack_insertion_len * args.cp_attack_num_insertions > 0
|
228 |
+
), f"Invalid attack strength: {original_len_str} for {args.cp_attack_num_insertions} insertions."
|
229 |
+
|
230 |
+
args.cp_attack_effective_attack_percentage = (
|
231 |
+
1 - (int(original_len_str[:-1]) / 100)
|
232 |
+
) * 100
|
233 |
+
print(
|
234 |
+
f"Effective attack percentage is 1-{original_len_str}={args.cp_attack_effective_attack_percentage}% by "
|
235 |
+
f"copying {args.cp_attack_num_insertions} x {args.cp_attack_insertion_len} = {args.cp_attack_num_insertions * args.cp_attack_insertion_len} tokens "
|
236 |
+
f"from {args.cp_attack_src_col} to {args.cp_attack_dst_col} where T={args.max_new_tokens}"
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
args.cp_attack_insertion_len = int(args.cp_attack_insertion_len)
|
240 |
+
args.cp_attack_effective_attack_percentage = (
|
241 |
+
1
|
242 |
+
- (
|
243 |
+
(args.cp_attack_insertion_len * args.cp_attack_num_insertions)
|
244 |
+
/ args.max_new_tokens
|
245 |
+
)
|
246 |
+
) * 100
|
247 |
+
print(
|
248 |
+
f"Effective attack percentage is {args.cp_attack_effective_attack_percentage}% by "
|
249 |
+
f"copying {args.cp_attack_num_insertions} x {args.cp_attack_insertion_len} = {args.cp_attack_num_insertions * args.cp_attack_insertion_len} tokens "
|
250 |
+
f"from {args.cp_attack_src_col} to {args.cp_attack_dst_col} where T={args.max_new_tokens}"
|
251 |
+
)
|
252 |
+
|
253 |
+
tokenizer = load_tokenizer(args)
|
254 |
+
tokenize_for_copy_paste_partial = partial(tokenize_for_copy_paste, tokenizer=tokenizer)
|
255 |
+
gen_table_tokd_ds = gen_table_ds.map(tokenize_for_copy_paste_partial, batched=False)
|
256 |
+
|
257 |
+
copy_paste_attack_partial = partial(copy_paste_attack, tokenizer=tokenizer, args=args)
|
258 |
+
gen_table_attacked_ds = gen_table_tokd_ds.map(copy_paste_attack_partial, batched=False)
|
259 |
+
###########################################################################
|
260 |
+
# Write the final dataset out to disk in jsonl format
|
261 |
+
# with the metrics added
|
262 |
+
###########################################################################
|
263 |
+
else:
|
264 |
+
raise ValueError(f"Invalid attack method: {args.attack_method}")
|
265 |
+
|
266 |
+
# write the metadata file, which is a union of the previous metadata
|
267 |
+
# and the current cmdline args
|
268 |
+
write_json(args.__dict__, attacked_meta_path, indent=4)
|
269 |
+
|
270 |
+
gen_table_attacked_lst = [ex for ex in gen_table_attacked_ds]
|
271 |
+
write_jsonlines(gen_table_attacked_lst, gen_table_attacked_path)
|
272 |
+
|
273 |
+
###########################################################################
|
274 |
+
# Log the data/series to wandb
|
275 |
+
###########################################################################
|
276 |
+
# log the metrics to wandb
|
277 |
+
if args.wandb:
|
278 |
+
# find cols that should be logged in a table
|
279 |
+
tabular_column_types = ["string", "bool"]
|
280 |
+
tabular_column_names = [
|
281 |
+
name
|
282 |
+
for name, _ in filter(
|
283 |
+
lambda tup: tup[1].dtype in tabular_column_types,
|
284 |
+
gen_table_attacked_ds.features.items(),
|
285 |
+
)
|
286 |
+
]
|
287 |
+
# the rest should be logged as series
|
288 |
+
series_column_names = [
|
289 |
+
name
|
290 |
+
for name, _ in filter(
|
291 |
+
lambda tup: tup[1].dtype not in tabular_column_types,
|
292 |
+
gen_table_attacked_ds.features.items(),
|
293 |
+
)
|
294 |
+
]
|
295 |
+
for metric_name in series_column_names:
|
296 |
+
# summarize series metrics as mean by default
|
297 |
+
wandb.define_metric(metric_name, summary="mean")
|
298 |
+
# log the raw series
|
299 |
+
for example in tqdm(
|
300 |
+
gen_table_attacked_ds.remove_columns(tabular_column_names),
|
301 |
+
desc="Logging series metrics to wandb",
|
302 |
+
):
|
303 |
+
run.log(example)
|
304 |
+
# log the raw tabular data
|
305 |
+
# but also include the dataset index as a column
|
306 |
+
series_column_names.remove("idx")
|
307 |
+
table = wandb.Table(
|
308 |
+
dataframe=gen_table_attacked_ds.remove_columns(series_column_names).to_pandas()
|
309 |
+
)
|
310 |
+
run.log({"output_table": table})
|
311 |
+
|
312 |
+
# finish the wandb run
|
313 |
+
run.finish()
|
314 |
+
|
315 |
+
return
|
316 |
+
|
317 |
+
|
318 |
+
if __name__ == "__main__":
|
319 |
+
parser = argparse.ArgumentParser(description="Run evaluation pipeline for watermark detection")
|
320 |
+
parser.add_argument(
|
321 |
+
"--attack_method",
|
322 |
+
type=str,
|
323 |
+
choices=SUPPORTED_ATTACK_METHODS,
|
324 |
+
default="gpt",
|
325 |
+
help="The attack method to use.",
|
326 |
+
)
|
327 |
+
parser.add_argument(
|
328 |
+
"--attack_model_name",
|
329 |
+
type=str,
|
330 |
+
default="gpt-3.5-turbo",
|
331 |
+
)
|
332 |
+
parser.add_argument(
|
333 |
+
"--attack_temperature",
|
334 |
+
type=float,
|
335 |
+
default=0.7,
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--attack_max_tokens",
|
339 |
+
type=int,
|
340 |
+
default=1000,
|
341 |
+
)
|
342 |
+
parser.add_argument(
|
343 |
+
"--attack_prompt_id",
|
344 |
+
type=int,
|
345 |
+
default=4,
|
346 |
+
)
|
347 |
+
parser.add_argument(
|
348 |
+
"--attack_prompt",
|
349 |
+
type=str,
|
350 |
+
default=None,
|
351 |
+
help="Pass in the prompt to use for the attack. Is loaded by id from utils/prompts.json by default.",
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--no_wm_attack",
|
355 |
+
type=str2bool,
|
356 |
+
default=False,
|
357 |
+
help="Whether to attack the no_wm_output column when running gpt or dipper.",
|
358 |
+
)
|
359 |
+
parser.add_argument(
|
360 |
+
"--overwrite_args",
|
361 |
+
type=str2bool,
|
362 |
+
default=False,
|
363 |
+
help="Whether to overwrite the shared args in the metadata file with the current, runtime args.",
|
364 |
+
)
|
365 |
+
parser.add_argument(
|
366 |
+
"--wandb",
|
367 |
+
type=str2bool,
|
368 |
+
default=False,
|
369 |
+
help="Whether to log to wandb.",
|
370 |
+
)
|
371 |
+
parser.add_argument(
|
372 |
+
"--wandb_project",
|
373 |
+
type=str,
|
374 |
+
default="lm-watermarking",
|
375 |
+
help="The name of the wandb project.",
|
376 |
+
)
|
377 |
+
parser.add_argument(
|
378 |
+
"--wandb_entity",
|
379 |
+
type=str,
|
380 |
+
default="jwkirchenbauer",
|
381 |
+
help="The wandb entity/user for the project.",
|
382 |
+
)
|
383 |
+
parser.add_argument(
|
384 |
+
"--wandb_tags",
|
385 |
+
type=str,
|
386 |
+
default="",
|
387 |
+
help="The comma separated list of tags to add to the wandb run.",
|
388 |
+
)
|
389 |
+
parser.add_argument(
|
390 |
+
"--run_name",
|
391 |
+
type=str,
|
392 |
+
default=None,
|
393 |
+
help="The unique name for the run.",
|
394 |
+
)
|
395 |
+
parser.add_argument(
|
396 |
+
"--input_dir",
|
397 |
+
type=str,
|
398 |
+
default="./input",
|
399 |
+
help="The directory containing the input files.",
|
400 |
+
)
|
401 |
+
parser.add_argument(
|
402 |
+
"--output_dir",
|
403 |
+
type=str,
|
404 |
+
default=None,
|
405 |
+
help=(
|
406 |
+
"The directory in which to write out the dataset after adding the metrics. "
|
407 |
+
"If not specified, will use the input_dir. Note, if the output_dir already "
|
408 |
+
"contains the metric-enriched file, it will be overwritten :/"
|
409 |
+
),
|
410 |
+
)
|
411 |
+
parser.add_argument(
|
412 |
+
"--overwrite_output_file",
|
413 |
+
type=str2bool,
|
414 |
+
default=False,
|
415 |
+
help="Whether to overwrite the output file if it already exists.",
|
416 |
+
)
|
417 |
+
parser.add_argument(
|
418 |
+
"--limit_rows",
|
419 |
+
type=int,
|
420 |
+
default=None,
|
421 |
+
help="The number of rows to limit the dataset to. Useful for debugging.",
|
422 |
+
)
|
423 |
+
parser.add_argument(
|
424 |
+
"--verbose",
|
425 |
+
type=str2bool,
|
426 |
+
default=False,
|
427 |
+
help="Whether to print verbose output of every attack.",
|
428 |
+
)
|
429 |
+
parser.add_argument(
|
430 |
+
"--lex",
|
431 |
+
type=int,
|
432 |
+
default=20,
|
433 |
+
help="Lexical diversity knob for the paraphrase attack.",
|
434 |
+
)
|
435 |
+
parser.add_argument(
|
436 |
+
"--order",
|
437 |
+
type=int,
|
438 |
+
default=0,
|
439 |
+
help="Order diversity knob for the paraphrase attack.",
|
440 |
+
)
|
441 |
+
parser.add_argument(
|
442 |
+
"--cp_attack_type",
|
443 |
+
type=str,
|
444 |
+
default="single-single",
|
445 |
+
choices=["single-single", "triple-single", "k-t"],
|
446 |
+
help="Type of copy-paste attack to be run.",
|
447 |
+
)
|
448 |
+
parser.add_argument(
|
449 |
+
"--cp_attack_min_len",
|
450 |
+
type=int,
|
451 |
+
default=0,
|
452 |
+
help="Minimum length of cols for the copy-paste attack to be run.",
|
453 |
+
)
|
454 |
+
parser.add_argument(
|
455 |
+
"--cp_attack_num_insertions",
|
456 |
+
type=int,
|
457 |
+
default=3,
|
458 |
+
help="Length of the insertion for the copy-paste attack.",
|
459 |
+
)
|
460 |
+
parser.add_argument(
|
461 |
+
"--cp_attack_insertion_len",
|
462 |
+
type=str,
|
463 |
+
default="20",
|
464 |
+
help=(
|
465 |
+
f"Length of the insertion for the copy-paste attack. "
|
466 |
+
f"Converts to int. Unless expressed as a percentage, "
|
467 |
+
f"in which case it refers to what percent of src is copied to dst, "
|
468 |
+
f"which is 1-attack strength as a percentage."
|
469 |
+
),
|
470 |
+
)
|
471 |
+
parser.add_argument(
|
472 |
+
"--cp_attack_src_col",
|
473 |
+
type=str,
|
474 |
+
default="w_wm_output",
|
475 |
+
help="Source column for the copy-paste attack.",
|
476 |
+
)
|
477 |
+
parser.add_argument(
|
478 |
+
"--cp_attack_dst_col",
|
479 |
+
type=str,
|
480 |
+
default="no_wm_output",
|
481 |
+
help="Destination column for the copy-paste attack.",
|
482 |
+
)
|
483 |
+
args = parser.parse_args()
|
484 |
+
|
485 |
+
###########################################################################
|
486 |
+
# Argument validation and conditional setting
|
487 |
+
###########################################################################
|
488 |
+
|
489 |
+
assert args.attack_method, "attack_method must be specified"
|
490 |
+
|
491 |
+
# if no output dir specified, use the input dir
|
492 |
+
if args.output_dir is None:
|
493 |
+
args.output_dir = args.input_dir
|
494 |
+
|
495 |
+
# check limit_rows
|
496 |
+
assert (args.limit_rows is None) or (
|
497 |
+
(args.limit_rows > 0) and isinstance(args.limit_rows, int)
|
498 |
+
), "limit_rows must be > 0 or None"
|
499 |
+
|
500 |
+
# split wandb tags
|
501 |
+
if args.wandb_tags != "":
|
502 |
+
args.wandb_tags = args.wandb_tags.split(",")
|
503 |
+
else:
|
504 |
+
args.wandb_tags = []
|
505 |
+
|
506 |
+
main(args)
|
lm-watermarking-main/watermark_reliability_release/broadcast_token_prefixes.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
3 |
+
# available at https://arxiv.org/abs/2301.10226
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
note = "Note: this script should be moved to/run from the same dir as the `utils` subdir lives in to work properly"
|
19 |
+
print(note)
|
20 |
+
|
21 |
+
import os
|
22 |
+
import argparse
|
23 |
+
from functools import partial
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
import wandb
|
27 |
+
import torch
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
from datasets import Dataset, concatenate_datasets
|
31 |
+
|
32 |
+
from utils.submitit import str2bool # better bool flag type for argparse
|
33 |
+
from utils.io import read_jsonlines, read_json, write_json, write_jsonlines
|
34 |
+
from utils.evaluation import load_tokenizer, NO_CHECK_ARGS
|
35 |
+
from utils.generation import tokenize_only
|
36 |
+
|
37 |
+
print(f"Current huggingface cache dir: {os.environ['HF_HOME']}")
|
38 |
+
|
39 |
+
|
40 |
+
def main(args):
|
41 |
+
###########################################################################
|
42 |
+
# Load generations
|
43 |
+
###########################################################################
|
44 |
+
print(f"Input dir for this run: {args.input_dir}")
|
45 |
+
|
46 |
+
print(f"Loading previously generated outputs for evaluation via oracle model and metrics...")
|
47 |
+
|
48 |
+
# check for the "attacked version" of the gen table first
|
49 |
+
gen_table_meta_path = f"{args.input_dir}/gen_table_attacked_meta.json"
|
50 |
+
gen_table_path = f"{args.input_dir}/gen_table_attacked.jsonl"
|
51 |
+
safe_gen_table_path = f"{args.input_dir}/gen_table_attacked_safe.jsonl"
|
52 |
+
|
53 |
+
attack_variants_exist = [
|
54 |
+
os.path.exists(gen_table_meta_path),
|
55 |
+
os.path.exists(gen_table_path),
|
56 |
+
]
|
57 |
+
found_attacked_files = all(attack_variants_exist)
|
58 |
+
if not found_attacked_files:
|
59 |
+
gen_table_meta_path = f"{args.input_dir}/gen_table_meta.json"
|
60 |
+
gen_table_path = f"{args.input_dir}/gen_table.jsonl"
|
61 |
+
safe_gen_table_path = f"{args.input_dir}/gen_table_safe.jsonl"
|
62 |
+
|
63 |
+
assert os.path.exists(
|
64 |
+
gen_table_meta_path
|
65 |
+
), f"failed file check for prev generations metadata json file: {gen_table_meta_path}"
|
66 |
+
assert os.path.exists(
|
67 |
+
gen_table_path
|
68 |
+
), f"failed file check for prev generations jsonl file: {gen_table_path}"
|
69 |
+
|
70 |
+
assert not os.path.exists(safe_gen_table_path), (
|
71 |
+
f"failed for safety bc there is a secondary 'safe' marked file",
|
72 |
+
f" in this dir indicating a possible issue with the generation step. ",
|
73 |
+
)
|
74 |
+
|
75 |
+
cmdline_args = args.__dict__.copy()
|
76 |
+
prev_gen_table_meta = read_json(gen_table_meta_path)
|
77 |
+
|
78 |
+
joined_args = prev_gen_table_meta.copy()
|
79 |
+
for k, v in cmdline_args.items():
|
80 |
+
if v is not None or (not k in joined_args):
|
81 |
+
joined_args.update({k: v})
|
82 |
+
else:
|
83 |
+
print(
|
84 |
+
f"cmdline arg {k} is None, leaving it as the value found in the input metadata (or None): {prev_gen_table_meta.get(k)}"
|
85 |
+
)
|
86 |
+
|
87 |
+
# check that the args used to generate the prev generations are the same as
|
88 |
+
# the current args, for the intersection of keys
|
89 |
+
|
90 |
+
for key in prev_gen_table_meta.keys():
|
91 |
+
if key in NO_CHECK_ARGS:
|
92 |
+
continue
|
93 |
+
assert joined_args[key] == prev_gen_table_meta[key], (
|
94 |
+
f"failed for safety bc after merging the prev metadata with "
|
95 |
+
f"the current cmdline args, values for '{key}' are not the same. "
|
96 |
+
f"in metadata: {prev_gen_table_meta[key]}, passed: {cmdline_args[key]}. "
|
97 |
+
f"Pass the '--overwrite_args' flag to ignore this check."
|
98 |
+
)
|
99 |
+
|
100 |
+
args = argparse.Namespace(**joined_args)
|
101 |
+
gen_table = [ex for ex in read_jsonlines(gen_table_path)]
|
102 |
+
gen_table_ds = Dataset.from_list(gen_table[: args.limit_rows])
|
103 |
+
|
104 |
+
######## length filtering: only keeps the samples of exact N tokens ########
|
105 |
+
|
106 |
+
df = gen_table_ds.to_pandas()
|
107 |
+
original_len = len(df)
|
108 |
+
print(f"Origianl #samples: {original_len}")
|
109 |
+
if args.filter_length:
|
110 |
+
df = df[
|
111 |
+
(df["baseline_completion_length"] == args.max_new_tokens)
|
112 |
+
& (df["no_wm_output_length"] == args.max_new_tokens)
|
113 |
+
& (df["w_wm_output_length"] == args.max_new_tokens)
|
114 |
+
]
|
115 |
+
# TODO: filter length for the attacked output
|
116 |
+
print(f" after filtering token length: {len(df)}")
|
117 |
+
gen_table_ds = Dataset.from_pandas(df)
|
118 |
+
|
119 |
+
###########################################################################
|
120 |
+
# Prefix list logic
|
121 |
+
###########################################################################
|
122 |
+
from utils.generation import tokenize_and_truncate
|
123 |
+
|
124 |
+
print(f"Generating prefixes for the gen table...")
|
125 |
+
|
126 |
+
# load the tokenizer
|
127 |
+
tokenizer = load_tokenizer(args)
|
128 |
+
|
129 |
+
def generate_prefix(example, prefix_length=None, text_col_names=None, tokenizer=None):
|
130 |
+
assert prefix_length is not None, "prefix_length must be specified"
|
131 |
+
assert text_col_names is not None and isinstance(
|
132 |
+
text_col_names, list
|
133 |
+
), "text_col_names must be a list of column names"
|
134 |
+
|
135 |
+
# make a copy of the example
|
136 |
+
example = example.copy()
|
137 |
+
|
138 |
+
tokd_column_data = {}
|
139 |
+
for text_col_name in text_col_names:
|
140 |
+
try:
|
141 |
+
# check that the col exists
|
142 |
+
assert text_col_name in example, f"text_col_name '{text_col_name}' not in example"
|
143 |
+
# check whether the prefix is OOB for this example
|
144 |
+
# NOTE, this logic might not make perfect sense, but it avoids having prefixes that are ragged
|
145 |
+
# which is a better quality when measuring @ idx_T
|
146 |
+
|
147 |
+
# tokenize first because we can't rely on the length col existing
|
148 |
+
example = tokenize_only(
|
149 |
+
example,
|
150 |
+
input_col_name=text_col_name,
|
151 |
+
hf_model_name=args.model_name_or_path,
|
152 |
+
tokenizer=tokenizer,
|
153 |
+
model_max_length=args.model_max_length,
|
154 |
+
)
|
155 |
+
raw_inputs = example.pop("input_ids")
|
156 |
+
|
157 |
+
if not (prefix_length <= raw_inputs.shape[1]):
|
158 |
+
if args.verbose:
|
159 |
+
print(
|
160 |
+
f"Skipping prefix generation for col {text_col_name} because prefix_length"
|
161 |
+
f" {prefix_length} is OOB for this example (orig length={raw_inputs.shape[1]})."
|
162 |
+
)
|
163 |
+
continue
|
164 |
+
|
165 |
+
# else slice the inputs to the prefix length
|
166 |
+
inputs = raw_inputs[:, : prefix_length + 1]
|
167 |
+
prefix_len = inputs.shape[1]
|
168 |
+
|
169 |
+
# decode the prefix
|
170 |
+
decoded_prefix = tokenizer.decode(inputs[0], skip_special_tokens=True)
|
171 |
+
# store the prefix and it's length
|
172 |
+
tokd_column_data.update(
|
173 |
+
{
|
174 |
+
f"{text_col_name}": decoded_prefix,
|
175 |
+
f"{text_col_name}_length": prefix_len,
|
176 |
+
}
|
177 |
+
)
|
178 |
+
except Exception as e:
|
179 |
+
if args.verbose:
|
180 |
+
print(
|
181 |
+
f"Failed to generate prefix of len {prefix_length} for example idx={example['idx']}\n"
|
182 |
+
f"Should either be becuase the col doesnt exist, or the prefix is OOB for this col in this example."
|
183 |
+
)
|
184 |
+
print(f"Exception: {e}")
|
185 |
+
if text_col_name not in tokd_column_data:
|
186 |
+
tokd_column_data.update({f"{text_col_name}": None, f"{text_col_name}_length": None})
|
187 |
+
|
188 |
+
# add the prefix_len to the example
|
189 |
+
# then add the prefixes to the example
|
190 |
+
example.update({"prefix_length": prefix_length})
|
191 |
+
example.update(tokd_column_data)
|
192 |
+
return example
|
193 |
+
|
194 |
+
# if max_prefix_length is not specified, use the max length for the gen table
|
195 |
+
if args.max_prefix_length is None:
|
196 |
+
# args.max_prefix_length = args.model_max_length
|
197 |
+
args.max_prefix_length = args.max_new_tokens
|
198 |
+
|
199 |
+
# get the maximum length out of the ["baseline_completion_length", "no_wm_output_length", "w_wm_output_length", "w_wm_output_attacked_length"]
|
200 |
+
# found in the gen table
|
201 |
+
max_gen_table_output_length = max(
|
202 |
+
[
|
203 |
+
ex["baseline_completion_length"]
|
204 |
+
for ex in gen_table_ds
|
205 |
+
if "baseline_completion_length" in ex
|
206 |
+
]
|
207 |
+
+ [ex["no_wm_output_length"] for ex in gen_table_ds if "no_wm_output_length" in ex]
|
208 |
+
+ [ex["w_wm_output_length"] for ex in gen_table_ds if "w_wm_output_length" in ex]
|
209 |
+
+ [
|
210 |
+
ex["w_wm_output_attacked_length"]
|
211 |
+
for ex in gen_table_ds
|
212 |
+
if "w_wm_output_attacked_length" in ex
|
213 |
+
]
|
214 |
+
)
|
215 |
+
|
216 |
+
args.max_prefix_length = min(args.max_prefix_length, max_gen_table_output_length)
|
217 |
+
|
218 |
+
# round down to the nearest multiple of prefix_stride
|
219 |
+
last_multiple = args.max_prefix_length - (args.max_prefix_length % args.prefix_stride)
|
220 |
+
prefix_lengths = list(
|
221 |
+
range(args.prefix_stride, last_multiple + args.prefix_stride, args.prefix_stride)
|
222 |
+
)
|
223 |
+
# if missing the largest prefix length, add it
|
224 |
+
if prefix_lengths[-1] != args.max_prefix_length:
|
225 |
+
prefix_lengths.append(args.max_prefix_length)
|
226 |
+
|
227 |
+
if args.max_prefix_length > prefix_lengths[-1]:
|
228 |
+
print(
|
229 |
+
f"WARNING: max_prefix_length {args.max_prefix_length} is larger than the last prefix length {prefix_lengths[-1]} "
|
230 |
+
f"as computed by prefix_stride {args.prefix_stride} multiples up to the longest prefix length in the gen table: "
|
231 |
+
f"{max_gen_table_output_length}."
|
232 |
+
)
|
233 |
+
|
234 |
+
# store the prefix lengths
|
235 |
+
args.prefix_lengths = prefix_lengths
|
236 |
+
print(prefix_lengths)
|
237 |
+
|
238 |
+
###########################################################################
|
239 |
+
# Create output dir if it doesn't exist, and warn if it contains metric file
|
240 |
+
# we do this here because we need the prefix list
|
241 |
+
###########################################################################
|
242 |
+
# gen_table_prefixes_path = f"{args.output_dir}/gen_table_prefixes.jsonl"
|
243 |
+
# gen_table_prefixes_meta_path = f"{args.output_dir}/gen_table_prefixes_meta.json"
|
244 |
+
# making these the same as normal data so they can be used in the same way by eval
|
245 |
+
gen_table_prefixes_path = f"{args.output_dir}/gen_table.jsonl"
|
246 |
+
gen_table_prefixes_meta_path = f"{args.output_dir}/gen_table_meta.json"
|
247 |
+
|
248 |
+
if found_attacked_files:
|
249 |
+
gen_table_prefixes_path = f"{args.output_dir}/gen_table_attacked.jsonl"
|
250 |
+
gen_table_prefixes_meta_path = f"{args.output_dir}/gen_table_attacked_meta.json"
|
251 |
+
|
252 |
+
print(f"Output dir for this run: {args.output_dir}")
|
253 |
+
# notify if exists
|
254 |
+
if os.path.exists(args.output_dir):
|
255 |
+
print(f"Output dir for this run already exists!")
|
256 |
+
print(f"Contents: {sorted(os.listdir(args.output_dir))}")
|
257 |
+
# warn if metrics file exists
|
258 |
+
if args.save_per_prefix:
|
259 |
+
for prefix_len in prefix_lengths:
|
260 |
+
prefix_table_path = (
|
261 |
+
f"{gen_table_prefixes_path.replace('.jsonl','')}_{prefix_len}.jsonl"
|
262 |
+
)
|
263 |
+
if os.path.exists(prefix_table_path):
|
264 |
+
if not args.overwrite_output_file:
|
265 |
+
print(
|
266 |
+
f"WARNING: Exiting to avoid overwriting prefix output file. "
|
267 |
+
f"Pass the '--overwrite_output_file' flag to ignore this check."
|
268 |
+
)
|
269 |
+
exit()
|
270 |
+
else:
|
271 |
+
print(
|
272 |
+
f"WARNING: Found existing prefix files at this output dir. "
|
273 |
+
f"Overwriting anyway :/"
|
274 |
+
)
|
275 |
+
|
276 |
+
elif os.path.exists(gen_table_prefixes_path):
|
277 |
+
if not args.overwrite_output_file:
|
278 |
+
print(
|
279 |
+
f"WARNING: Exiting to avoid overwriting prefix output file. "
|
280 |
+
f"Pass the '--overwrite_output_file' flag to ignore this check."
|
281 |
+
)
|
282 |
+
exit()
|
283 |
+
else:
|
284 |
+
print(
|
285 |
+
f"WARNING: Found existing prefix files at this output dir. "
|
286 |
+
f"Overwriting anyway :/"
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
# create the output dir where run artifacts are stored
|
290 |
+
os.makedirs(args.output_dir)
|
291 |
+
|
292 |
+
###########################################################################
|
293 |
+
# Generate the prefixes
|
294 |
+
###########################################################################
|
295 |
+
|
296 |
+
prefix_tables = []
|
297 |
+
gen_table_ds_lst = [ex for ex in gen_table_ds]
|
298 |
+
|
299 |
+
# hacky check to see whether were working with attacked files
|
300 |
+
text_col_names = ["baseline_completion", "no_wm_output", "w_wm_output"]
|
301 |
+
if "w_wm_output_attacked" in gen_table_ds_lst[0]:
|
302 |
+
assert found_attacked_files, (
|
303 |
+
f"found 'w_wm_output_attacked' in the gen table, but apparently we didn't 'load attacked files'?."
|
304 |
+
f"Odd... please check whats going on in the input_dir."
|
305 |
+
)
|
306 |
+
|
307 |
+
text_col_names.append("w_wm_output_attacked")
|
308 |
+
|
309 |
+
for prefix_len in tqdm(prefix_lengths):
|
310 |
+
prefixes_partial = partial(
|
311 |
+
generate_prefix,
|
312 |
+
prefix_length=prefix_len,
|
313 |
+
tokenizer=tokenizer,
|
314 |
+
text_col_names=text_col_names,
|
315 |
+
)
|
316 |
+
gen_table_prefixes = [prefixes_partial(ex) for ex in gen_table_ds_lst]
|
317 |
+
|
318 |
+
# add the prefix dataset to the list of prefix tables
|
319 |
+
prefix_tables.append(Dataset.from_list(gen_table_prefixes))
|
320 |
+
|
321 |
+
# now concat the tables
|
322 |
+
gen_table_prefixes = concatenate_datasets(prefix_tables)
|
323 |
+
|
324 |
+
###########################################################################
|
325 |
+
# Write the metadata and final dataset out to disk in jsonl format
|
326 |
+
# (and optionally save the individual prefix shards)
|
327 |
+
###########################################################################
|
328 |
+
|
329 |
+
# write the metadata
|
330 |
+
write_json(args.__dict__, gen_table_prefixes_meta_path, indent=4)
|
331 |
+
|
332 |
+
# write the dataset
|
333 |
+
if not args.save_per_prefix:
|
334 |
+
write_jsonlines(gen_table_prefixes, gen_table_prefixes_path)
|
335 |
+
else:
|
336 |
+
# save the individual prefix shards
|
337 |
+
for prefix_len in prefix_lengths:
|
338 |
+
prefix_table = gen_table_prefixes.filter(lambda ex: ex["prefix_length"] == prefix_len)
|
339 |
+
prefix_table_path = f"{gen_table_prefixes_path.replace('.jsonl','')}_{prefix_len}.jsonl"
|
340 |
+
write_jsonlines(prefix_table, prefix_table_path)
|
341 |
+
|
342 |
+
|
343 |
+
if __name__ == "__main__":
|
344 |
+
parser = argparse.ArgumentParser(
|
345 |
+
description="Transform jsonl datasets into a broadcasted prefix version."
|
346 |
+
)
|
347 |
+
parser.add_argument(
|
348 |
+
"--model_name_or_path",
|
349 |
+
type=str,
|
350 |
+
help="use to load the tokenizer",
|
351 |
+
)
|
352 |
+
parser.add_argument(
|
353 |
+
"--prefix_stride",
|
354 |
+
type=int,
|
355 |
+
default=10,
|
356 |
+
help="The stride to use when generating prefixes.",
|
357 |
+
)
|
358 |
+
parser.add_argument(
|
359 |
+
"--max_prefix_length",
|
360 |
+
type=int,
|
361 |
+
default=None,
|
362 |
+
help="The maximum prefix length to use when generating prefixes.",
|
363 |
+
)
|
364 |
+
parser.add_argument(
|
365 |
+
"--model_max_length",
|
366 |
+
type=int,
|
367 |
+
default=2048,
|
368 |
+
)
|
369 |
+
parser.add_argument(
|
370 |
+
"--input_dir",
|
371 |
+
type=str,
|
372 |
+
default=None,
|
373 |
+
help="The directory containing the input files.",
|
374 |
+
)
|
375 |
+
parser.add_argument(
|
376 |
+
"--output_dir",
|
377 |
+
type=str,
|
378 |
+
default=None,
|
379 |
+
help=("The directory in which to write out the dataset after creating prefixes. "),
|
380 |
+
)
|
381 |
+
parser.add_argument(
|
382 |
+
"--save_per_prefix",
|
383 |
+
type=str2bool,
|
384 |
+
default=False,
|
385 |
+
help="Whether to save the individual shards of the dataset corresponding to each prefix length.",
|
386 |
+
)
|
387 |
+
parser.add_argument(
|
388 |
+
"--overwrite_output_file",
|
389 |
+
type=str2bool,
|
390 |
+
default=False,
|
391 |
+
help="Whether to overwrite the output file if it already exists.",
|
392 |
+
)
|
393 |
+
parser.add_argument(
|
394 |
+
"--limit_rows",
|
395 |
+
type=int,
|
396 |
+
default=None,
|
397 |
+
help="The number of rows to limit the dataset to. Useful for debugging.",
|
398 |
+
)
|
399 |
+
parser.add_argument(
|
400 |
+
"--verbose",
|
401 |
+
type=str2bool,
|
402 |
+
default=False,
|
403 |
+
help="Whether to print out the indexes for errors as the prefixes are generated.",
|
404 |
+
)
|
405 |
+
parser.add_argument(
|
406 |
+
"--filter_length",
|
407 |
+
action="store_true",
|
408 |
+
default=False,
|
409 |
+
)
|
410 |
+
parser.add_argument(
|
411 |
+
"--max_new_tokens",
|
412 |
+
type=int,
|
413 |
+
default=None,
|
414 |
+
)
|
415 |
+
args = parser.parse_args()
|
416 |
+
|
417 |
+
###########################################################################
|
418 |
+
# Argument validation and conditional setting
|
419 |
+
###########################################################################
|
420 |
+
|
421 |
+
# require output_dir to be specified and different from input_dir
|
422 |
+
assert args.input_dir is not None
|
423 |
+
assert args.output_dir is not None
|
424 |
+
assert args.input_dir != args.output_dir, "input_dir and output_dir must be different"
|
425 |
+
|
426 |
+
# check limit_rows
|
427 |
+
assert (args.limit_rows is None) or (
|
428 |
+
(args.limit_rows > 0) and isinstance(args.limit_rows, int)
|
429 |
+
), "limit_rows must be > 0 or None"
|
430 |
+
|
431 |
+
# check prefix_stride
|
432 |
+
assert (args.prefix_stride > 0) and isinstance(
|
433 |
+
args.prefix_stride, int
|
434 |
+
), "prefix_stride must be > 0"
|
435 |
+
|
436 |
+
main(args)
|
lm-watermarking-main/watermark_reliability_release/detectgpt/debug.sh
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#SBATCH --partition tron
|
4 |
+
|
5 |
+
#SBATCH --gres=gpu:rtxa6000:1
|
6 |
+
|
7 |
+
#SBATCH --ntasks=4
|
8 |
+
|
9 |
+
#SBATCH --mem=32G
|
10 |
+
|
11 |
+
#SBATCH --account=nexus
|
12 |
+
|
13 |
+
#SBATCH --qos=default
|
14 |
+
|
15 |
+
#SBATCH --time=48:00:00
|
16 |
+
|
17 |
+
#SBATCH --array=0-1
|
18 |
+
|
19 |
+
#SBATCH --output=slurm_logs/%A_%a.out
|
20 |
+
|
21 |
+
#SBATCH --job-name=run-detect
|
22 |
+
|
23 |
+
source ~/.bashrc
|
24 |
+
conda activate watermarking-dev
|
25 |
+
|
26 |
+
OUTPUT_DIR=/cmlscratch/manlis/test/watermarking-root/input/new_runs
|
27 |
+
|
28 |
+
# model_name="facebook/opt-1.3b"
|
29 |
+
# data_path="/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_200_opt1_3b_evaluation/gen_table_w_metrics.jsonl"
|
30 |
+
|
31 |
+
model_name='facebook/opt-6.7b'
|
32 |
+
data_path='/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_1000_evaluation/gen_table_w_metrics.jsonl'
|
33 |
+
|
34 |
+
mask_model="t5-3b"
|
35 |
+
|
36 |
+
# token_len=200
|
37 |
+
chunk_size=32
|
38 |
+
pct=0.3
|
39 |
+
split="no_wm"
|
40 |
+
|
41 |
+
textlen=600
|
42 |
+
|
43 |
+
# python detectgpt_main.py \
|
44 |
+
# --n_perturbation_list="10,100" \
|
45 |
+
# --do_chunk \
|
46 |
+
# --base_model_name=${model_name} \
|
47 |
+
# --mask_filling_model_name=${mask_model} \
|
48 |
+
# --data_path=/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_${textlen}_evaluation/gen_table_w_metrics.jsonl \
|
49 |
+
# --token_len=${textlen} \
|
50 |
+
# --pct_words_masked=${pct} \
|
51 |
+
# --chunk_size=${chunk_size} \
|
52 |
+
# --data_split=${split};
|
53 |
+
|
54 |
+
declare -a commands
|
55 |
+
|
56 |
+
for textlen in 600 1000;
|
57 |
+
do
|
58 |
+
commands+=( "python detectgpt_main.py \
|
59 |
+
--n_perturbation_list="10,100" \
|
60 |
+
--do_chunk \
|
61 |
+
--base_model_name=${model_name} \
|
62 |
+
--mask_filling_model_name=${mask_model} \
|
63 |
+
--data_path=/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_${textlen}_evaluation/gen_table_w_metrics.jsonl \
|
64 |
+
--token_len=${textlen} \
|
65 |
+
--pct_words_masked=${pct} \
|
66 |
+
--chunk_size=${chunk_size} \
|
67 |
+
--data_split=${split};" )
|
68 |
+
|
69 |
+
done
|
70 |
+
|
71 |
+
bash -c "${commands[${SLURM_ARRAY_TASK_ID}]}"
|
72 |
+
|
73 |
+
# --data_path=/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_${textlen}_evaluation/gen_table_w_metrics.jsonl \
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
lm-watermarking-main/watermark_reliability_release/detectgpt/detectgpt_main.py
ADDED
@@ -0,0 +1,807 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Basic imports
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import re
|
5 |
+
import functools
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
from statistics import mean
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
from matplotlib import rc
|
16 |
+
|
17 |
+
rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
|
18 |
+
rc("text", usetex=True)
|
19 |
+
from sklearn.metrics import roc_curve, precision_recall_curve, auc
|
20 |
+
|
21 |
+
import cmasher as cmr
|
22 |
+
|
23 |
+
# ### Load the processed dataset/frame
|
24 |
+
|
25 |
+
|
26 |
+
import sys
|
27 |
+
|
28 |
+
sys.path.insert(0, "..")
|
29 |
+
|
30 |
+
from datasets import Dataset
|
31 |
+
from utils.io import read_jsonlines, load_jsonlines
|
32 |
+
|
33 |
+
import transformers
|
34 |
+
|
35 |
+
# some file i/o helpers
|
36 |
+
from utils.io import write_jsonlines, write_json
|
37 |
+
|
38 |
+
INPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/input"
|
39 |
+
OUTPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/output"
|
40 |
+
|
41 |
+
|
42 |
+
# 15 colorblind-friendly colors
|
43 |
+
COLORS = [
|
44 |
+
"#0072B2",
|
45 |
+
"#009E73",
|
46 |
+
"#D55E00",
|
47 |
+
"#CC79A7",
|
48 |
+
"#F0E442",
|
49 |
+
"#56B4E9",
|
50 |
+
"#E69F00",
|
51 |
+
"#000000",
|
52 |
+
"#0072B2",
|
53 |
+
"#009E73",
|
54 |
+
"#D55E00",
|
55 |
+
"#CC79A7",
|
56 |
+
"#F0E442",
|
57 |
+
"#56B4E9",
|
58 |
+
"#E69F00",
|
59 |
+
]
|
60 |
+
|
61 |
+
|
62 |
+
def tokenize_and_mask(
|
63 |
+
text, span_length, pct, ceil_pct=False, buffer_size=1, mask_string="<<<mask>>>"
|
64 |
+
):
|
65 |
+
if isinstance(text, str):
|
66 |
+
tokens = text.split(" ")
|
67 |
+
else:
|
68 |
+
tokens = text
|
69 |
+
mask_string = mask_string
|
70 |
+
|
71 |
+
n_spans = pct * len(tokens) / (span_length + buffer_size * 2)
|
72 |
+
if ceil_pct:
|
73 |
+
n_spans = np.ceil(n_spans)
|
74 |
+
n_spans = int(n_spans)
|
75 |
+
|
76 |
+
n_masks = 0
|
77 |
+
while n_masks < n_spans:
|
78 |
+
start = np.random.randint(0, len(tokens) - span_length)
|
79 |
+
end = start + span_length
|
80 |
+
search_start = max(0, start - buffer_size)
|
81 |
+
search_end = min(len(tokens), end + buffer_size)
|
82 |
+
if mask_string not in tokens[search_start:search_end]:
|
83 |
+
tokens[start:end] = [mask_string]
|
84 |
+
n_masks += 1
|
85 |
+
|
86 |
+
# replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
|
87 |
+
num_filled = 0
|
88 |
+
for idx, token in enumerate(tokens):
|
89 |
+
if token == mask_string:
|
90 |
+
tokens[idx] = f"<extra_id_{num_filled}>"
|
91 |
+
num_filled += 1
|
92 |
+
assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
|
93 |
+
text = " ".join(tokens)
|
94 |
+
return text
|
95 |
+
|
96 |
+
|
97 |
+
def tokenize_and_mask_glm(
|
98 |
+
text, span_length, pct, ceil_pct=False, buffer_size=1, mask_string="[MASK]"
|
99 |
+
):
|
100 |
+
tokens = text.split(" ")
|
101 |
+
mask_string = mask_string
|
102 |
+
|
103 |
+
n_spans = pct * len(tokens) / (span_length + buffer_size * 2)
|
104 |
+
if ceil_pct:
|
105 |
+
n_spans = np.ceil(n_spans)
|
106 |
+
n_spans = int(n_spans)
|
107 |
+
|
108 |
+
n_masks = 0
|
109 |
+
while n_masks < n_spans:
|
110 |
+
start = np.random.randint(0, len(tokens) - span_length)
|
111 |
+
end = start + span_length
|
112 |
+
search_start = max(0, start - buffer_size)
|
113 |
+
search_end = min(len(tokens), end + buffer_size)
|
114 |
+
if mask_string not in tokens[search_start:search_end]:
|
115 |
+
tokens[start:end] = [mask_string]
|
116 |
+
n_masks += 1
|
117 |
+
|
118 |
+
text = " ".join(tokens)
|
119 |
+
return text
|
120 |
+
|
121 |
+
|
122 |
+
def count_masks(texts):
|
123 |
+
return [len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts]
|
124 |
+
|
125 |
+
|
126 |
+
# replace each masked span with a sample from T5 mask_model
|
127 |
+
def replace_masks(texts):
|
128 |
+
n_expected = count_masks(texts)
|
129 |
+
stop_id = mask_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
|
130 |
+
tokens = mask_tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to("cuda")
|
131 |
+
outputs = mask_model.generate(
|
132 |
+
**tokens,
|
133 |
+
max_length=mask_tokenizer.model_max_length,
|
134 |
+
do_sample=True,
|
135 |
+
top_p=1.0,
|
136 |
+
num_return_sequences=1,
|
137 |
+
eos_token_id=stop_id,
|
138 |
+
)
|
139 |
+
# outputs = mask_model.generate(**tokens, max_length=mask_tokenizer.model_max_length, do_sample=True, top_p=1.0, num_return_sequences=1, eos_token_id=stop_id)
|
140 |
+
return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)
|
141 |
+
|
142 |
+
|
143 |
+
def replace_masks_glm(texts):
|
144 |
+
# n_expected = [len([x for x in text.split() if x == '[MASK]']) for text in texts]
|
145 |
+
tokens = mask_tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
146 |
+
tokens = mask_tokenizer.build_inputs_for_generation(
|
147 |
+
tokens, max_gen_length=mask_tokenizer.model_max_length
|
148 |
+
).to("cuda")
|
149 |
+
outputs = mask_model.generate(
|
150 |
+
**tokens,
|
151 |
+
max_length=mask_tokenizer.model_max_length,
|
152 |
+
do_sample=True,
|
153 |
+
top_p=1.0,
|
154 |
+
num_return_sequences=1,
|
155 |
+
eos_token_id=mask_tokenizer.eop_token_id,
|
156 |
+
)
|
157 |
+
return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)
|
158 |
+
|
159 |
+
|
160 |
+
def extract_fills(texts):
|
161 |
+
# remove <pad> from beginning of each text
|
162 |
+
texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts]
|
163 |
+
|
164 |
+
# return the text in between each matched mask token
|
165 |
+
extracted_fills = [pattern.split(x)[1:-1] for x in texts]
|
166 |
+
|
167 |
+
# remove whitespace around each fill
|
168 |
+
extracted_fills = [[y.strip() for y in x] for x in extracted_fills]
|
169 |
+
|
170 |
+
return extracted_fills
|
171 |
+
|
172 |
+
|
173 |
+
def apply_extracted_fills(masked_texts, extracted_fills):
|
174 |
+
# split masked text into tokens, only splitting on spaces (not newlines)
|
175 |
+
tokens = [x.split(" ") for x in masked_texts]
|
176 |
+
|
177 |
+
n_expected = count_masks(masked_texts)
|
178 |
+
|
179 |
+
# replace each mask token with the corresponding fill
|
180 |
+
for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):
|
181 |
+
if len(fills) < n:
|
182 |
+
tokens[idx] = []
|
183 |
+
else:
|
184 |
+
for fill_idx in range(n):
|
185 |
+
text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx]
|
186 |
+
|
187 |
+
# join tokens back into text
|
188 |
+
texts = [" ".join(x) for x in tokens]
|
189 |
+
return texts
|
190 |
+
|
191 |
+
|
192 |
+
def perturb_texts_(
|
193 |
+
texts, span_length, pct, ceil_pct=False, mask_filling_model_name="t5-3b", do_chunk=False
|
194 |
+
):
|
195 |
+
if "t5" in mask_filling_model_name:
|
196 |
+
if do_chunk:
|
197 |
+
texts = [x.split(" ") for x in texts]
|
198 |
+
## chunk long texts
|
199 |
+
if max([len(x) for x in texts]) > 600:
|
200 |
+
text_pieces = [
|
201 |
+
[t[: len(t) // 3] for t in texts],
|
202 |
+
[t[len(t) // 3 : 2 * len(t) // 3] for t in texts],
|
203 |
+
[t[2 * len(t) // 3 :] for t in texts],
|
204 |
+
]
|
205 |
+
else:
|
206 |
+
text_pieces = [[t[: len(t) // 2] for t in texts], [t[len(t) // 2 :] for t in texts]]
|
207 |
+
|
208 |
+
perturbed_pieces = []
|
209 |
+
for pieces in text_pieces:
|
210 |
+
masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in pieces]
|
211 |
+
raw_fills = replace_masks(masked_texts)
|
212 |
+
extracted_fills = extract_fills(raw_fills)
|
213 |
+
perturbed_pieces.append(apply_extracted_fills(masked_texts, extracted_fills))
|
214 |
+
## put the chunks together
|
215 |
+
perturbed_texts = []
|
216 |
+
for i in range(len(texts)):
|
217 |
+
perturbed_texts.append(" ".join([p[i] for p in perturbed_pieces]))
|
218 |
+
else:
|
219 |
+
masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts]
|
220 |
+
raw_fills = replace_masks(masked_texts)
|
221 |
+
extracted_fills = extract_fills(raw_fills)
|
222 |
+
perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
|
223 |
+
# elif 'glm' in mask_filling_model_name:
|
224 |
+
# masked_texts = [tokenize_and_mask_glm(x, span_length, pct, ceil_pct) for x in texts]
|
225 |
+
# raw_fills = replace_masks_glm(masked_texts)
|
226 |
+
# extracted_fills = extract_fills(raw_fills)
|
227 |
+
# perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
|
228 |
+
|
229 |
+
# Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again
|
230 |
+
attempts = 1
|
231 |
+
while "" in perturbed_texts:
|
232 |
+
idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ""]
|
233 |
+
print(f"WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].")
|
234 |
+
if do_chunk:
|
235 |
+
new_perturbed_pieces = []
|
236 |
+
for pieces in text_pieces:
|
237 |
+
masked_texts = [
|
238 |
+
tokenize_and_mask(x, span_length, pct, ceil_pct)
|
239 |
+
for idx, x in enumerate(pieces)
|
240 |
+
if idx in idxs
|
241 |
+
]
|
242 |
+
raw_fills = replace_masks(masked_texts)
|
243 |
+
extracted_fills = extract_fills(raw_fills)
|
244 |
+
new_perturbed_pieces.append(apply_extracted_fills(masked_texts, extracted_fills))
|
245 |
+
new_perturbed_texts = []
|
246 |
+
for i in range(len(texts)):
|
247 |
+
new_perturbed_texts.append(" ".join([p[i] for p in new_perturbed_pieces]))
|
248 |
+
else:
|
249 |
+
masked_texts = [
|
250 |
+
tokenize_and_mask(x, span_length, pct, ceil_pct)
|
251 |
+
for idx, x in enumerate(texts)
|
252 |
+
if idx in idxs
|
253 |
+
]
|
254 |
+
raw_fills = replace_masks(masked_texts)
|
255 |
+
extracted_fills = extract_fills(raw_fills)
|
256 |
+
new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
|
257 |
+
for idx, x in zip(idxs, new_perturbed_texts):
|
258 |
+
perturbed_texts[idx] = x
|
259 |
+
attempts += 1
|
260 |
+
|
261 |
+
return perturbed_texts
|
262 |
+
|
263 |
+
|
264 |
+
def perturb_texts(
|
265 |
+
texts, span_length, pct, mask_filling_model_name, ceil_pct=False, chunk_size=20, do_chunk=False
|
266 |
+
):
|
267 |
+
chunk_size = chunk_size
|
268 |
+
if "11b" in mask_filling_model_name:
|
269 |
+
chunk_size //= 2
|
270 |
+
|
271 |
+
outputs = []
|
272 |
+
for i in tqdm(range(0, len(texts), chunk_size), desc="Applying perturbations"):
|
273 |
+
outputs.extend(
|
274 |
+
perturb_texts_(
|
275 |
+
texts[i : i + chunk_size],
|
276 |
+
span_length,
|
277 |
+
pct,
|
278 |
+
ceil_pct=ceil_pct,
|
279 |
+
mask_filling_model_name=mask_filling_model_name,
|
280 |
+
do_chunk=do_chunk,
|
281 |
+
)
|
282 |
+
)
|
283 |
+
return outputs
|
284 |
+
|
285 |
+
|
286 |
+
# Get the log likelihood of each text under the base_model
|
287 |
+
def get_ll(text):
|
288 |
+
with torch.no_grad():
|
289 |
+
tokenized = base_tokenizer(text, return_tensors="pt").to("cuda")
|
290 |
+
labels = tokenized.input_ids
|
291 |
+
return -base_model(**tokenized, labels=labels).loss.item()
|
292 |
+
|
293 |
+
|
294 |
+
def get_lls(texts):
|
295 |
+
return [get_ll(text) for text in texts]
|
296 |
+
|
297 |
+
|
298 |
+
def get_perturbation_results(
|
299 |
+
span_length=10,
|
300 |
+
chunk_size=50,
|
301 |
+
n_perturbations=1,
|
302 |
+
n_perturbation_rounds=1,
|
303 |
+
pct_words_masked=0.3,
|
304 |
+
data_split="wm",
|
305 |
+
mask_filling_model_name="t5-3b",
|
306 |
+
do_chunk=False,
|
307 |
+
save_path="/cmlscratch/manlis/test/watermarking-root/output/detect-gpt",
|
308 |
+
):
|
309 |
+
## check if pre-computed results exist
|
310 |
+
if os.path.isfile(os.path.join(save_path, f"perturbed_raw_texts_{n_perturbations}.jsonl")):
|
311 |
+
results = load_jsonlines(
|
312 |
+
os.path.join(save_path, f"perturbed_raw_texts_{n_perturbations}.jsonl")
|
313 |
+
)
|
314 |
+
else:
|
315 |
+
base_model.cpu()
|
316 |
+
mask_model.cuda()
|
317 |
+
|
318 |
+
torch.manual_seed(0)
|
319 |
+
np.random.seed(0)
|
320 |
+
|
321 |
+
results = []
|
322 |
+
original_text = df["baseline_completion"]
|
323 |
+
if data_split == "wm":
|
324 |
+
sampled_text = df["w_wm_output"]
|
325 |
+
elif data_split == "no_wm":
|
326 |
+
sampled_text = df["no_wm_output"]
|
327 |
+
elif data_split == "no_wm_paraphrase":
|
328 |
+
sampled_text = df["w_wm_output_attacked"]
|
329 |
+
else:
|
330 |
+
raise NotImplementedError(f"Unknown split: {data_split}")
|
331 |
+
|
332 |
+
perturb_fn = functools.partial(
|
333 |
+
perturb_texts,
|
334 |
+
span_length=span_length,
|
335 |
+
pct=pct_words_masked,
|
336 |
+
mask_filling_model_name=mask_filling_model_name,
|
337 |
+
chunk_size=chunk_size,
|
338 |
+
do_chunk=do_chunk,
|
339 |
+
)
|
340 |
+
|
341 |
+
p_sampled_text = perturb_fn([x for x in sampled_text for _ in range(n_perturbations)])
|
342 |
+
p_original_text = perturb_fn([x for x in original_text for _ in range(n_perturbations)])
|
343 |
+
for _ in range(n_perturbation_rounds - 1):
|
344 |
+
try:
|
345 |
+
p_sampled_text, p_original_text = perturb_fn(p_sampled_text), perturb_fn(
|
346 |
+
p_original_text
|
347 |
+
)
|
348 |
+
except AssertionError:
|
349 |
+
break
|
350 |
+
|
351 |
+
assert (
|
352 |
+
len(p_sampled_text) == len(sampled_text) * n_perturbations
|
353 |
+
), f"Expected {len(sampled_text) * n_perturbations} perturbed samples, got {len(p_sampled_text)}"
|
354 |
+
assert (
|
355 |
+
len(p_original_text) == len(original_text) * n_perturbations
|
356 |
+
), f"Expected {len(original_text) * n_perturbations} perturbed samples, got {len(p_original_text)}"
|
357 |
+
|
358 |
+
for i, idx in enumerate(original_text.index):
|
359 |
+
results.append(
|
360 |
+
{
|
361 |
+
"original": original_text[idx],
|
362 |
+
"sampled": sampled_text[idx],
|
363 |
+
"perturbed_sampled": p_sampled_text[
|
364 |
+
i * n_perturbations : (i + 1) * n_perturbations
|
365 |
+
],
|
366 |
+
"perturbed_original": p_original_text[
|
367 |
+
i * n_perturbations : (i + 1) * n_perturbations
|
368 |
+
],
|
369 |
+
}
|
370 |
+
)
|
371 |
+
|
372 |
+
## save perturbed samples in case job got preempted
|
373 |
+
write_jsonlines(
|
374 |
+
results, os.path.join(save_path, f"perturbed_raw_texts_{n_perturbations}.jsonl")
|
375 |
+
)
|
376 |
+
|
377 |
+
mask_model.cpu()
|
378 |
+
base_model.cuda()
|
379 |
+
|
380 |
+
for res in tqdm(results, desc="Computing log likelihoods"):
|
381 |
+
p_sampled_ll = get_lls(res["perturbed_sampled"])
|
382 |
+
p_original_ll = get_lls(res["perturbed_original"])
|
383 |
+
res["original_ll"] = get_ll(res["original"])
|
384 |
+
res["sampled_ll"] = get_ll(res["sampled"])
|
385 |
+
res["all_perturbed_sampled_ll"] = p_sampled_ll
|
386 |
+
res["all_perturbed_original_ll"] = p_original_ll
|
387 |
+
res["perturbed_sampled_ll"] = np.mean(p_sampled_ll)
|
388 |
+
res["perturbed_original_ll"] = np.mean(p_original_ll)
|
389 |
+
res["perturbed_sampled_ll_std"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1
|
390 |
+
res["perturbed_original_ll_std"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1
|
391 |
+
|
392 |
+
return results
|
393 |
+
|
394 |
+
|
395 |
+
def get_roc_metrics(real_preds, sample_preds):
|
396 |
+
fpr, tpr, _ = roc_curve(
|
397 |
+
[0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds
|
398 |
+
)
|
399 |
+
roc_auc = auc(fpr, tpr)
|
400 |
+
return fpr.tolist(), tpr.tolist(), float(roc_auc)
|
401 |
+
|
402 |
+
|
403 |
+
def get_precision_recall_metrics(real_preds, sample_preds):
|
404 |
+
precision, recall, _ = precision_recall_curve(
|
405 |
+
[0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds
|
406 |
+
)
|
407 |
+
pr_auc = auc(recall, precision)
|
408 |
+
return precision.tolist(), recall.tolist(), float(pr_auc)
|
409 |
+
|
410 |
+
|
411 |
+
def run_perturbation_experiment(
|
412 |
+
results, criterion, span_length=10, n_perturbations=1, pct_words_masked=0.3, n_samples=500
|
413 |
+
):
|
414 |
+
# compute diffs with perturbed
|
415 |
+
predictions = {"real": [], "samples": []}
|
416 |
+
for res in results:
|
417 |
+
if criterion == "d":
|
418 |
+
predictions["real"].append(res["original_ll"] - res["perturbed_original_ll"])
|
419 |
+
predictions["samples"].append(res["sampled_ll"] - res["perturbed_sampled_ll"])
|
420 |
+
elif criterion == "z":
|
421 |
+
if res["perturbed_original_ll_std"] == 0:
|
422 |
+
res["perturbed_original_ll_std"] = 1
|
423 |
+
print("WARNING: std of perturbed original is 0, setting to 1")
|
424 |
+
print(
|
425 |
+
f"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}"
|
426 |
+
)
|
427 |
+
print(f"Original text: {res['original']}")
|
428 |
+
if res["perturbed_sampled_ll_std"] == 0:
|
429 |
+
res["perturbed_sampled_ll_std"] = 1
|
430 |
+
print("WARNING: std of perturbed sampled is 0, setting to 1")
|
431 |
+
print(
|
432 |
+
f"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}"
|
433 |
+
)
|
434 |
+
print(f"Sampled text: {res['sampled']}")
|
435 |
+
predictions["real"].append(
|
436 |
+
(res["original_ll"] - res["perturbed_original_ll"])
|
437 |
+
/ res["perturbed_original_ll_std"]
|
438 |
+
)
|
439 |
+
predictions["samples"].append(
|
440 |
+
(res["sampled_ll"] - res["perturbed_sampled_ll"]) / res["perturbed_sampled_ll_std"]
|
441 |
+
)
|
442 |
+
|
443 |
+
fpr, tpr, roc_auc = get_roc_metrics(predictions["real"], predictions["samples"])
|
444 |
+
p, r, pr_auc = get_precision_recall_metrics(predictions["real"], predictions["samples"])
|
445 |
+
name = f"perturbation_{n_perturbations}_{criterion}"
|
446 |
+
print(f"{name} ROC AUC: {roc_auc}, PR AUC: {pr_auc}")
|
447 |
+
return {
|
448 |
+
"name": name,
|
449 |
+
"predictions": predictions,
|
450 |
+
"info": {
|
451 |
+
"pct_words_masked": pct_words_masked,
|
452 |
+
"span_length": span_length,
|
453 |
+
"n_perturbations": n_perturbations,
|
454 |
+
"n_samples": n_samples,
|
455 |
+
},
|
456 |
+
"raw_results": results,
|
457 |
+
"metrics": {
|
458 |
+
"roc_auc": roc_auc,
|
459 |
+
"fpr": fpr,
|
460 |
+
"tpr": tpr,
|
461 |
+
},
|
462 |
+
"pr_metrics": {
|
463 |
+
"pr_auc": pr_auc,
|
464 |
+
"precision": p,
|
465 |
+
"recall": r,
|
466 |
+
},
|
467 |
+
"loss": 1 - pr_auc,
|
468 |
+
}
|
469 |
+
|
470 |
+
|
471 |
+
## DetectGPT Running: get perturnation results
|
472 |
+
import json
|
473 |
+
|
474 |
+
|
475 |
+
# save the ROC curve for each experiment, given a list of output dictionaries, one for each experiment, using colorblind-friendly colors
|
476 |
+
def save_roc_curves(experiments, save_folder, args):
|
477 |
+
# first, clear plt
|
478 |
+
plt.clf()
|
479 |
+
|
480 |
+
for experiment, color in zip(experiments, COLORS):
|
481 |
+
metrics = experiment["metrics"]
|
482 |
+
plt.plot(
|
483 |
+
metrics["fpr"],
|
484 |
+
metrics["tpr"],
|
485 |
+
label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}",
|
486 |
+
color=color,
|
487 |
+
)
|
488 |
+
# print roc_auc for this experiment
|
489 |
+
print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}")
|
490 |
+
plt.plot([0, 1], [0, 1], color="black", lw=2, linestyle="--")
|
491 |
+
plt.xlim([0.0, 1.0])
|
492 |
+
plt.ylim([0.0, 1.05])
|
493 |
+
plt.xlabel("False Positive Rate")
|
494 |
+
plt.ylabel("True Positive Rate")
|
495 |
+
plt.title(f"ROC Curves ({args.base_model_name} - {args.mask_filling_model_name})")
|
496 |
+
plt.legend(loc="lower right", fontsize=6)
|
497 |
+
plt.savefig(f"{save_folder}/roc_curves.png")
|
498 |
+
|
499 |
+
|
500 |
+
def save_roc_curves_w_ztest(experiments, zscore_sample, zscore_original, save_folder, args):
|
501 |
+
# first, clear plt
|
502 |
+
plt.clf()
|
503 |
+
|
504 |
+
## make ztest ROC curve
|
505 |
+
positive_preds = np.array(zscore_sample)
|
506 |
+
negative_preds = np.array(zscore_original)
|
507 |
+
positive_labels = np.ones_like(positive_preds, dtype=int)
|
508 |
+
negative_labels = np.zeros_like(negative_preds, dtype=int)
|
509 |
+
|
510 |
+
all_preds = np.concatenate((positive_preds, negative_preds))
|
511 |
+
all_labels = np.concatenate((positive_labels, negative_labels))
|
512 |
+
|
513 |
+
tpr_z, fpr_z, _ = roc_curve(all_labels, all_preds)
|
514 |
+
roc_auc_z = auc(tpr_z, fpr_z)
|
515 |
+
plt.plot(tpr_z, fpr_z, label=f"z-score test, roc_auc={roc_auc_z:.3f}")
|
516 |
+
print(f"ztest roc_auc: {roc_auc_z:.3f}")
|
517 |
+
|
518 |
+
for experiment, color in zip(experiments, COLORS):
|
519 |
+
metrics = experiment["metrics"]
|
520 |
+
plt.plot(
|
521 |
+
metrics["fpr"],
|
522 |
+
metrics["tpr"],
|
523 |
+
label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}",
|
524 |
+
color=color,
|
525 |
+
)
|
526 |
+
# print roc_auc for this experiment
|
527 |
+
print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}")
|
528 |
+
plt.plot([0, 1], [0, 1], color="black", lw=2, linestyle="--")
|
529 |
+
plt.xlim([0.0, 1.0])
|
530 |
+
plt.ylim([0.0, 1.05])
|
531 |
+
plt.xlabel("False Positive Rate")
|
532 |
+
plt.ylabel("True Positive Rate")
|
533 |
+
plt.title(f"ROC Curves ({args.base_model_name} - {args.mask_filling_model_name})")
|
534 |
+
plt.legend(loc="lower right", fontsize=6)
|
535 |
+
plt.savefig(f"{save_folder}/roc_curves_w_ztests.png")
|
536 |
+
|
537 |
+
|
538 |
+
# save the histogram of log likelihoods in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed
|
539 |
+
def save_ll_histograms(experiments, save_folder):
|
540 |
+
# first, clear plt
|
541 |
+
plt.clf()
|
542 |
+
|
543 |
+
for experiment in experiments:
|
544 |
+
try:
|
545 |
+
results = experiment["raw_results"]
|
546 |
+
# plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
|
547 |
+
plt.figure(figsize=(20, 6))
|
548 |
+
plt.subplot(1, 2, 1)
|
549 |
+
plt.hist([r["sampled_ll"] for r in results], alpha=0.5, bins="auto", label="sampled")
|
550 |
+
plt.hist(
|
551 |
+
[r["perturbed_sampled_ll"] for r in results],
|
552 |
+
alpha=0.5,
|
553 |
+
bins="auto",
|
554 |
+
label="perturbed sampled",
|
555 |
+
)
|
556 |
+
plt.xlabel("log likelihood")
|
557 |
+
plt.ylabel("count")
|
558 |
+
plt.legend(loc="upper right")
|
559 |
+
plt.subplot(1, 2, 2)
|
560 |
+
plt.hist([r["original_ll"] for r in results], alpha=0.5, bins="auto", label="original")
|
561 |
+
plt.hist(
|
562 |
+
[r["perturbed_original_ll"] for r in results],
|
563 |
+
alpha=0.5,
|
564 |
+
bins="auto",
|
565 |
+
label="perturbed original",
|
566 |
+
)
|
567 |
+
plt.xlabel("log likelihood")
|
568 |
+
plt.ylabel("count")
|
569 |
+
plt.legend(loc="upper right")
|
570 |
+
plt.savefig(f"{save_folder}/ll_histograms_{experiment['name']}.png")
|
571 |
+
except:
|
572 |
+
pass
|
573 |
+
|
574 |
+
|
575 |
+
# save the histograms of log likelihood ratios in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed
|
576 |
+
def save_llr_histograms(experiments, save_folder):
|
577 |
+
# first, clear plt
|
578 |
+
plt.clf()
|
579 |
+
|
580 |
+
for experiment in experiments:
|
581 |
+
try:
|
582 |
+
results = experiment["raw_results"]
|
583 |
+
# plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
|
584 |
+
plt.figure(figsize=(20, 6))
|
585 |
+
plt.subplot(1, 2, 1)
|
586 |
+
|
587 |
+
# compute the log likelihood ratio for each result
|
588 |
+
for r in results:
|
589 |
+
r["sampled_llr"] = r["sampled_ll"] - r["perturbed_sampled_ll"]
|
590 |
+
r["original_llr"] = r["original_ll"] - r["perturbed_original_ll"]
|
591 |
+
|
592 |
+
plt.hist([r["sampled_llr"] for r in results], alpha=0.5, bins="auto", label="sampled")
|
593 |
+
plt.hist([r["original_llr"] for r in results], alpha=0.5, bins="auto", label="original")
|
594 |
+
plt.xlabel("log likelihood ratio")
|
595 |
+
plt.ylabel("count")
|
596 |
+
plt.legend(loc="upper right")
|
597 |
+
plt.savefig(f"{save_folder}/llr_histograms_{experiment['name']}.png")
|
598 |
+
except:
|
599 |
+
pass
|
600 |
+
|
601 |
+
|
602 |
+
if __name__ == "__main__":
|
603 |
+
parser = argparse.ArgumentParser(
|
604 |
+
description="Run detect-gpt with watermarked and baseline generations"
|
605 |
+
)
|
606 |
+
parser.add_argument(
|
607 |
+
"--base_model_name",
|
608 |
+
type=str,
|
609 |
+
default="facebook/opt-1.3b",
|
610 |
+
help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
|
611 |
+
)
|
612 |
+
parser.add_argument(
|
613 |
+
"--token_len",
|
614 |
+
type=int,
|
615 |
+
default=200,
|
616 |
+
)
|
617 |
+
parser.add_argument(
|
618 |
+
"--n_samples",
|
619 |
+
type=int,
|
620 |
+
default=500,
|
621 |
+
)
|
622 |
+
parser.add_argument(
|
623 |
+
"--chunk_size",
|
624 |
+
type=int,
|
625 |
+
default=32,
|
626 |
+
)
|
627 |
+
parser.add_argument(
|
628 |
+
"--data_path",
|
629 |
+
type=str,
|
630 |
+
)
|
631 |
+
parser.add_argument("--data_split", type=str, default="wm")
|
632 |
+
parser.add_argument(
|
633 |
+
"--mask_filling_model_name",
|
634 |
+
type=str,
|
635 |
+
default="t5-3b",
|
636 |
+
)
|
637 |
+
parser.add_argument("--n_positions", type=int, default=512)
|
638 |
+
parser.add_argument(
|
639 |
+
"--pct_words_masked",
|
640 |
+
type=float,
|
641 |
+
default=0.3,
|
642 |
+
)
|
643 |
+
parser.add_argument(
|
644 |
+
"--do_chunk",
|
645 |
+
action="store_true",
|
646 |
+
default=False,
|
647 |
+
)
|
648 |
+
parser.add_argument("--filter", type=str, default=None)
|
649 |
+
parser.add_argument("--mask_top_p", type=float, default=1.0)
|
650 |
+
parser.add_argument("--n_perturbation_list", type=str, default="1,10,100")
|
651 |
+
parser.add_argument("--n_perturbation_rounds", type=int, default=1)
|
652 |
+
parser.add_argument("--span_length", type=int, default=2)
|
653 |
+
parser.add_argument("--buffer_size", type=int, default=1)
|
654 |
+
|
655 |
+
args = parser.parse_args()
|
656 |
+
if args.token_len > 300:
|
657 |
+
args.do_chunk = True
|
658 |
+
## load data
|
659 |
+
list_of_dict = load_jsonlines(args.data_path)
|
660 |
+
raw_data = Dataset.from_list(list_of_dict)
|
661 |
+
df = raw_data.to_pandas()
|
662 |
+
## drop samples that are too short
|
663 |
+
original_len = len(df)
|
664 |
+
print(f"Origianl #samples: {original_len}")
|
665 |
+
if args.filter == "length":
|
666 |
+
df = df[
|
667 |
+
(df["baseline_completion_length"] == args.token_len)
|
668 |
+
& (df["no_wm_output_length"] == args.token_len)
|
669 |
+
& (df["w_wm_output_length"] == args.token_len)
|
670 |
+
]
|
671 |
+
print(f" after filtering token length: {len(df)}")
|
672 |
+
if args.filter == "null":
|
673 |
+
try:
|
674 |
+
df = df[
|
675 |
+
(df["w_wm_output_length"].notnull())
|
676 |
+
& (df["w_wm_output_attacked_length"].notnull())
|
677 |
+
& ~(df["w_wm_output_length"] == "")
|
678 |
+
& ~(df["w_wm_output_attacked_length"] == 0)
|
679 |
+
]
|
680 |
+
print(f" after filtering token length: {len(df)}")
|
681 |
+
except:
|
682 |
+
print(
|
683 |
+
"failed to filter null entries, probably because the file does not contain column 'w_wm_output_attacked_length'. "
|
684 |
+
)
|
685 |
+
args.n_samples = len(df)
|
686 |
+
|
687 |
+
## load models
|
688 |
+
int8_kwargs = {}
|
689 |
+
half_kwargs = {}
|
690 |
+
if (
|
691 |
+
"glm" not in args.mask_filling_model_name
|
692 |
+
): # GLM uses an OP that's not supported in BFloat16: "triu_tril_cuda_template" not implemented for 'BFloat16'
|
693 |
+
half_kwargs = dict(torch_dtype=torch.bfloat16)
|
694 |
+
else:
|
695 |
+
half_kwargs = dict(torch_dtype=torch.float16)
|
696 |
+
|
697 |
+
## load the base model (for generation) and base tokenizer
|
698 |
+
optional_tok_kwargs = {}
|
699 |
+
if "facebook/opt-" in args.base_model_name:
|
700 |
+
print("Using non-fast tokenizer for OPT")
|
701 |
+
optional_tok_kwargs["fast"] = False
|
702 |
+
base_model = transformers.AutoModelForCausalLM.from_pretrained(
|
703 |
+
args.base_model_name, **half_kwargs
|
704 |
+
)
|
705 |
+
base_model.eval()
|
706 |
+
|
707 |
+
####### load base tokenizer ########
|
708 |
+
if "llama" in args.base_model_name:
|
709 |
+
from transformers import LlamaTokenizer
|
710 |
+
|
711 |
+
base_tokenizer = LlamaTokenizer.from_pretrained(
|
712 |
+
args.base_model_name, padding_side="left", **optional_tok_kwargs
|
713 |
+
)
|
714 |
+
else:
|
715 |
+
base_tokenizer = transformers.AutoTokenizer.from_pretrained(
|
716 |
+
args.base_model_name, padding_side="left", **optional_tok_kwargs
|
717 |
+
)
|
718 |
+
base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
|
719 |
+
|
720 |
+
print(f"Loading mask filling model {args.mask_filling_model_name}...")
|
721 |
+
mask_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
|
722 |
+
args.mask_filling_model_name,
|
723 |
+
**int8_kwargs,
|
724 |
+
**half_kwargs,
|
725 |
+
trust_remote_code="glm" in args.mask_filling_model_name,
|
726 |
+
)
|
727 |
+
mask_model.eval()
|
728 |
+
|
729 |
+
## mask model max length
|
730 |
+
try:
|
731 |
+
if "glm" in args.mask_filling_model_name:
|
732 |
+
n_positions = mask_model.config.max_sequence_length
|
733 |
+
else:
|
734 |
+
n_positions = mask_model.config.n_positions
|
735 |
+
except AttributeError:
|
736 |
+
n_positions = 512
|
737 |
+
|
738 |
+
# if n_positions < args.token_len:
|
739 |
+
# raise ValueError(f"Mask model cannot handle input longer then {n_positions}. Input token length: {args.token_len}")
|
740 |
+
# preproc_tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small', model_max_length=n_positions)
|
741 |
+
mask_tokenizer = transformers.AutoTokenizer.from_pretrained(
|
742 |
+
args.mask_filling_model_name,
|
743 |
+
model_max_length=n_positions,
|
744 |
+
trust_remote_code="glm" in args.mask_filling_model_name,
|
745 |
+
)
|
746 |
+
mask_model.cpu()
|
747 |
+
|
748 |
+
# perturbing text ops
|
749 |
+
# define regex to match all <extra_id_*> tokens, where * is an integer
|
750 |
+
pattern = re.compile(r"<extra_id_\d+>")
|
751 |
+
|
752 |
+
SAVE_FOLDER = f'{OUTPUT_DIR}/detect-gpt/{os.path.basename(os.path.dirname(args.data_path))}-{args.data_split}-mask{args.mask_filling_model_name}/maskpct{args.pct_words_masked}-{os.path.basename(args.data_path).split(".")[0]}-ns{args.n_samples}'
|
753 |
+
os.makedirs(SAVE_FOLDER, exist_ok=True)
|
754 |
+
|
755 |
+
outputs = []
|
756 |
+
n_perturbation_list = [int(x) for x in args.n_perturbation_list.split(",")]
|
757 |
+
for n_perturbations in n_perturbation_list:
|
758 |
+
perturbation_results = get_perturbation_results(
|
759 |
+
args.span_length,
|
760 |
+
args.chunk_size,
|
761 |
+
n_perturbations,
|
762 |
+
args.n_perturbation_rounds,
|
763 |
+
args.pct_words_masked,
|
764 |
+
args.data_split,
|
765 |
+
args.mask_filling_model_name,
|
766 |
+
args.do_chunk,
|
767 |
+
save_path=SAVE_FOLDER,
|
768 |
+
)
|
769 |
+
for perturbation_mode in ["d", "z"]:
|
770 |
+
output = run_perturbation_experiment(
|
771 |
+
perturbation_results,
|
772 |
+
perturbation_mode,
|
773 |
+
span_length=args.span_length,
|
774 |
+
n_perturbations=n_perturbations,
|
775 |
+
pct_words_masked=args.pct_words_masked,
|
776 |
+
n_samples=args.n_samples,
|
777 |
+
)
|
778 |
+
outputs.append(output)
|
779 |
+
## write columns to the input df
|
780 |
+
df[
|
781 |
+
f"baseline_completion_detectgpt_score_{n_perturbations}_{perturbation_mode}"
|
782 |
+
] = output["predictions"]["real"]
|
783 |
+
df[f"no_wm_output_detectgpt_score_{n_perturbations}_{perturbation_mode}"] = output[
|
784 |
+
"predictions"
|
785 |
+
]["samples"]
|
786 |
+
with open(
|
787 |
+
os.path.join(
|
788 |
+
SAVE_FOLDER, f"perturbation_{n_perturbations}_{perturbation_mode}_results.json"
|
789 |
+
),
|
790 |
+
"w",
|
791 |
+
) as f:
|
792 |
+
json.dump(output, f)
|
793 |
+
|
794 |
+
## save the updated input df
|
795 |
+
with open(os.path.join(SAVE_FOLDER, os.path.basename(args.data_path)), "w") as f:
|
796 |
+
print(df.to_json(orient="records", lines=True), file=f, flush=False, end="")
|
797 |
+
## save meta file
|
798 |
+
gen_table_meta = args.__dict__
|
799 |
+
write_json(gen_table_meta, os.path.join(SAVE_FOLDER, "gen_table_meta.json"), indent=4)
|
800 |
+
|
801 |
+
### plot curves and histograms
|
802 |
+
|
803 |
+
save_roc_curves(outputs, SAVE_FOLDER, args)
|
804 |
+
# save_roc_curves_w_ztest(outputs, df["w_wm_output_z_score"],
|
805 |
+
# df["baseline_completion_z_score"], SAVE_FOLDER, args)
|
806 |
+
save_ll_histograms(outputs, SAVE_FOLDER)
|
807 |
+
save_llr_histograms(outputs, SAVE_FOLDER)
|
lm-watermarking-main/watermark_reliability_release/detectgpt/make_plot.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Basic imports
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from matplotlib import rc
|
9 |
+
rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
|
10 |
+
rc('text', usetex=True)
|
11 |
+
from sklearn.metrics import roc_curve, precision_recall_curve, auc
|
12 |
+
|
13 |
+
|
14 |
+
import sys
|
15 |
+
sys.path.insert(0, "..")
|
16 |
+
|
17 |
+
from datasets import Dataset
|
18 |
+
from utils.io import read_jsonlines, load_jsonlines
|
19 |
+
|
20 |
+
import transformers
|
21 |
+
|
22 |
+
from detectgpt_main import save_roc_curves_w_ztest
|
23 |
+
|
24 |
+
|
25 |
+
INPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/input"
|
26 |
+
OUTPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/output"
|
27 |
+
|
28 |
+
# 15 colorblind-friendly colors
|
29 |
+
COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442",
|
30 |
+
"#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73",
|
31 |
+
"#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"]
|
32 |
+
|
33 |
+
|
34 |
+
if __name__=="__main__":
|
35 |
+
parser = argparse.ArgumentParser(
|
36 |
+
description="Run detect-gpt with watermarked and baseline generations"
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--base_model_name",
|
40 |
+
type=str,
|
41 |
+
default="facebook/opt-1.3b",
|
42 |
+
help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--data_name",
|
46 |
+
type=str,
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--token_len",
|
50 |
+
type=int,
|
51 |
+
default=200,
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--n_samples",
|
55 |
+
type=int,
|
56 |
+
default=500,
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--chunk_size",
|
60 |
+
type=int,
|
61 |
+
default=500,
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--data_path",
|
65 |
+
type=str,
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--data_split",
|
69 |
+
type=str,
|
70 |
+
default="wm"
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--mask_filling_model_name",
|
74 |
+
type=str,
|
75 |
+
default="t5-3b",
|
76 |
+
)
|
77 |
+
parser.add_argument('--n_positions', type=int, default=512)
|
78 |
+
parser.add_argument(
|
79 |
+
"--pct_words_masked",
|
80 |
+
type=float,
|
81 |
+
default=0.3,
|
82 |
+
)
|
83 |
+
parser.add_argument('--mask_top_p', type=float, default=1.0)
|
84 |
+
parser.add_argument('--n_perturbation_list', type=str, default="1,10,100")
|
85 |
+
parser.add_argument('--n_perturbation_rounds', type=int, default=1)
|
86 |
+
parser.add_argument('--span_length', type=int, default=2)
|
87 |
+
parser.add_argument('--buffer_size', type=int, default=1)
|
88 |
+
|
89 |
+
args = parser.parse_args()
|
90 |
+
|
91 |
+
## load data
|
92 |
+
list_of_dict = load_jsonlines(args.data_path)
|
93 |
+
raw_data = Dataset.from_list(list_of_dict)
|
94 |
+
df = raw_data.to_pandas()
|
95 |
+
## drop samples that are too short
|
96 |
+
original_len = len(df)
|
97 |
+
df = df[(df["baseline_completion_length"] == args.token_len) \
|
98 |
+
& (df["no_wm_num_tokens_generated"] == args.token_len) \
|
99 |
+
& (df["w_wm_num_tokens_generated"] == args.token_len) ]
|
100 |
+
print(f"Origianl #samples: {original_len}, after filtering token length: {len(df)}")
|
101 |
+
args.n_samples = len(df)
|
102 |
+
|
103 |
+
# perturbing text ops
|
104 |
+
# define regex to match all <extra_id_*> tokens, where * is an integer
|
105 |
+
pattern = re.compile(r"<extra_id_\d+>")
|
106 |
+
|
107 |
+
SAVE_FOLDER = f'{OUTPUT_DIR}/detect-gpt/{args.data_name}-{args.data_split}-mask{args.mask_filling_model_name}/maskpct{args.pct_words_masked}-ns{args.n_samples}'
|
108 |
+
os.makedirs(SAVE_FOLDER, exist_ok=True)
|
109 |
+
|
110 |
+
outputs = []
|
111 |
+
|
112 |
+
n_perturbation_list = [int(x) for x in args.n_perturbation_list.split(",")]
|
113 |
+
for n_perturbations in n_perturbation_list:
|
114 |
+
for perturbation_mode in ['d', 'z']:
|
115 |
+
with open(os.path.join(SAVE_FOLDER, f"perturbation_{n_perturbations}_{perturbation_mode}_results.json"), "r") as f:
|
116 |
+
output = json.load(f)
|
117 |
+
outputs.append(output)
|
118 |
+
|
119 |
+
|
120 |
+
### plot curves and histograms
|
121 |
+
|
122 |
+
save_roc_curves_w_ztest(outputs, df["w_wm_output_z_score"],
|
123 |
+
df["baseline_completion_z_score"], SAVE_FOLDER, args,
|
124 |
+
)
|
lm-watermarking-main/watermark_reliability_release/detectgpt/plot.sh
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#SBATCH --partition tron
|
4 |
+
|
5 |
+
#SBATCH --gres=gpu:rtxa5000:1
|
6 |
+
|
7 |
+
#SBATCH --ntasks=4
|
8 |
+
|
9 |
+
#SBATCH --mem=32G
|
10 |
+
|
11 |
+
#SBATCH --account=nexus
|
12 |
+
|
13 |
+
#SBATCH --qos=default
|
14 |
+
|
15 |
+
#SBATCH --time=48:00:00
|
16 |
+
|
17 |
+
#SBATCH --array=0
|
18 |
+
|
19 |
+
#SBATCH --output=slurm_logs/%A_%a.out
|
20 |
+
|
21 |
+
#SBATCH --job-name=gen-small
|
22 |
+
|
23 |
+
source ~/.bashrc
|
24 |
+
conda activate watermarking-dev
|
25 |
+
|
26 |
+
OUTPUT_DIR=/cmlscratch/manlis/test/watermarking-root/input/new_runs
|
27 |
+
|
28 |
+
model_name="facebook/opt-1.3b"
|
29 |
+
data_name="test_len_200_opt1_3b"
|
30 |
+
token_len=200
|
31 |
+
chunk_size=32
|
32 |
+
data_path="/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_200_opt1_3b_evaluation/gen_table_w_metrics.jsonl"
|
33 |
+
split="no_wm"
|
34 |
+
|
35 |
+
|
36 |
+
python make_plot.py \
|
37 |
+
--n_perturbation_list="1,10,100" \
|
38 |
+
--base_model_name=${model_name} \
|
39 |
+
--data_name=${data_name} \
|
40 |
+
--data_path=${data_path} \
|
41 |
+
--token_len=${token_len} \
|
42 |
+
--chunk_size=${chunk_size} \
|
43 |
+
--data_split=${split};
|
44 |
+
|
45 |
+
|
46 |
+
|
lm-watermarking-main/watermark_reliability_release/detectgpt/run_detectgpt.sh
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#SBATCH --partition scavenger
|
4 |
+
|
5 |
+
#SBATCH --gres=gpu:rtxa6000:1
|
6 |
+
|
7 |
+
#SBATCH --ntasks=4
|
8 |
+
|
9 |
+
#SBATCH --mem=32G
|
10 |
+
|
11 |
+
#SBATCH --account=scavenger
|
12 |
+
|
13 |
+
#SBATCH --qos=scavenger
|
14 |
+
|
15 |
+
#SBATCH --time=24:00:00
|
16 |
+
|
17 |
+
#SBATCH --array=0-2
|
18 |
+
|
19 |
+
#SBATCH --output=slurm_logs/no_wm_attack_%A_%a.out
|
20 |
+
|
21 |
+
#SBATCH --job-name=run-detect
|
22 |
+
|
23 |
+
source ~/.bashrc
|
24 |
+
conda activate watermarking-dev
|
25 |
+
|
26 |
+
# OUTPUT_DIR=/cmlscratch/manlis/test/watermarking-root/input/new_runs
|
27 |
+
|
28 |
+
# model_name="facebook/opt-1.3b"
|
29 |
+
# data_path="/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_200_opt1_3b_evaluation/gen_table_w_metrics.jsonl"
|
30 |
+
|
31 |
+
# model_name='facebook/opt-6.7b'
|
32 |
+
# data_path='/cmlscratch/manlis/test/watermarking-root/input/core_simple_1_50_200_gen/gen_table.jsonl'
|
33 |
+
# data_path=input/core_simple_1_200_1000_gen_prefixes/gen_table_prefixes_200.jsonl
|
34 |
+
model_name='/cmlscratch/manlis/test/watermarking-root/local_model/llama-7b-base'
|
35 |
+
|
36 |
+
mask_model="t5-3b"
|
37 |
+
|
38 |
+
# token_len=200
|
39 |
+
chunk_size=32 # can run 32 when textlen=200
|
40 |
+
pct=0.3
|
41 |
+
# split="no_wm"
|
42 |
+
split='no_wm_paraphrase'
|
43 |
+
|
44 |
+
declare -a commands
|
45 |
+
|
46 |
+
|
47 |
+
# for textlen in 50 100 200;
|
48 |
+
for textlen in 50 100 200;
|
49 |
+
do
|
50 |
+
commands+=( "python detectgpt_main.py \
|
51 |
+
--n_perturbation_list='100' \
|
52 |
+
--do_chunk \
|
53 |
+
--base_model_name=${model_name} \
|
54 |
+
--mask_filling_model_name=${mask_model} \
|
55 |
+
--filter='null' \
|
56 |
+
--data_path=/cmlscratch/manlis/test/watermarking-root/input/core_simple_1_200_1000_no_wm_gpt_p4_prefixes/gen_table_prefixes_${textlen}.jsonl \
|
57 |
+
--token_len=${textlen} \
|
58 |
+
--pct_words_masked=${pct} \
|
59 |
+
--chunk_size=${chunk_size} \
|
60 |
+
--data_split=${split};" )
|
61 |
+
done
|
62 |
+
|
63 |
+
bash -c "${commands[${SLURM_ARRAY_TASK_ID}]}"
|
lm-watermarking-main/watermark_reliability_release/evaluation_pipeline.py
ADDED
@@ -0,0 +1,1330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
3 |
+
# available at https://arxiv.org/abs/2301.10226
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from types import NoneType
|
18 |
+
|
19 |
+
from typing import Union
|
20 |
+
import os
|
21 |
+
import argparse
|
22 |
+
from functools import partial
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
import wandb
|
26 |
+
import torch
|
27 |
+
import numpy as np
|
28 |
+
import sklearn.metrics as metrics
|
29 |
+
|
30 |
+
from datasets import Dataset, Sequence
|
31 |
+
from transformers import DataCollatorWithPadding
|
32 |
+
|
33 |
+
from utils.submitit import str2bool # better bool flag type for argparse
|
34 |
+
from utils.io import read_jsonlines, read_json, write_json, write_jsonlines
|
35 |
+
from utils.notebooks import filter_text_col_length, infer_length_column
|
36 |
+
|
37 |
+
from utils.evaluation import (
|
38 |
+
SUPPORTED_METRICS,
|
39 |
+
NO_CHECK_ARGS,
|
40 |
+
ROC_TEST_STAT_SUFFIXES,
|
41 |
+
FILTER_BY_COLUMNS,
|
42 |
+
conditional_no_check_args,
|
43 |
+
load_oracle_model,
|
44 |
+
evaluate_ppl,
|
45 |
+
load_detector,
|
46 |
+
compute_z_scores,
|
47 |
+
compute_windowed_z_scores,
|
48 |
+
compute_run_len_chsqrd_stats,
|
49 |
+
compute_repetition_diversity,
|
50 |
+
compute_p_sp,
|
51 |
+
compute_coherence,
|
52 |
+
compute_mauve,
|
53 |
+
compute_detect_retrieval,
|
54 |
+
load_tokenizer,
|
55 |
+
concat_rows,
|
56 |
+
)
|
57 |
+
|
58 |
+
print(f"Current huggingface cache dir: {os.environ['HF_HOME']}")
|
59 |
+
|
60 |
+
from datasets import disable_caching
|
61 |
+
|
62 |
+
disable_caching()
|
63 |
+
|
64 |
+
|
65 |
+
def main(args):
|
66 |
+
###########################################################################
|
67 |
+
# Create output dir if it doesn't exist, and warn if it contains metric file
|
68 |
+
###########################################################################
|
69 |
+
gen_table_w_metrics_path = f"{args.output_dir}/gen_table_w_metrics.jsonl"
|
70 |
+
metrics_meta_path = f"{args.output_dir}/gen_table_w_metrics_meta.json"
|
71 |
+
|
72 |
+
print(f"Output dir for this run: {args.output_dir}")
|
73 |
+
# notify if exists
|
74 |
+
if os.path.exists(args.output_dir):
|
75 |
+
print(f"Output dir for this run already exists!")
|
76 |
+
print(f"Contents: {sorted(os.listdir(args.output_dir))}")
|
77 |
+
# warn if metrics file exists
|
78 |
+
if os.path.exists(gen_table_w_metrics_path):
|
79 |
+
if not args.overwrite_output_file:
|
80 |
+
print(
|
81 |
+
f"WARNING: Exiting to avoid overwriting output file. "
|
82 |
+
f"Pass the '--overwrite_output_file' flag to ignore this check."
|
83 |
+
)
|
84 |
+
exit()
|
85 |
+
else:
|
86 |
+
print(
|
87 |
+
f"WARNING: Found existing generation files with metrics added at this output dir. "
|
88 |
+
f"Overwriting anyway :/"
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
# create the output dir where run artifacts are stored
|
92 |
+
os.makedirs(args.output_dir)
|
93 |
+
|
94 |
+
###########################################################################
|
95 |
+
# Parse metrics to log - ppl, zscore, etc
|
96 |
+
###########################################################################
|
97 |
+
|
98 |
+
# check that all metrics are supported
|
99 |
+
metric_support = [metric in SUPPORTED_METRICS for metric in args.evaluation_metrics]
|
100 |
+
assert all(metric_support), (
|
101 |
+
f"Unsupported metric '{args.evaluation_metrics[metric_support.index(False)]}' in"
|
102 |
+
f" {args.evaluation_metrics}. Supported metrics are: {SUPPORTED_METRICS}"
|
103 |
+
)
|
104 |
+
# Hack check that if prefix_lengths exists then the method must be
|
105 |
+
# detect-retrieval (for now) because other methods don't support the
|
106 |
+
# sparse dataset with Nones all over the place
|
107 |
+
if "prefix_lengths" in args.__dict__:
|
108 |
+
# assert args.evaluation_metrics == [
|
109 |
+
# "detect-retrieval"
|
110 |
+
# ], f"Currently, only the detect-retrieval metric supports the prefix_lengths column. "
|
111 |
+
print(
|
112 |
+
f"WARNING: Found prefix_lengths column assuming that this is either retireval or detectgpt"
|
113 |
+
)
|
114 |
+
|
115 |
+
print(f"Evaluation metrics to compute: {args.evaluation_metrics}")
|
116 |
+
|
117 |
+
###########################################################################
|
118 |
+
# Load generations
|
119 |
+
###########################################################################
|
120 |
+
print(f"Input dir for this run: {args.input_dir}")
|
121 |
+
print(f"Loading previously generated outputs for evaluation via oracle model and metrics...")
|
122 |
+
|
123 |
+
# check for the "attacked version" of the gen table first
|
124 |
+
gen_table_meta_path = f"{args.input_dir}/gen_table_attacked_meta.json"
|
125 |
+
gen_table_path = f"{args.input_dir}/gen_table_attacked.jsonl"
|
126 |
+
safe_gen_table_path = f"{args.input_dir}/gen_table_attacked_safe.jsonl"
|
127 |
+
loaded_attacked = True
|
128 |
+
|
129 |
+
attack_variants_exist = [
|
130 |
+
os.path.exists(gen_table_meta_path),
|
131 |
+
os.path.exists(gen_table_path),
|
132 |
+
]
|
133 |
+
if not all(attack_variants_exist):
|
134 |
+
loaded_attacked = False
|
135 |
+
gen_table_meta_path = f"{args.input_dir}/gen_table_meta.json"
|
136 |
+
gen_table_path = f"{args.input_dir}/gen_table.jsonl"
|
137 |
+
safe_gen_table_path = f"{args.input_dir}/gen_table_safe.jsonl"
|
138 |
+
|
139 |
+
assert os.path.exists(
|
140 |
+
gen_table_meta_path
|
141 |
+
), f"failed file check for prev generations metadata json file: {gen_table_meta_path}"
|
142 |
+
assert os.path.exists(
|
143 |
+
gen_table_path
|
144 |
+
), f"failed file check for prev generations jsonl file: {gen_table_path}"
|
145 |
+
|
146 |
+
assert not os.path.exists(safe_gen_table_path), (
|
147 |
+
f"failed for safety bc there is a secondary 'safe' marked file",
|
148 |
+
f" in this dir indicating a possible issue with the generation step. ",
|
149 |
+
)
|
150 |
+
|
151 |
+
cmdline_args = args.__dict__.copy()
|
152 |
+
prev_gen_table_meta = read_json(gen_table_meta_path)
|
153 |
+
|
154 |
+
joined_args = prev_gen_table_meta.copy()
|
155 |
+
for k, v in cmdline_args.items():
|
156 |
+
if v is not None:
|
157 |
+
joined_args.update({k: v})
|
158 |
+
else:
|
159 |
+
print(
|
160 |
+
f"cmdline arg {k} is None, leaving it as the value found in the input metadata: {prev_gen_table_meta[k]}"
|
161 |
+
)
|
162 |
+
|
163 |
+
# check that the args used to generate the prev generations are the same as
|
164 |
+
# the current args, for the intersection of keys
|
165 |
+
if not args.overwrite_args:
|
166 |
+
# update the no check args based on the current state of args
|
167 |
+
current_no_check_args = conditional_no_check_args(
|
168 |
+
NO_CHECK_ARGS, args.evaluation_metrics, args
|
169 |
+
)
|
170 |
+
|
171 |
+
for key in prev_gen_table_meta.keys():
|
172 |
+
if key in current_no_check_args:
|
173 |
+
continue
|
174 |
+
assert joined_args[key] == prev_gen_table_meta[key], (
|
175 |
+
f"failed for safety bc after merging the prev metadata with "
|
176 |
+
f"the current cmdline args, values for '{key}' are not the same. "
|
177 |
+
f"in metadata: {prev_gen_table_meta[key]}, passed: {cmdline_args[key]}. "
|
178 |
+
f"Pass the '--overwrite_args' flag to ignore this check."
|
179 |
+
)
|
180 |
+
|
181 |
+
args = argparse.Namespace(**joined_args)
|
182 |
+
gen_table = [ex for ex in read_jsonlines(gen_table_path)]
|
183 |
+
if args.limit_rows == -1:
|
184 |
+
gen_table_ds = Dataset.from_list(gen_table)
|
185 |
+
else:
|
186 |
+
gen_table_ds = Dataset.from_list(gen_table[: args.limit_rows])
|
187 |
+
|
188 |
+
###########################################################################
|
189 |
+
# Extract the seeding scheme fine grained parameters
|
190 |
+
###########################################################################
|
191 |
+
from utils.evaluation import scheme_hparam_extractor
|
192 |
+
|
193 |
+
args.__dict__.update(scheme_hparam_extractor(args.seeding_scheme))
|
194 |
+
|
195 |
+
print(f"seeding_scheme: {args.seeding_scheme}")
|
196 |
+
print(f"prf_type: {args.prf_type}")
|
197 |
+
print(f"anchored: {args.anchored}")
|
198 |
+
print(f"context_width: {args.context_width}")
|
199 |
+
print(f"self_salt: {args.self_salt}")
|
200 |
+
|
201 |
+
###########################################################################
|
202 |
+
# Concat logic for multiple generations
|
203 |
+
###########################################################################
|
204 |
+
|
205 |
+
if args.concat_rows != 0:
|
206 |
+
assert isinstance(args.concat_rows, int), f"Invalid concat_rows arg: {args.concat_rows}. "
|
207 |
+
|
208 |
+
# set to all rows if -1
|
209 |
+
if args.concat_rows == -1:
|
210 |
+
args.concat_rows = len(gen_table_ds)
|
211 |
+
|
212 |
+
if args.shuffle_before_concat:
|
213 |
+
print(f"Shuffling the gen table before concatenating every {args.concat_rows} rows...")
|
214 |
+
gen_table_ds = gen_table_ds.shuffle()
|
215 |
+
|
216 |
+
print(f"Concatenating every {args.concat_rows} rows of the gen table...")
|
217 |
+
|
218 |
+
# we concat all cols in OUTPUT_TEXT_COLUMN_NAMES
|
219 |
+
# and update the length col to reflect the new length
|
220 |
+
# which means we need to tokenize the new text temporarily
|
221 |
+
# to get the new length
|
222 |
+
|
223 |
+
tokenizer = load_tokenizer(args)
|
224 |
+
|
225 |
+
concat_partial = partial(concat_rows, tokenizer=tokenizer, args=args)
|
226 |
+
|
227 |
+
# manually write a btach loop bc hf doesnt support returning fewer rows than input
|
228 |
+
concatenated_rows = []
|
229 |
+
for i in tqdm(range(0, len(gen_table_ds), args.concat_rows)):
|
230 |
+
batch = gen_table_ds[i : i + args.concat_rows]
|
231 |
+
concatenated_rows.append(concat_partial(batch))
|
232 |
+
gen_table_concated_ds = Dataset.from_list(concatenated_rows)
|
233 |
+
|
234 |
+
# overwrite the args.max_new_tokens to reflect the implicit new target length T
|
235 |
+
# which is concat_rows * max_new_tokens
|
236 |
+
args.max_new_tokens = args.concat_rows * args.max_new_tokens
|
237 |
+
|
238 |
+
# write the dataset out in the same filename as the original
|
239 |
+
# but check that the input dir is different from the output dir
|
240 |
+
assert (
|
241 |
+
args.input_dir != args.output_dir
|
242 |
+
), f"Input dir and output dir must be different to write out the result of concat rows."
|
243 |
+
|
244 |
+
if loaded_attacked:
|
245 |
+
concat_meta_path = f"{args.output_dir}/gen_table_attacked_meta.json"
|
246 |
+
concat_gen_table_path = f"{args.output_dir}/gen_table_attacked.jsonl"
|
247 |
+
else:
|
248 |
+
concat_meta_path = f"{args.output_dir}/gen_table_meta.json"
|
249 |
+
concat_gen_table_path = f"{args.output_dir}/gen_table.jsonl"
|
250 |
+
|
251 |
+
write_json(args.__dict__, concat_meta_path, indent=4)
|
252 |
+
gen_table_concated_lst = [ex for ex in gen_table_concated_ds]
|
253 |
+
write_jsonlines(gen_table_concated_lst, concat_gen_table_path)
|
254 |
+
else:
|
255 |
+
gen_table_concated_ds = gen_table_ds
|
256 |
+
|
257 |
+
###########################################################################
|
258 |
+
# Additional args setup
|
259 |
+
###########################################################################
|
260 |
+
# if target_T is not specified, use max_new_tokens (which will be in the reloaded gen metadata)
|
261 |
+
# and potentially overwritten by the concat logic above
|
262 |
+
if args.target_T == 0:
|
263 |
+
args.target_T = args.max_new_tokens
|
264 |
+
|
265 |
+
# storing slurm info to allow auditing logfiles
|
266 |
+
# note this is set after the metadata check to ignore overwriting
|
267 |
+
args.SLURM_JOB_ID = os.getenv("SLURM_JOB_ID")
|
268 |
+
args.SLURM_ARRAY_JOB_ID = os.getenv("SLURM_ARRAY_JOB_ID")
|
269 |
+
args.SLURM_ARRAY_TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
|
270 |
+
|
271 |
+
###########################################################################
|
272 |
+
# Start logging, we wait to do this until after loading the generations
|
273 |
+
# so that we can log the args used to generate them unioned with the
|
274 |
+
# cmdline args
|
275 |
+
###########################################################################
|
276 |
+
if args.wandb:
|
277 |
+
# start a new wandb run to track this experiment, will send data to it
|
278 |
+
run = wandb.init(
|
279 |
+
# set the wandb project where this run will be logged
|
280 |
+
project=args.wandb_project,
|
281 |
+
entity=args.wandb_entity,
|
282 |
+
name=f"{args.run_name}",
|
283 |
+
# track hyperparameters and run metadata
|
284 |
+
config=args,
|
285 |
+
tags=args.wandb_tags,
|
286 |
+
)
|
287 |
+
|
288 |
+
###########################################################################
|
289 |
+
# Perplexity (PPL) evaluation
|
290 |
+
# NOTE: basically requires a model on gpu, or is extremely slow
|
291 |
+
###########################################################################
|
292 |
+
if "ppl" in args.evaluation_metrics:
|
293 |
+
assert args.oracle_model_name_or_path, "PPL metric requires oracle model."
|
294 |
+
|
295 |
+
# Load the oracle model for PPL measurement
|
296 |
+
oracle_model, oracle_tokenizer, _ = load_oracle_model(args)
|
297 |
+
|
298 |
+
# construct the collator
|
299 |
+
data_collator = DataCollatorWithPadding(
|
300 |
+
tokenizer=oracle_tokenizer, padding=True, pad_to_multiple_of=8
|
301 |
+
)
|
302 |
+
|
303 |
+
# construct fluency/ppl partial
|
304 |
+
evaluate_ppl_partial = partial(
|
305 |
+
evaluate_ppl,
|
306 |
+
oracle_model_name=args.oracle_model_name_or_path,
|
307 |
+
oracle_model=oracle_model,
|
308 |
+
oracle_tokenizer=oracle_tokenizer,
|
309 |
+
data_collator=data_collator,
|
310 |
+
)
|
311 |
+
|
312 |
+
print(f"Computing metrics on model generations: {gen_table_concated_ds}")
|
313 |
+
|
314 |
+
gen_table_w_ppl_ds = gen_table_concated_ds.map(
|
315 |
+
evaluate_ppl_partial,
|
316 |
+
batched=True,
|
317 |
+
batch_size=args.ppl_batch_size,
|
318 |
+
load_from_cache_file=False,
|
319 |
+
keep_in_memory=True,
|
320 |
+
)
|
321 |
+
|
322 |
+
# clear the model just for fun
|
323 |
+
oracle_model = oracle_model.to(torch.device("cpu"))
|
324 |
+
del oracle_model
|
325 |
+
else:
|
326 |
+
gen_table_w_ppl_ds = gen_table_concated_ds
|
327 |
+
|
328 |
+
###########################################################################
|
329 |
+
# Cheap to load, and required for all detectors so load it first
|
330 |
+
watermark_detector = load_detector(args)
|
331 |
+
|
332 |
+
# Map setup for all dataset operations:
|
333 |
+
map_setup = dict(batched=False, load_from_cache_file=False)
|
334 |
+
###########################################################################
|
335 |
+
# z-score evaluation
|
336 |
+
# NOTE: requires a gpu because if original source of watermark randomness,
|
337 |
+
# RNG, is gpu based, then detector should be on gpu as well
|
338 |
+
###########################################################################
|
339 |
+
if "z-score" in args.evaluation_metrics:
|
340 |
+
# set up the partial
|
341 |
+
compute_z_scores_partial = partial(
|
342 |
+
compute_z_scores,
|
343 |
+
watermark_detector=watermark_detector,
|
344 |
+
args=args,
|
345 |
+
)
|
346 |
+
|
347 |
+
gen_table_w_zscore_ds = gen_table_w_ppl_ds.map(
|
348 |
+
compute_z_scores_partial, **map_setup, desc="Computing z-scores"
|
349 |
+
)
|
350 |
+
|
351 |
+
else:
|
352 |
+
gen_table_w_zscore_ds = gen_table_w_ppl_ds
|
353 |
+
|
354 |
+
###########################################################################
|
355 |
+
# Windowed z-score evaluation
|
356 |
+
###########################################################################
|
357 |
+
|
358 |
+
if "windowed-z-score" in args.evaluation_metrics:
|
359 |
+
# set up the windowed partial
|
360 |
+
compute_windowed_z_scores_partial = partial(
|
361 |
+
compute_windowed_z_scores,
|
362 |
+
watermark_detector=watermark_detector,
|
363 |
+
args=args,
|
364 |
+
)
|
365 |
+
|
366 |
+
gen_table_w_windowed_zscore_ds = gen_table_w_zscore_ds.map(
|
367 |
+
compute_windowed_z_scores_partial, **map_setup, desc="Computing windowed z-scores"
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
gen_table_w_windowed_zscore_ds = gen_table_w_zscore_ds
|
371 |
+
|
372 |
+
###########################################################################
|
373 |
+
# run-len-chisqrd evaluation
|
374 |
+
###########################################################################
|
375 |
+
if "run-len-chisqrd" in args.evaluation_metrics:
|
376 |
+
assert "w_wm_output_green_token_mask" in gen_table_w_windowed_zscore_ds.column_names, (
|
377 |
+
f"Currently, run-len-chisqrd metric requires the green token masks to be computed previously "
|
378 |
+
f"by one of the z-score metrics."
|
379 |
+
)
|
380 |
+
# this ^ is unused currently, but we will need it to remove the assert condition above
|
381 |
+
|
382 |
+
# set up the run len chisqrd partial
|
383 |
+
compute_run_len_chisqrd_partial = partial(
|
384 |
+
compute_run_len_chsqrd_stats,
|
385 |
+
watermark_detector=watermark_detector,
|
386 |
+
args=args,
|
387 |
+
)
|
388 |
+
|
389 |
+
gen_table_w_run_len_chisqrd_ds = gen_table_w_windowed_zscore_ds.map(
|
390 |
+
compute_run_len_chisqrd_partial, **map_setup, desc="Computing runlength tests"
|
391 |
+
)
|
392 |
+
else:
|
393 |
+
gen_table_w_run_len_chisqrd_ds = gen_table_w_windowed_zscore_ds
|
394 |
+
|
395 |
+
###########################################################################
|
396 |
+
# Diversity and Repetition evaluation
|
397 |
+
###########################################################################
|
398 |
+
|
399 |
+
if "repetition" in args.evaluation_metrics or "diversity" in args.evaluation_metrics:
|
400 |
+
# set up the partial
|
401 |
+
compute_repetition_partial = partial(
|
402 |
+
compute_repetition_diversity,
|
403 |
+
include_repetition=("repetition" in args.evaluation_metrics),
|
404 |
+
include_diversity=("diversity" in args.evaluation_metrics),
|
405 |
+
)
|
406 |
+
|
407 |
+
gen_table_w_repetition_ds = gen_table_w_run_len_chisqrd_ds.map(
|
408 |
+
compute_repetition_partial, **map_setup, desc="Computing text repetition and diversity"
|
409 |
+
)
|
410 |
+
else:
|
411 |
+
gen_table_w_repetition_ds = gen_table_w_run_len_chisqrd_ds
|
412 |
+
|
413 |
+
###########################################################################
|
414 |
+
# P-SP evaluation
|
415 |
+
###########################################################################
|
416 |
+
|
417 |
+
if "p-sp" in args.evaluation_metrics:
|
418 |
+
print(f"Loading the P-SP model and computing P-SP")
|
419 |
+
gen_table_w_p_sp_ds = compute_p_sp(gen_table_w_repetition_ds)
|
420 |
+
else:
|
421 |
+
gen_table_w_p_sp_ds = gen_table_w_repetition_ds
|
422 |
+
|
423 |
+
###########################################################################
|
424 |
+
# Coherence evaluation
|
425 |
+
###########################################################################
|
426 |
+
|
427 |
+
if "coherence" in args.evaluation_metrics:
|
428 |
+
print(f"Computing coherence")
|
429 |
+
gen_table_w_coherence_ds = compute_coherence(gen_table_w_p_sp_ds)
|
430 |
+
else:
|
431 |
+
gen_table_w_coherence_ds = gen_table_w_p_sp_ds
|
432 |
+
|
433 |
+
###########################################################################
|
434 |
+
# Mauve evaluation
|
435 |
+
###########################################################################
|
436 |
+
|
437 |
+
if "mauve" in args.evaluation_metrics:
|
438 |
+
print(f"Computing mauve")
|
439 |
+
gen_table_w_mauve_ds = compute_mauve(gen_table_w_coherence_ds)
|
440 |
+
else:
|
441 |
+
gen_table_w_mauve_ds = gen_table_w_coherence_ds
|
442 |
+
|
443 |
+
###########################################################################
|
444 |
+
# Retrieval detection
|
445 |
+
###########################################################################
|
446 |
+
|
447 |
+
if "detect-retrieval" in args.evaluation_metrics:
|
448 |
+
print(f"Computing detect retrieval")
|
449 |
+
gen_table_w_detect_retrieval_ds = compute_detect_retrieval(gen_table_w_mauve_ds, args=args)
|
450 |
+
else:
|
451 |
+
gen_table_w_detect_retrieval_ds = gen_table_w_mauve_ds
|
452 |
+
|
453 |
+
if "prefix_length" in gen_table_w_detect_retrieval_ds.features:
|
454 |
+
if "no_wm_output_retrieval_score" in gen_table_w_detect_retrieval_ds.features:
|
455 |
+
print("Avg scores at each prefix length for no_wm_output:")
|
456 |
+
print(
|
457 |
+
gen_table_w_detect_retrieval_ds.to_pandas()
|
458 |
+
.groupby("prefix_length")["no_wm_output_retrieval_score"]
|
459 |
+
.describe()
|
460 |
+
)
|
461 |
+
if "w_wm_output_retrieval_score" in gen_table_w_detect_retrieval_ds.features:
|
462 |
+
print("Avg scores at each prefix length for w_wm_output:")
|
463 |
+
print(
|
464 |
+
gen_table_w_detect_retrieval_ds.to_pandas()
|
465 |
+
.groupby("prefix_length")["w_wm_output_retrieval_score"]
|
466 |
+
.describe()
|
467 |
+
)
|
468 |
+
if "w_wm_output_attacked_retrieval_score" in gen_table_w_detect_retrieval_ds.features:
|
469 |
+
print("Avg scores at each prefix length for no_wm_output_attacked:")
|
470 |
+
print(
|
471 |
+
gen_table_w_detect_retrieval_ds.to_pandas()
|
472 |
+
.groupby("prefix_length")["w_wm_output_attacked_retrieval_score"]
|
473 |
+
.describe()
|
474 |
+
)
|
475 |
+
|
476 |
+
###########################################################################
|
477 |
+
# Detectgpt detection
|
478 |
+
###########################################################################
|
479 |
+
if "detectgpt" in args.evaluation_metrics:
|
480 |
+
assert args.evaluation_metrics == ["detectgpt"], (
|
481 |
+
f"Detectgpt must be run separately from other metrics. "
|
482 |
+
f"Found: {args.evaluation_metrics}. "
|
483 |
+
)
|
484 |
+
# check that the right score column exists
|
485 |
+
assert any(
|
486 |
+
["detectgpt_score" in col for col in gen_table_w_detect_retrieval_ds.column_names]
|
487 |
+
), (
|
488 |
+
f"Detectgpt metric requires the detectgpt_score column to be computed previously "
|
489 |
+
f"but no such cols exist in this file."
|
490 |
+
)
|
491 |
+
print(
|
492 |
+
f"Evaluating detectgpt by simply computing ROC-AUC metrics on the scores that already exist"
|
493 |
+
)
|
494 |
+
gen_table_w_metrics_ds = gen_table_w_detect_retrieval_ds
|
495 |
+
|
496 |
+
# if we loaded an attack file, since detect gpt only outputs a baseline score col
|
497 |
+
# and a no_wm_output score col (which is implcitly the attack col if the file was attacked)
|
498 |
+
# we need to add the attacked score col to the dataset, and remove the no_wm score col
|
499 |
+
if loaded_attacked:
|
500 |
+
for suff in ["100_d", "100_z"]:
|
501 |
+
gen_table_w_metrics_ds = gen_table_w_metrics_ds.add_column(
|
502 |
+
f"w_wm_output_attacked_detectgpt_score_{suff}",
|
503 |
+
gen_table_w_metrics_ds[f"no_wm_output_detectgpt_score_{suff}"],
|
504 |
+
)
|
505 |
+
gen_table_w_metrics_ds = gen_table_w_metrics_ds.remove_columns(
|
506 |
+
[f"no_wm_output_detectgpt_score_{suff}"]
|
507 |
+
)
|
508 |
+
else:
|
509 |
+
###########################################################################
|
510 |
+
# Write the final dataset out to disk in jsonl format
|
511 |
+
# with the metrics added
|
512 |
+
###########################################################################
|
513 |
+
|
514 |
+
# last applied metric, NOTE which will of course change as more are added
|
515 |
+
gen_table_w_metrics_ds = gen_table_w_detect_retrieval_ds
|
516 |
+
|
517 |
+
# write the metadata file, which is a union of the previous metadata
|
518 |
+
# and the current cmdline args
|
519 |
+
write_json(args.__dict__, metrics_meta_path, indent=4)
|
520 |
+
|
521 |
+
gen_table_w_metrics_lst = [ex for ex in gen_table_w_metrics_ds]
|
522 |
+
write_jsonlines(gen_table_w_metrics_lst, gen_table_w_metrics_path)
|
523 |
+
|
524 |
+
###########################################################################
|
525 |
+
# Log the metric series to wandb
|
526 |
+
###########################################################################
|
527 |
+
# log the metrics to wandb
|
528 |
+
if args.wandb:
|
529 |
+
# find cols that should be logged in a table
|
530 |
+
tabular_column_types = ["string", "bool"]
|
531 |
+
tabular_column_names = [
|
532 |
+
name
|
533 |
+
for name, _ in filter(
|
534 |
+
lambda tup: tup[1].dtype in tabular_column_types,
|
535 |
+
gen_table_w_metrics_ds.features.items(),
|
536 |
+
)
|
537 |
+
]
|
538 |
+
# the rest should be logged as series
|
539 |
+
series_column_names = [
|
540 |
+
name
|
541 |
+
for name, _ in filter(
|
542 |
+
lambda tup: tup[1].dtype not in tabular_column_types,
|
543 |
+
gen_table_w_metrics_ds.features.items(),
|
544 |
+
)
|
545 |
+
]
|
546 |
+
|
547 |
+
for metric_name in series_column_names:
|
548 |
+
# summarize series metrics as mean by default
|
549 |
+
wandb.define_metric(metric_name, summary="mean")
|
550 |
+
|
551 |
+
if args.log_raw_series:
|
552 |
+
# log the raw series
|
553 |
+
for example in tqdm(
|
554 |
+
gen_table_w_metrics_ds.remove_columns(tabular_column_names),
|
555 |
+
desc="Logging series metrics to wandb",
|
556 |
+
):
|
557 |
+
run.log(example)
|
558 |
+
|
559 |
+
if args.log_raw_tabular:
|
560 |
+
# log the raw tabular data
|
561 |
+
# but also include the dataset index as a column
|
562 |
+
series_column_names.remove("idx")
|
563 |
+
table = wandb.Table(
|
564 |
+
dataframe=gen_table_w_metrics_ds.remove_columns(series_column_names).to_pandas()
|
565 |
+
)
|
566 |
+
run.log({"output_table": table})
|
567 |
+
|
568 |
+
###########################################################################
|
569 |
+
# Filter rows, then log means to wandb
|
570 |
+
###########################################################################
|
571 |
+
assert (
|
572 |
+
args.target_T - args.lower_tolerance_T
|
573 |
+
) >= 0, "target_T - lower_tolerance_T must be >= 0"
|
574 |
+
|
575 |
+
target_T = args.target_T
|
576 |
+
lower_tolerance = args.lower_tolerance_T
|
577 |
+
upper_tolerance = args.upper_tolerance_T
|
578 |
+
filtered_table = gen_table_w_metrics_ds.to_pandas() # explictly convert lists
|
579 |
+
|
580 |
+
for col in args.filter_by_columns:
|
581 |
+
length_col_name = infer_length_column(col, filtered_table, args=args)
|
582 |
+
filtered_table = filter_text_col_length(
|
583 |
+
filtered_table,
|
584 |
+
text_col_name=length_col_name,
|
585 |
+
count_suffix="",
|
586 |
+
upper_T=target_T + upper_tolerance,
|
587 |
+
lower_T=target_T - lower_tolerance,
|
588 |
+
)
|
589 |
+
|
590 |
+
# Save filtered mean values:
|
591 |
+
for metric_name in series_column_names:
|
592 |
+
filtered_name = f"f_{target_T}p{upper_tolerance}m{lower_tolerance}_{metric_name}"
|
593 |
+
try:
|
594 |
+
run.summary[f"{filtered_name}_mean"] = filtered_table[metric_name].mean()
|
595 |
+
run.summary[f"{filtered_name}_std"] = filtered_table[metric_name].std()
|
596 |
+
except TypeError:
|
597 |
+
two_dim_mean = filtered_table[metric_name].apply(np.mean).mean()
|
598 |
+
|
599 |
+
###########################################################################
|
600 |
+
# Compute ROC-AUC and send to wandb
|
601 |
+
###########################################################################
|
602 |
+
try:
|
603 |
+
test_stats = args.roc_test_stat
|
604 |
+
if isinstance(test_stats, str):
|
605 |
+
test_stats = [test_stats]
|
606 |
+
for test_stat in test_stats:
|
607 |
+
for attacked in [True, False]:
|
608 |
+
try:
|
609 |
+
roc_auc, fpr, tpr, thresholds, tpr_at_X_fpr = _roc_metrics_for_wandb(
|
610 |
+
filtered_table, test_stat, attacked=attacked
|
611 |
+
)
|
612 |
+
run.summary[
|
613 |
+
f"{'attacked_' if attacked else ''}{test_stat}_roc_auc"
|
614 |
+
] = roc_auc
|
615 |
+
run.summary[
|
616 |
+
f"{'attacked_' if attacked else ''}{test_stat}_tpr_at_X_fpr"
|
617 |
+
] = tpr_at_X_fpr
|
618 |
+
|
619 |
+
# for tp, fp, thr in tqdm(
|
620 |
+
# zip(tpr, fpr, thresholds), desc="Logging ROC curve"
|
621 |
+
# ):
|
622 |
+
# run.log(
|
623 |
+
# {
|
624 |
+
# f"{'attacked_' if attacked else ''}{test_stat}_fpr": fp,
|
625 |
+
# f"{'attacked_' if attacked else ''}{test_stat}_tpr": tp,
|
626 |
+
# f"{'attacked_' if attacked else ''}thr": thr,
|
627 |
+
# }
|
628 |
+
# )
|
629 |
+
data = [[x, y] for (x, y) in zip(fpr, tpr)]
|
630 |
+
table = wandb.Table(data=data, columns=["fpr", "tpr"])
|
631 |
+
run.log(
|
632 |
+
{
|
633 |
+
f"{'attacked_' if attacked else ''}{test_stat}": wandb.plot.line(
|
634 |
+
table,
|
635 |
+
"fpr",
|
636 |
+
"tpr",
|
637 |
+
title=f"ROC ({test_stat}{',attacked' if attacked else ',clean'})",
|
638 |
+
)
|
639 |
+
}
|
640 |
+
)
|
641 |
+
print(f"Successfully logged ROC-AUC metrics for {test_stat}.")
|
642 |
+
|
643 |
+
except Exception as e:
|
644 |
+
if args.verbose:
|
645 |
+
print(e)
|
646 |
+
print(
|
647 |
+
f"Failed to log ROC-AUC metrics for {'attacked output' if attacked else ''} {test_stat}."
|
648 |
+
f"Metric probably was not computed and or attack col not present."
|
649 |
+
)
|
650 |
+
except Exception as e:
|
651 |
+
if args.verbose:
|
652 |
+
print(f"Exception: {e}")
|
653 |
+
print(
|
654 |
+
f"Failed to log ROC-AUC metrics. ",
|
655 |
+
f"Make sure the test statistic required for detection ({test_stat}) has been computed!",
|
656 |
+
)
|
657 |
+
|
658 |
+
################################################################################
|
659 |
+
# NOTE we do that ^^^ basic ROC logic first because it's faster
|
660 |
+
# as well as the manual prefix lengths at T logic bc that's also faster
|
661 |
+
################################################################################
|
662 |
+
|
663 |
+
# Handle z @ T but for the retrieval and detectgpt scores that are evaluated
|
664 |
+
# manually at each prefix length. Use groupby to compute the mean and std
|
665 |
+
# for each prefix length for any of the feats that have retrieval_score in them,
|
666 |
+
# then log those pairs to wandb.
|
667 |
+
at_T_df = gen_table_w_metrics_ds.to_pandas()
|
668 |
+
|
669 |
+
for name, feat in gen_table_w_metrics_ds.features.items():
|
670 |
+
if "retrieval_score" in name and "prefix_length" in at_T_df.columns:
|
671 |
+
# compute the mean and std for each prefix length
|
672 |
+
# and log those pairs to wandb
|
673 |
+
df_view = at_T_df.groupby("prefix_length")[name].describe()[["mean", "std"]]
|
674 |
+
T_indices = df_view.index
|
675 |
+
|
676 |
+
# for idx, (mean, std) in df_view.iterrows():
|
677 |
+
# run.log(data={f"{name}_mean": mean, f"{name}_std": std, "idx_T": idx})
|
678 |
+
# log this triple as a table instead like the ROC curve above
|
679 |
+
# where the first two are plotted and the third is the x axis
|
680 |
+
data = [[x, y, z] for x, (y, z) in df_view.iterrows()]
|
681 |
+
table = wandb.Table(data=data, columns=["idx_T", "mean", "std"])
|
682 |
+
# compute stderr from std
|
683 |
+
table.add_column(
|
684 |
+
"stderr",
|
685 |
+
[
|
686 |
+
std / np.sqrt(len(at_T_df[at_T_df["prefix_length"] == idx]))
|
687 |
+
for idx, std in zip(T_indices, df_view["std"])
|
688 |
+
],
|
689 |
+
)
|
690 |
+
# first log mean
|
691 |
+
run.log({f"{name}": wandb.plot.line(table, "idx_T", "mean", title=f"{name} mean")})
|
692 |
+
# then log std err
|
693 |
+
run.log(
|
694 |
+
{
|
695 |
+
f"{name}_stderr": wandb.plot.line(
|
696 |
+
table, "idx_T", "stderr", title=f"{name} stderr"
|
697 |
+
)
|
698 |
+
}
|
699 |
+
)
|
700 |
+
|
701 |
+
# also compute an AUC at each prefix len idx by treating the name col as the positives
|
702 |
+
# and the baseline_completion_retrieval_score as the negatives
|
703 |
+
# then log those pairs to wandb
|
704 |
+
if name != "baseline_completion_retrieval_score":
|
705 |
+
pos_negs_at_T = at_T_df.groupby("prefix_length")[
|
706 |
+
[name, "baseline_completion_retrieval_score"]
|
707 |
+
]
|
708 |
+
# auc_at_T = []
|
709 |
+
# tpr_at_X_fpr = []
|
710 |
+
all_aucs, all_tpr_at_X_fpr = [], []
|
711 |
+
for idx, sub_df in pos_negs_at_T:
|
712 |
+
pos = sub_df[name]
|
713 |
+
neg = sub_df["baseline_completion_retrieval_score"]
|
714 |
+
# convert to arrays and remove nans
|
715 |
+
pos = pos.to_numpy()[~np.isnan(pos.to_numpy())]
|
716 |
+
neg = neg.to_numpy()[~np.isnan(neg.to_numpy())]
|
717 |
+
|
718 |
+
fpr, tpr, thresholds = metrics.roc_curve(
|
719 |
+
np.concatenate([np.ones_like(pos), np.zeros_like(neg)]), # labels
|
720 |
+
np.concatenate([pos, neg]), # scores
|
721 |
+
pos_label=1,
|
722 |
+
)
|
723 |
+
auc = metrics.auc(fpr, tpr)
|
724 |
+
try:
|
725 |
+
tpr_at_X_fpr = tpr[np.where(fpr < 1e-3)[0][-1]]
|
726 |
+
except IndexError:
|
727 |
+
tpr_at_X_fpr = float("NaN")
|
728 |
+
all_aucs.append(auc)
|
729 |
+
all_tpr_at_X_fpr.append(tpr_at_X_fpr)
|
730 |
+
|
731 |
+
# run.log(data={f"{name}_auc_at_T": auc, "idx_T": idx})
|
732 |
+
# log this triple as a table instead like the AUC and tpr at X fpr below
|
733 |
+
# where the first two are plotted and the third is the x axis
|
734 |
+
data = [
|
735 |
+
[x, y, z] for x, (y, z) in zip(T_indices, zip(all_aucs, all_tpr_at_X_fpr))
|
736 |
+
]
|
737 |
+
table = wandb.Table(data=data, columns=["idx_T", "aucs", "tpr_at"])
|
738 |
+
run.log(
|
739 |
+
{
|
740 |
+
f"{name}_aucs": wandb.plot.line(
|
741 |
+
table, "idx_T", "aucs", title=f"{name} aucs"
|
742 |
+
)
|
743 |
+
}
|
744 |
+
)
|
745 |
+
run.log(
|
746 |
+
{
|
747 |
+
f"{name}_tpr_at": wandb.plot.line(
|
748 |
+
table, "idx_T", "tpr_at", title=f"{name} tpr_at"
|
749 |
+
)
|
750 |
+
}
|
751 |
+
)
|
752 |
+
|
753 |
+
elif "detectgpt_score" in name and "prefix_length" in at_T_df.columns:
|
754 |
+
# this covers detectgpt_score_100_d and variants
|
755 |
+
# compute the mean and std for each prefix length
|
756 |
+
# and log those pairs to wandb
|
757 |
+
df_view = at_T_df.groupby("prefix_length")[name].describe()[["mean", "std"]]
|
758 |
+
T_indices = df_view.index
|
759 |
+
|
760 |
+
# for idx, (mean, std) in df_view.iterrows():
|
761 |
+
# run.log(data={f"{name}_mean": mean, f"{name}_std": std, "idx_T": idx})
|
762 |
+
# log this triple as a table instead like the ROC curve above
|
763 |
+
# where the first two are plotted and the third is the x axis
|
764 |
+
data = [[x, y, z] for x, (y, z) in df_view.iterrows()]
|
765 |
+
table = wandb.Table(data=data, columns=["idx_T", "mean", "std"])
|
766 |
+
|
767 |
+
# compute stderr from std
|
768 |
+
table.add_column(
|
769 |
+
"stderr",
|
770 |
+
[
|
771 |
+
std / np.sqrt(len(at_T_df[at_T_df["prefix_length"] == idx]))
|
772 |
+
for idx, std in zip(T_indices, df_view["std"])
|
773 |
+
],
|
774 |
+
)
|
775 |
+
# first log mean
|
776 |
+
run.log({f"{name}": wandb.plot.line(table, "idx_T", "mean", title=f"{name} mean")})
|
777 |
+
# then log std err
|
778 |
+
run.log(
|
779 |
+
{
|
780 |
+
f"{name}_stderr": wandb.plot.line(
|
781 |
+
table, "idx_T", "stderr", title=f"{name} stderr"
|
782 |
+
)
|
783 |
+
}
|
784 |
+
)
|
785 |
+
|
786 |
+
# also compute an AUC at each prefix len idx by treating the name col as the positives
|
787 |
+
# and the baseline_completion_retrieval_score as the negatives
|
788 |
+
# then log those pairs to wandb
|
789 |
+
if "baseline_completion_detectgpt_score" not in name:
|
790 |
+
# check which suffix this is in ["_100_d", "_100_z"]
|
791 |
+
# and use that to set the baseline/falst col
|
792 |
+
if name.endswith("_100_d"):
|
793 |
+
baseline_col = "baseline_completion_detectgpt_score_100_d"
|
794 |
+
elif name.endswith("_100_z"):
|
795 |
+
baseline_col = "baseline_completion_detectgpt_score_100_z"
|
796 |
+
pos_negs_at_T = at_T_df.groupby("prefix_length")[[name, baseline_col]]
|
797 |
+
# auc_at_T = []
|
798 |
+
# tpr_at_X_fpr = []
|
799 |
+
all_aucs, all_tpr_at_X_fpr = [], []
|
800 |
+
for idx, sub_df in pos_negs_at_T:
|
801 |
+
pos = sub_df[name]
|
802 |
+
neg = sub_df[baseline_col]
|
803 |
+
# convert to arrays and remove nans
|
804 |
+
pos = pos.to_numpy()[~np.isnan(pos.to_numpy())]
|
805 |
+
neg = neg.to_numpy()[~np.isnan(neg.to_numpy())]
|
806 |
+
|
807 |
+
fpr, tpr, thresholds = metrics.roc_curve(
|
808 |
+
np.concatenate([np.ones_like(pos), np.zeros_like(neg)]), # labels
|
809 |
+
np.concatenate([pos, neg]), # scores
|
810 |
+
pos_label=1,
|
811 |
+
)
|
812 |
+
auc = metrics.auc(fpr, tpr)
|
813 |
+
try:
|
814 |
+
tpr_at_X_fpr = tpr[np.where(fpr < 1e-3)[0][-1]]
|
815 |
+
except IndexError:
|
816 |
+
tpr_at_X_fpr = float("NaN")
|
817 |
+
all_aucs.append(auc)
|
818 |
+
all_tpr_at_X_fpr.append(tpr_at_X_fpr)
|
819 |
+
|
820 |
+
# run.log(data={f"{name}_auc_at_T": auc, "idx_T": idx})
|
821 |
+
# log this triple as a table instead like the AUC and tpr at X fpr below
|
822 |
+
# where the first two are plotted and the third is the x axis
|
823 |
+
data = [
|
824 |
+
[x, y, z] for x, (y, z) in zip(T_indices, zip(all_aucs, all_tpr_at_X_fpr))
|
825 |
+
]
|
826 |
+
table = wandb.Table(data=data, columns=["idx_T", "aucs", "tpr_at"])
|
827 |
+
run.log(
|
828 |
+
{
|
829 |
+
f"{name}_aucs": wandb.plot.line(
|
830 |
+
table, "idx_T", "aucs", title=f"{name} aucs"
|
831 |
+
)
|
832 |
+
}
|
833 |
+
)
|
834 |
+
run.log(
|
835 |
+
{
|
836 |
+
f"{name}_tpr_at": wandb.plot.line(
|
837 |
+
table, "idx_T", "tpr_at", title=f"{name} tpr_at"
|
838 |
+
)
|
839 |
+
}
|
840 |
+
)
|
841 |
+
|
842 |
+
###########################################################################
|
843 |
+
# Compute our @ T detection metrics and send to wandb
|
844 |
+
###########################################################################
|
845 |
+
|
846 |
+
# Merge z_at_T and other sequence metrics so they can be shown in wandb:
|
847 |
+
for name, feat in gen_table_w_metrics_ds.features.items():
|
848 |
+
if isinstance(feat, Sequence):
|
849 |
+
max_feat_seq_len = max([len(l) for l in gen_table_w_metrics_ds[name]])
|
850 |
+
merging_seq = np.zeros(max_feat_seq_len)
|
851 |
+
counts = np.zeros(max_feat_seq_len)
|
852 |
+
proto_variance = np.zeros(max_feat_seq_len)
|
853 |
+
for entry in gen_table_w_metrics_ds[name]:
|
854 |
+
len_seq = len(entry)
|
855 |
+
delta = entry * counts[:len_seq] - merging_seq[:len_seq]
|
856 |
+
# Accumulate ragged sum over entries:
|
857 |
+
counts[:len_seq] += 1
|
858 |
+
merging_seq[:len_seq] += entry[: len(merging_seq)]
|
859 |
+
# Compute ragged, running variance via Welford:
|
860 |
+
gamma = entry * counts[:len_seq] - merging_seq[:len_seq]
|
861 |
+
proto_variance[:len_seq] += (delta / counts[:len_seq]) * (
|
862 |
+
gamma / counts[:len_seq]
|
863 |
+
)
|
864 |
+
|
865 |
+
mask = counts != 0
|
866 |
+
averaged_seq = merging_seq.copy()
|
867 |
+
averaged_seq[mask] /= counts
|
868 |
+
averaged_seq[~mask] = float("NaN")
|
869 |
+
|
870 |
+
seq_stderr = proto_variance.copy()
|
871 |
+
seq_stderr[counts > 1] = np.sqrt(
|
872 |
+
proto_variance[counts > 1] / (counts[counts > 1] - 1)
|
873 |
+
) / np.sqrt(counts[counts > 1])
|
874 |
+
seq_stderr[counts <= 1] = float("NaN")
|
875 |
+
# for idx, (avg, stderr) in enumerate(zip(averaged_seq[mask], seq_stderr[mask])):
|
876 |
+
# run.log(data={f"{name}_avg": avg, f"{name}_stderr": stderr, "idx_T": idx})
|
877 |
+
# log this triple as a table instead like the ROC curve above
|
878 |
+
# where the first two are plotted and the third is the x axis
|
879 |
+
data = [
|
880 |
+
[x, y, z]
|
881 |
+
for (x, y, z) in zip(
|
882 |
+
averaged_seq[mask], seq_stderr[mask], range(len(averaged_seq[mask]))
|
883 |
+
)
|
884 |
+
]
|
885 |
+
table = wandb.Table(data=data, columns=["avg", "stderr", "idx_T"])
|
886 |
+
|
887 |
+
# first plot avg
|
888 |
+
run.log({f"{name}": wandb.plot.line(table, "idx_T", "avg", title=f"{name} avg")})
|
889 |
+
# then plot stderr
|
890 |
+
run.log(
|
891 |
+
{
|
892 |
+
f"{name}_stderr": wandb.plot.line(
|
893 |
+
table, "idx_T", "stderr", title=f"{name} stderr"
|
894 |
+
)
|
895 |
+
}
|
896 |
+
)
|
897 |
+
|
898 |
+
# Compute AUC_at_T
|
899 |
+
# For now we'll just do a dumb loop over scipy.roc_curve, but this could be batched
|
900 |
+
test_stats = args.roc_test_stat
|
901 |
+
if isinstance(test_stats, str):
|
902 |
+
test_stats = [test_stats]
|
903 |
+
|
904 |
+
for test_stat in test_stats:
|
905 |
+
for attacked in [True, False]:
|
906 |
+
base_col = f"baseline_completion_{test_stat}_at_T"
|
907 |
+
w_wm_col = f"w_wm_output{'_attacked' if attacked else ''}_{test_stat}_at_T"
|
908 |
+
name = f"w_wm{'_attacked' if attacked else ''}_{test_stat}_at_T"
|
909 |
+
|
910 |
+
if w_wm_col in gen_table_w_metrics_ds.features.keys(): # metric was computed
|
911 |
+
print(f"Computing AUC at T for {name}.")
|
912 |
+
max_length = min(
|
913 |
+
max([len(l) for l in gen_table_w_metrics_ds[base_col]]),
|
914 |
+
max([len(l) for l in gen_table_w_metrics_ds[w_wm_col]]),
|
915 |
+
)
|
916 |
+
|
917 |
+
all_aucs, all_tpr_at_X_fpr = [], []
|
918 |
+
for T in range(1, max_length):
|
919 |
+
w_wm_stats = np.array(
|
920 |
+
[t[T] for t in gen_table_w_metrics_ds[w_wm_col] if len(t) > T]
|
921 |
+
)
|
922 |
+
|
923 |
+
baseline_stats = np.array(
|
924 |
+
[t[T] for t in gen_table_w_metrics_ds[base_col] if len(t) > T]
|
925 |
+
)[: len(w_wm_stats)]
|
926 |
+
all_scores = np.concatenate([baseline_stats, w_wm_stats])
|
927 |
+
|
928 |
+
baseline_labels = np.zeros_like(baseline_stats)
|
929 |
+
attacked_labels = np.ones_like(w_wm_stats)
|
930 |
+
all_labels = np.concatenate([baseline_labels, attacked_labels])
|
931 |
+
|
932 |
+
if len(np.unique(all_labels)) < 2:
|
933 |
+
roc_auc = float("NaN")
|
934 |
+
tpr_at_X_fpr = float("NaN")
|
935 |
+
else:
|
936 |
+
fpr, tpr, thresholds = metrics.roc_curve(
|
937 |
+
all_labels, all_scores, pos_label=1
|
938 |
+
)
|
939 |
+
roc_auc = metrics.auc(fpr, tpr)
|
940 |
+
try:
|
941 |
+
tpr_at_X_fpr = tpr[np.where(fpr < 1e-3)[0][-1]]
|
942 |
+
except IndexError:
|
943 |
+
tpr_at_X_fpr = float("NaN")
|
944 |
+
|
945 |
+
all_aucs.append(roc_auc)
|
946 |
+
all_tpr_at_X_fpr.append(tpr_at_X_fpr)
|
947 |
+
# for idx, (aucs, tpr_at) in enumerate(zip(all_aucs, all_tpr_at_X_fpr)):
|
948 |
+
# run.log(data={f"{name}_aucs": aucs, f"{name}_tpr_at": tpr_at, "idx_T": idx})
|
949 |
+
# log these two separately using a table
|
950 |
+
data = [
|
951 |
+
[x, y, z]
|
952 |
+
for (x, y, z) in zip(all_aucs, all_tpr_at_X_fpr, range(len(all_aucs)))
|
953 |
+
]
|
954 |
+
table = wandb.Table(data=data, columns=["aucs", "tpr_at", "idx_T"])
|
955 |
+
run.log(
|
956 |
+
{
|
957 |
+
f"{name}_aucs": wandb.plot.line(
|
958 |
+
table, "idx_T", "aucs", title=f"{name} aucs"
|
959 |
+
)
|
960 |
+
}
|
961 |
+
)
|
962 |
+
run.log(
|
963 |
+
{
|
964 |
+
f"{name}_tpr_at": wandb.plot.line(
|
965 |
+
table, "idx_T", "tpr_at", title=f"{name} tpr_at"
|
966 |
+
)
|
967 |
+
}
|
968 |
+
)
|
969 |
+
|
970 |
+
# finish the wandb run
|
971 |
+
run.finish()
|
972 |
+
|
973 |
+
return
|
974 |
+
|
975 |
+
|
976 |
+
def _roc_metrics_for_wandb(
|
977 |
+
gen_table_ds, test_stat="z_score", prefix="", attacked=False, remove_nan=True
|
978 |
+
):
|
979 |
+
# In theory, we actually should be filtering the attacked column too, but we know these
|
980 |
+
# end up very short sometimes. So, to make sure the logic works, we just
|
981 |
+
# filter for any rows where the test metrics are NaN and note the damage
|
982 |
+
|
983 |
+
baseline_col_name = f"{prefix}baseline_completion_{test_stat}"
|
984 |
+
if "retrieval" in test_stat:
|
985 |
+
if attacked:
|
986 |
+
w_wm_col_name = f"{prefix}w_wm_output_attacked_retrieval_score"
|
987 |
+
else:
|
988 |
+
w_wm_col_name = f"{prefix}{args.retrieval_db_column}_retrieval_score"
|
989 |
+
elif "detectgpt" in test_stat:
|
990 |
+
if attacked:
|
991 |
+
w_wm_col_name = f"{prefix}w_wm_output_attacked_{test_stat}"
|
992 |
+
else:
|
993 |
+
w_wm_col_name = f"{prefix}no_wm_output_{test_stat}"
|
994 |
+
else:
|
995 |
+
w_wm_col_name = f"{prefix}w_wm_output{'_attacked' if attacked else ''}_{test_stat}"
|
996 |
+
|
997 |
+
# drop nans in either column
|
998 |
+
if remove_nan:
|
999 |
+
orig_length = len(gen_table_ds)
|
1000 |
+
gen_table_ds = gen_table_ds.dropna(subset=[baseline_col_name, w_wm_col_name])
|
1001 |
+
if orig_length != len(gen_table_ds):
|
1002 |
+
print(
|
1003 |
+
f"NOTE: During ROC calculation, dropped {orig_length - len(gen_table_ds)} rows due to NaNs in {baseline_col_name} or {w_wm_col_name}"
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
baseline_stats = gen_table_ds[baseline_col_name].values
|
1007 |
+
w_wm_stats = gen_table_ds[w_wm_col_name].values
|
1008 |
+
all_scores = np.concatenate([baseline_stats, w_wm_stats])
|
1009 |
+
|
1010 |
+
baseline_labels = np.zeros_like(baseline_stats)
|
1011 |
+
attacked_labels = np.ones_like(w_wm_stats)
|
1012 |
+
all_labels = np.concatenate([baseline_labels, attacked_labels])
|
1013 |
+
|
1014 |
+
fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)
|
1015 |
+
roc_auc = metrics.auc(fpr, tpr)
|
1016 |
+
try:
|
1017 |
+
tpr_at_X_fpr = tpr[np.where(fpr < 1e-3)[0][-1]]
|
1018 |
+
except IndexError:
|
1019 |
+
tpr_at_X_fpr = float("NaN")
|
1020 |
+
return roc_auc, fpr, tpr, thresholds, tpr_at_X_fpr
|
1021 |
+
|
1022 |
+
|
1023 |
+
if __name__ == "__main__":
|
1024 |
+
parser = argparse.ArgumentParser(description="Run evaluation pipeline for watermark detection")
|
1025 |
+
parser.add_argument(
|
1026 |
+
"--evaluation_metrics",
|
1027 |
+
type=str,
|
1028 |
+
default="all",
|
1029 |
+
help="Comma separated list of columns to remove from the dataset before generation.",
|
1030 |
+
)
|
1031 |
+
parser.add_argument(
|
1032 |
+
"--compute_scores_at_T",
|
1033 |
+
type=str2bool,
|
1034 |
+
default=True,
|
1035 |
+
help="Whether to compute (applicable) metrics at each T index in the output/text columns.",
|
1036 |
+
)
|
1037 |
+
parser.add_argument(
|
1038 |
+
"--overwrite_args",
|
1039 |
+
type=str2bool,
|
1040 |
+
default=False,
|
1041 |
+
help="Whether to overwrite the shared args in the metadata file with the current, runtime args.",
|
1042 |
+
)
|
1043 |
+
parser.add_argument(
|
1044 |
+
"--oracle_model_name_or_path",
|
1045 |
+
type=str,
|
1046 |
+
default="facebook/opt-6.7b",
|
1047 |
+
help="Oracle model, path to pretrained model or model identifier from huggingface.co/models.",
|
1048 |
+
)
|
1049 |
+
parser.add_argument(
|
1050 |
+
"--load_fp16",
|
1051 |
+
type=str2bool,
|
1052 |
+
default=None,
|
1053 |
+
help=(
|
1054 |
+
"Whether to run model (for ppl) in float16 precsion, note, will overwrite error as a reminder that "
|
1055 |
+
"generation was run in other mode, even though there's no hard requirement that these match."
|
1056 |
+
),
|
1057 |
+
)
|
1058 |
+
parser.add_argument(
|
1059 |
+
"--ppl_batch_size",
|
1060 |
+
type=int,
|
1061 |
+
default=1,
|
1062 |
+
help="Batch size for ppl eval.",
|
1063 |
+
)
|
1064 |
+
parser.add_argument(
|
1065 |
+
"--seeding_scheme",
|
1066 |
+
type=Union[str, NoneType],
|
1067 |
+
default=None,
|
1068 |
+
help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
|
1069 |
+
)
|
1070 |
+
parser.add_argument(
|
1071 |
+
"--gamma",
|
1072 |
+
type=Union[float, NoneType],
|
1073 |
+
default=None,
|
1074 |
+
help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
|
1075 |
+
)
|
1076 |
+
parser.add_argument(
|
1077 |
+
"--normalizers",
|
1078 |
+
type=Union[str, NoneType],
|
1079 |
+
default=None,
|
1080 |
+
help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
|
1081 |
+
)
|
1082 |
+
parser.add_argument(
|
1083 |
+
"--ignore_repeated_ngrams",
|
1084 |
+
type=str2bool,
|
1085 |
+
default=False,
|
1086 |
+
help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
|
1087 |
+
)
|
1088 |
+
parser.add_argument(
|
1089 |
+
"--detection_z_threshold",
|
1090 |
+
type=float,
|
1091 |
+
default=4.0,
|
1092 |
+
help="The test statistic threshold for the detection hypothesis test.",
|
1093 |
+
)
|
1094 |
+
parser.add_argument(
|
1095 |
+
"--return_green_token_mask",
|
1096 |
+
type=str2bool,
|
1097 |
+
default=True,
|
1098 |
+
help="Whether to return the mask marking which tokens are green from the watermark detector.",
|
1099 |
+
)
|
1100 |
+
parser.add_argument(
|
1101 |
+
"--window_settings",
|
1102 |
+
type=str,
|
1103 |
+
default="20,40,max", # can also be "20" or "20,40,max"
|
1104 |
+
help="Comma separated list of window sizes to use for watermark detection. Only used if 'windowed-z-score' is in the evaluation metrics list.",
|
1105 |
+
)
|
1106 |
+
parser.add_argument(
|
1107 |
+
"--run_len_chisqrd_variant",
|
1108 |
+
type=str,
|
1109 |
+
default="F_succ_T_runs",
|
1110 |
+
choices=["F_succ_T_runs", "T_and_F_runs"],
|
1111 |
+
help="The variant of the run length test to use for watermark detection.",
|
1112 |
+
)
|
1113 |
+
parser.add_argument(
|
1114 |
+
"--run_len_chisqrd_bin_spec",
|
1115 |
+
type=str,
|
1116 |
+
default="max_plus_1",
|
1117 |
+
choices=["max", "max_plus_1"],
|
1118 |
+
help="The binning specification to use for the run length test.",
|
1119 |
+
)
|
1120 |
+
parser.add_argument(
|
1121 |
+
"--run_len_chisqrd_mask_zeros",
|
1122 |
+
type=str2bool,
|
1123 |
+
default=True,
|
1124 |
+
help="Whether to mask zeros in the run length test.",
|
1125 |
+
)
|
1126 |
+
parser.add_argument(
|
1127 |
+
"--run_len_chisqrd_mask_leading_bins",
|
1128 |
+
type=int,
|
1129 |
+
default=0,
|
1130 |
+
help="The number of leading bins to mask in the run length test.",
|
1131 |
+
)
|
1132 |
+
parser.add_argument(
|
1133 |
+
"--run_len_chisqrd_lambda",
|
1134 |
+
type=str,
|
1135 |
+
default="pearson",
|
1136 |
+
choices=["pearson", "g_test", "cressie_read"],
|
1137 |
+
help="The lambda_ param to use for the run length test.",
|
1138 |
+
)
|
1139 |
+
parser.add_argument(
|
1140 |
+
"--retrieval_technique",
|
1141 |
+
type=str,
|
1142 |
+
default="bm25",
|
1143 |
+
choices=["bm25", "sim"],
|
1144 |
+
help="The retrieval technique to use for retrieval detection.",
|
1145 |
+
)
|
1146 |
+
parser.add_argument(
|
1147 |
+
"--retrieval_db_column",
|
1148 |
+
type=str,
|
1149 |
+
default="no_wm_output",
|
1150 |
+
choices=["w_wm_output", "no_wm_output"],
|
1151 |
+
help="The column to populate the db/index with use for retrieval detection.",
|
1152 |
+
)
|
1153 |
+
parser.add_argument(
|
1154 |
+
"--retrieval_db_load_all_prefixes",
|
1155 |
+
type=str2bool,
|
1156 |
+
default=False,
|
1157 |
+
help="Whether to load all prefixes into the retrieval db, or just the longest for each unique entry.",
|
1158 |
+
)
|
1159 |
+
parser.add_argument(
|
1160 |
+
"--roc_test_stat",
|
1161 |
+
type=str,
|
1162 |
+
default="all",
|
1163 |
+
help="The comma separated list of test statistics to use for the ROC-AUC metric.",
|
1164 |
+
)
|
1165 |
+
parser.add_argument(
|
1166 |
+
"--target_T",
|
1167 |
+
type=int,
|
1168 |
+
default=0,
|
1169 |
+
help="The target generation length to use when dropping rows before ROC-AUC evaluation.",
|
1170 |
+
)
|
1171 |
+
parser.add_argument(
|
1172 |
+
"--lower_tolerance_T",
|
1173 |
+
type=int,
|
1174 |
+
default=25,
|
1175 |
+
help="The lower tolerance to use when dropping rows before ROC-AUC evaluation.",
|
1176 |
+
)
|
1177 |
+
parser.add_argument(
|
1178 |
+
"--upper_tolerance_T",
|
1179 |
+
type=int,
|
1180 |
+
default=25,
|
1181 |
+
help="The upper tolerance to use when dropping rows before ROC-AUC evaluation.",
|
1182 |
+
)
|
1183 |
+
parser.add_argument(
|
1184 |
+
"--filter_by_columns",
|
1185 |
+
type=str,
|
1186 |
+
default="all",
|
1187 |
+
help="The comma separated list of columns to filter by before ROC-AUC evaluation.",
|
1188 |
+
)
|
1189 |
+
parser.add_argument(
|
1190 |
+
"--wandb",
|
1191 |
+
type=str2bool,
|
1192 |
+
default=False,
|
1193 |
+
help="Whether to log to wandb.",
|
1194 |
+
)
|
1195 |
+
parser.add_argument(
|
1196 |
+
"--wandb_project",
|
1197 |
+
type=str,
|
1198 |
+
default="lm-watermarking",
|
1199 |
+
help="The name of the wandb project.",
|
1200 |
+
)
|
1201 |
+
parser.add_argument(
|
1202 |
+
"--wandb_entity",
|
1203 |
+
type=str,
|
1204 |
+
default="jwkirchenbauer",
|
1205 |
+
help="The wandb entity/user for the project.",
|
1206 |
+
)
|
1207 |
+
parser.add_argument(
|
1208 |
+
"--wandb_tags",
|
1209 |
+
type=str,
|
1210 |
+
default="",
|
1211 |
+
help="The comma separated list of tags to add to the wandb run.",
|
1212 |
+
)
|
1213 |
+
parser.add_argument(
|
1214 |
+
"--run_name",
|
1215 |
+
type=str,
|
1216 |
+
default="",
|
1217 |
+
help="The unique name for the run.",
|
1218 |
+
)
|
1219 |
+
parser.add_argument(
|
1220 |
+
"--input_dir",
|
1221 |
+
type=str,
|
1222 |
+
default="./input",
|
1223 |
+
help="The directory containing the input files.",
|
1224 |
+
)
|
1225 |
+
parser.add_argument(
|
1226 |
+
"--output_dir",
|
1227 |
+
type=str,
|
1228 |
+
default="",
|
1229 |
+
help=(
|
1230 |
+
"The directory in which to write out the dataset after adding the metrics. "
|
1231 |
+
"If not specified, will use the input_dir. Note, if the output_dir already "
|
1232 |
+
"contains the metric-enriched file, it will be overwritten :/"
|
1233 |
+
),
|
1234 |
+
)
|
1235 |
+
parser.add_argument(
|
1236 |
+
"--overwrite_output_file",
|
1237 |
+
type=str2bool,
|
1238 |
+
default=False,
|
1239 |
+
help="Whether to overwrite the output file if it already exists.",
|
1240 |
+
)
|
1241 |
+
parser.add_argument(
|
1242 |
+
"--limit_rows",
|
1243 |
+
type=int,
|
1244 |
+
default=-1,
|
1245 |
+
help="The number of rows to limit the dataset to. Useful for debugging.",
|
1246 |
+
)
|
1247 |
+
parser.add_argument(
|
1248 |
+
"--concat_rows",
|
1249 |
+
type=int,
|
1250 |
+
default=0,
|
1251 |
+
help="The number of rows to concatenate into a single row. Result is a mangled dataset, be careful",
|
1252 |
+
)
|
1253 |
+
parser.add_argument(
|
1254 |
+
"--shuffle_before_concat",
|
1255 |
+
type=str2bool,
|
1256 |
+
default=False,
|
1257 |
+
help="Whether to shuffle the dataset before concatenating rows.",
|
1258 |
+
)
|
1259 |
+
parser.add_argument(
|
1260 |
+
"--verbose",
|
1261 |
+
type=str2bool,
|
1262 |
+
default=None,
|
1263 |
+
help="Whether to verbosely print things here and there.",
|
1264 |
+
)
|
1265 |
+
parser.add_argument(
|
1266 |
+
"--log_raw_series",
|
1267 |
+
type=str2bool,
|
1268 |
+
default=True,
|
1269 |
+
help="Whether to log the raw series metric data to wandb.",
|
1270 |
+
)
|
1271 |
+
parser.add_argument(
|
1272 |
+
"--log_raw_tabular",
|
1273 |
+
type=str2bool,
|
1274 |
+
default=True,
|
1275 |
+
help="Whether to log the raw tabular metric data to wandb.",
|
1276 |
+
)
|
1277 |
+
args = parser.parse_args()
|
1278 |
+
|
1279 |
+
###########################################################################
|
1280 |
+
# Argument validation and conditional setting
|
1281 |
+
###########################################################################
|
1282 |
+
|
1283 |
+
# convert evaluation metrics to list
|
1284 |
+
assert args.evaluation_metrics, "evaluation_metrics list must be specified"
|
1285 |
+
args.evaluation_metrics = args.evaluation_metrics.split(",")
|
1286 |
+
|
1287 |
+
if args.evaluation_metrics == ["all"]:
|
1288 |
+
all_metrics = SUPPORTED_METRICS
|
1289 |
+
all_metrics.remove("ppl") # by default not running this anymore
|
1290 |
+
all_metrics.remove("detectgpt") # can't run this with other metrics
|
1291 |
+
args.evaluation_metrics = all_metrics
|
1292 |
+
if args.evaluation_metrics == ["all_w_ppl"]:
|
1293 |
+
args.evaluation_metrics = SUPPORTED_METRICS
|
1294 |
+
|
1295 |
+
# if no output dir specified, use the input dir
|
1296 |
+
if args.output_dir == "":
|
1297 |
+
args.output_dir = args.input_dir
|
1298 |
+
|
1299 |
+
# check limit_rows
|
1300 |
+
assert (args.limit_rows == -1) or (
|
1301 |
+
(args.limit_rows > 0) and isinstance(args.limit_rows, int)
|
1302 |
+
), "limit_rows must be -1 or > 0"
|
1303 |
+
|
1304 |
+
# convert normalizers to list
|
1305 |
+
if args.normalizers:
|
1306 |
+
args.normalizers = args.normalizers.split(",")
|
1307 |
+
else:
|
1308 |
+
args.normalizers = []
|
1309 |
+
|
1310 |
+
# convert roc_test_stat to list
|
1311 |
+
args.roc_test_stat = args.roc_test_stat.split(",")
|
1312 |
+
|
1313 |
+
if args.roc_test_stat == ["all"]:
|
1314 |
+
args.roc_test_stat = ROC_TEST_STAT_SUFFIXES
|
1315 |
+
|
1316 |
+
# convert filter_by_columns to list
|
1317 |
+
args.filter_by_columns = args.filter_by_columns.split(",")
|
1318 |
+
if args.filter_by_columns == ["all"]:
|
1319 |
+
args.filter_by_columns = FILTER_BY_COLUMNS
|
1320 |
+
|
1321 |
+
# split wandb tags
|
1322 |
+
if args.wandb_tags != "":
|
1323 |
+
args.wandb_tags = args.wandb_tags.split(",")
|
1324 |
+
else:
|
1325 |
+
args.wandb_tags = []
|
1326 |
+
|
1327 |
+
# split window settings
|
1328 |
+
args.window_settings = args.window_settings.split(",")
|
1329 |
+
|
1330 |
+
main(args)
|
lm-watermarking-main/watermark_reliability_release/figure_notebooks/baseline_comparison.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lm-watermarking-main/watermark_reliability_release/figure_notebooks/baseline_comparison_transpose.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lm-watermarking-main/watermark_reliability_release/figure_notebooks/core_robustness.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lm-watermarking-main/watermark_reliability_release/figure_notebooks/data_model.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|