Havi999 commited on
Commit
9c8c5e4
·
1 Parent(s): 7e00faf

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.DS_Store ADDED
Binary file (10.2 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ resources/web-demo.gif filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug_report.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🐞 Bug/Help
2
+ description: File a bug/issue
3
+ title: "[BUG/Help] <title>"
4
+ labels: []
5
+ body:
6
+ - type: checkboxes
7
+ attributes:
8
+ label: Is there an existing issue for this?
9
+ description: Please search to see if an issue already exists for the bug you encountered.
10
+ options:
11
+ - label: I have searched the existing issues
12
+ required: true
13
+ - type: textarea
14
+ attributes:
15
+ label: Current Behavior
16
+ description: |
17
+ A concise description of what you're experiencing, with screenshot attached if possible.
18
+ Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
19
+ validations:
20
+ required: true
21
+ - type: textarea
22
+ attributes:
23
+ label: Expected Behavior
24
+ description: A concise description of what you expected to happen.
25
+ validations:
26
+ required: false
27
+ - type: textarea
28
+ attributes:
29
+ label: Steps To Reproduce
30
+ description: Steps to reproduce the behavior.
31
+ placeholder: |
32
+ 1. In this environment...
33
+ 2. With this config...
34
+ 3. Run '...'
35
+ 4. See error...
36
+ validations:
37
+ required: true
38
+ - type: textarea
39
+ attributes:
40
+ label: Environment
41
+ description: |
42
+ examples:
43
+ - **OS**: Ubuntu 20.04
44
+ - **Python**: 3.8
45
+ - **Transformers**: 4.26.1
46
+ - **PyTorch**: 1.12
47
+ - **CUDA Support**: True
48
+ value: |
49
+ - OS:
50
+ - Python:
51
+ - Transformers:
52
+ - PyTorch:
53
+ - CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) :
54
+ render: markdown
55
+ validations:
56
+ required: true
57
+ - type: textarea
58
+ attributes:
59
+ label: Anything else?
60
+ description: |
61
+ Links? References? Anything that will give us more context about the issue you are encountering!
62
+ validations:
63
+ required: false
.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ blank_issues_enabled: false
.github/ISSUE_TEMPLATE/feature_request.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Feature request
2
+ description: Suggest an idea for this project
3
+ title: "[Feature] <title>"
4
+ labels: []
5
+ body:
6
+ - type: textarea
7
+ attributes:
8
+ label: Is your feature request related to a problem? Please describe.
9
+ description: |
10
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
11
+ validations:
12
+ required: false
13
+ - type: textarea
14
+ attributes:
15
+ label: Solutions
16
+ description: |
17
+ Describe the solution you'd like
18
+ A clear and concise description of what you want to happen.
19
+ validations:
20
+ required: true
21
+ - type: textarea
22
+ attributes:
23
+ label: Additional context
24
+ description: Add any other context or screenshots about the feature request here.
25
+ validations:
26
+ required: false
.gitignore ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+ history/
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ .python-version
87
+
88
+ # pipenv
89
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
91
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
92
+ # install all needed dependencies.
93
+ #Pipfile.lock
94
+
95
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96
+ __pypackages__/
97
+
98
+ # Celery stuff
99
+ celerybeat-schedule
100
+ celerybeat.pid
101
+
102
+ # SageMath parsed files
103
+ *.sage.py
104
+
105
+ # Environments
106
+ .env
107
+ .venv
108
+ env/
109
+ venv/
110
+ ENV/
111
+ env.bak/
112
+ venv.bak/
113
+
114
+ # Spyder project settings
115
+ .spyderproject
116
+ .spyproject
117
+
118
+ # Rope project settings
119
+ .ropeproject
120
+
121
+ # mkdocs documentation
122
+ /site
123
+
124
+ # mypy
125
+ .mypy_cache/
126
+ .dmypy.json
127
+ dmypy.json
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+
132
+ # Mac system file
133
+ model/
134
+ .idea
FAQ.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Q1
2
+
3
+ **Mac直接加载量化后的模型出现提示 `clang: error: unsupported option '-fopenmp'**
4
+
5
+ 这是由于Mac由于本身缺乏omp导致的,此时可运行但是单核。需要单独安装 openmp 依赖,即可在Mac下使用OMP:
6
+
7
+ ```bash
8
+ # 参考`https://mac.r-project.org/openmp/`
9
+ ## 假设: gcc(clang)是14.x版本,其他版本见R-Project提供的表格
10
+ curl -O https://mac.r-project.org/openmp/openmp-14.0.6-darwin20-Release.tar.gz
11
+ sudo tar fvxz openmp-14.0.6-darwin20-Release.tar.gz -C /
12
+ ```
13
+ 此时会安装下面几个文件:`/usr/local/lib/libomp.dylib`, `/usr/local/include/ompt.h`, `/usr/local/include/omp.h`, `/usr/local/include/omp-tools.h`。
14
+
15
+ > 注意:如果你之前运行`ChatGLM`项目失败过,最好清一下Huggingface的缓存,i.e. 默认下是 `rm -rf ${HOME}/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4`。由于使用了`rm`命令,请明确知道自己在删除什么。
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright Zhengxiao Du
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MODEL_LICENSE ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The ChatGLM-6B License
2
+
3
+ 1. 定义
4
+
5
+ “许可方”是指分发其软件的 ChatGLM-6B 模型团队。
6
+
7
+ “软件”是指根据本许可提供的 ChatGLM-6B 模型参数。(不包括二代模型 ChatGLM2-6B 以及后续模型)
8
+
9
+ 2. 许可授予
10
+
11
+ 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。
12
+
13
+ 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
14
+
15
+ 3.限制
16
+
17
+ 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
18
+
19
+ 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
20
+
21
+ 4.免责声明
22
+
23
+ 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
24
+
25
+ 5. 责任限制
26
+
27
+ 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
28
+
29
+ 6.争议解决
30
+
31
+ 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
32
+
33
+ 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 [email protected] 与我们联系。
34
+
35
+ 1. Definitions
36
+
37
+ “Licensor” means the ChatGLM-6B Model Team that distributes its Software.
38
+
39
+ “Software” means the ChatGLM-6B model parameters made available under this license (does not include the second-generation model ChatGLM2-6B and subsequent models).
40
+
41
+ 2. License Grant
42
+
43
+ Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software.
44
+
45
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
46
+
47
+ 3. Restriction
48
+
49
+ You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
50
+
51
+ You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
52
+
53
+ 4. Disclaimer
54
+
55
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
56
+
57
+ 5. Limitation of Liability
58
+
59
+ EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
60
+
61
+ 6. Dispute Resolution
62
+
63
+ This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
64
+
65
+ Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at [email protected].
PROJECT.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 友情链接
2
+
3
+ 对 ChatGLM 进行加速或者重新实现的开源项目:
4
+ * [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): 对 ChatGLM-6B 进行推理加速,最高可以实现 9000+ tokens/s 的推理速度
5
+ * [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer): 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。
6
+ * [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
7
+ * [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署
8
+ * [InferLLM](https://github.com/MegEngine/InferLLM):轻量级 C++ 推理,可以实现本地 x86,Arm 处理器上实时聊天,手机上也同样可以实时运行,运行内存只需要 4G
9
+
10
+
11
+
12
+ 基于或使用了 ChatGLM-6B 的开源项目:
13
+ * [chatgpt_academic](https://github.com/binary-husky/chatgpt_academic): 支持ChatGLM-6B的学术写作与编程工具箱,具有模块化和多线程调用LLM的特点,可并行调用多种LLM。
14
+ * [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
15
+ * [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM
16
+ * [Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain):中文langchain项目,基于ChatGLM-6b+langchain实现本地化知识库检索与智能答案生成,增加web search功能、知识库选择功能和支持知识增量更新
17
+ * [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。
18
+ * [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于 langchain 的 ChatGLM 应用,实现基于可扩展知识库的问答
19
+ * [ChatGLM-web](https://github.com/NCZkevin/chatglm-web):基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能)
20
+ * [Chuanhu Chat](https://github.com/GaiZhenbiao/ChuanhuChatGPT): 为各个大语言模型和在线模型API提供美观易用、功能丰富、快速部署的用户界面,支持ChatGLM-6B。
21
+ * [ChatGLM-6B-Engineering](https://github.com/LemonQu-GIT/ChatGLM-6B-Engineering):基于 ChatGLM-6B 后期调教,网络爬虫及 [Stable Diffusion](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 实现的网络搜索及图片生成
22
+ * [ChatGLM-OpenAI-API](https://github.com/ninehills/chatglm-openai-api): 将 ChatGLM-6B 封装为 OpenAI API 风格,并通过 ngrok/cloudflare 对外提供服务,从而将 ChatGLM 快速集成到 OpenAI 的各种生态中。
23
+ * [ChatSQL](https://github.com/cubenlp/ChatSQL): 基于ChatGLM+SBERT实现NL2SQL本地化,并直接连接数据库查询数据返回结果,使得生成的SQL语句更具有实用性。
24
+
25
+ 对 ChatGLM-6B 进行微调的开源项目:
26
+ * [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM):基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题
27
+ * [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning):实现了ChatGLM-6B模型的监督微调和完整RLHF训练,汇总10余种指令数据集和3种微调方案,实现了4/8比特量化和模型权重融合,提供微调模型快速部署方法。
28
+ * [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。
29
+ * [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 [Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调](https://github.com/hscspring/hcgf)
30
+
31
+
32
+ 针对 ChatGLM-6B 的教程/文档:
33
+ * [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
34
+ * [搭建深度学习docker容器以运行 ChatGLM-6B - Luck_zy](https://www.luckzym.com/tags/ChatGLM-6B/)
35
+
36
+ 如果你有其他好的项目/教程的话,欢迎参照上述格式添加到 README 中并提出 [Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)。
37
+
README.md CHANGED
@@ -1,12 +1,397 @@
1
  ---
2
  title: FORAI
3
- emoji: 🚀
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.40.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: FORAI
3
+ app_file: web_demo_old.py
 
 
4
  sdk: gradio
5
  sdk_version: 3.40.1
 
 
6
  ---
7
+ # ChatGLM-6B
8
 
9
+ <p align="center">
10
+ 🌐 <a href="https://chatglm.cn/blog" target="_blank">Blog</a> • 🤗 <a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2210.02414" target="_blank">[GLM-130B@ICLR 23]</a> <a href="https://github.com/THUDM/GLM-130B" target="_blank">[GitHub]</a> <br>
11
+ </p>
12
+ <p align="center">
13
+ 👋 加入我们的 <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> 和 <a href="resources/WECHAT.md" target="_blank">WeChat</a>
14
+ </p>
15
+
16
+ *Read this in [English](README_en.md).*
17
+
18
+ ## 介绍
19
+
20
+ ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。
21
+ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答,更多信息请参考我们的[博客](https://chatglm.cn/blog)。欢迎通过 [chatglm.cn](https://chatglm.cn) 体验更大规模的 ChatGLM 模型。
22
+
23
+ 为了方便下游开发者针对自己的应用场景定制模型,我们同时实现了基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调方法 [(使用指南)](ptuning/README.md) ,INT4 量化级别下最低只需 7GB 显存即可启动微调。
24
+
25
+ ChatGLM-6B 权重对学术研究**完全开放**,在填写[问卷](https://open.bigmodel.cn/mla/form)进行登记后**亦允许免费商业使用**。
26
+
27
+ 想让 ChatGLM-6B 更符合你的应用场景?欢迎参与 [Badcase 反馈计划](improve/README.md)。
28
+
29
+ -----
30
+
31
+ ChatGLM-6B 开源模型旨在与开源社区一起推动大模型技术发展,恳请开发者和大家遵守[开源协议](MODEL_LICENSE),勿将开源模型和代码及基于开源项目产生的衍生物用于任何可能给国家和社会带来危害的用途以及用于任何未经过安全评估和备案的服务。**目前,本项目团队未基于 ChatGLM-6B 开发任何应用,包括网页端、安卓、苹果 iOS 及 Windows App 等应用。**
32
+
33
+ 尽管模型在训练的各个阶段都尽力确保数据的合规性和准确性,但由于 ChatGLM-6B 模型规模较小,且模型受概率随机性因素影响,无法保证输出内容的准确性,且模型易被误导(详见[局限性](README.md#局限性))。**本项目不承担开源模型和代码导致的数据安全、舆情风险或发生任何模型被误导、滥用、传播、不当利用而产生的风险和责任。**
34
+
35
+ ## 更新信息
36
+ **[2023/07/25]** 发布 [CodeGeeX2](https://github.com/THUDM/CodeGeeX2) ,基于 ChatGLM2-6B 的代码生成模型,代码能力全面提升,更多特性包括:
37
+
38
+ * **更强大的代码能力**:CodeGeeX2-6B 进一步经过了 600B 代码数据预训练,相比 CodeGeeX 一代模型,在代码能力上全面提升,[HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x) 评测集的六种编程语言均大幅提升 (Python +57%, C++ +71%, Java +54%, JavaScript +83%, Go +56%, Rust +321\%),在Python上达到 35.9\% 的 Pass@1 一次通过率,超越规模更大的 StarCoder-15B。
39
+ * **更优秀的模型特性**:继承 ChatGLM2-6B 模型特性,CodeGeeX2-6B 更好支持中英文输入,支持最大 8192 序列长度,推理速度较一代 大幅提升,量化后仅需6GB显存即可运行,支持轻量级本地化部署。
40
+ * **更全面的AI编程助手**:CodeGeeX插件([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex))后端升级,支持超过100种编程语言,新增上下文补全、跨文件补全等实用功能。结合 Ask CodeGeeX 交互式AI编程助手,支持中英文对话解决各种编程问题,包括且不限于代码解释、代码翻译、代码纠错、文档生成等,帮助程序员更高效开发。
41
+
42
+ **[2023/06/25]** 发布 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B),ChatGLM-6B 的升级版本,在保留了了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B 引入了如下新特性:
43
+
44
+ 1. **更强大的性能**:基于 ChatGLM 初代模型的开发��验,我们全面升级了 ChatGLM2-6B 的基座模型。ChatGLM2-6B 使用了 [GLM](https://github.com/THUDM/GLM) 的混合目标函数,经过了 1.4T 中英标识符的预训练与人类偏好对齐训练,[评测结果](#评测结果)显示,相比于初代模型,ChatGLM2-6B 在 MMLU(+23%)、CEval(+33%)、GSM8K(+571%) 、BBH(+60%)等数据集上的性能取得了大幅度的提升,在同尺寸开源模型中具有较强的竞争力。
45
+ 2. **更长的上下文**:基于 [FlashAttention](https://github.com/HazyResearch/flash-attention) 技术,我们将基座模型的上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段使用 8K 的上下文长度训练,允许更多轮次的对话。但当前版本的 ChatGLM2-6B 对单轮超长文档的理解能力有限,我们会在后续迭代升级中着重进行优化。
46
+ 3. **更高效的推理**:基于 [Multi-Query Attention](http://arxiv.org/abs/1911.02150) 技术,ChatGLM2-6B 有更高效的推理速度和更低的显存占用:在官方的模型实现下,推理速度相比初代提升了 42%,INT4 量化下,6G 显存支持的对话长度由 1K 提升到了 8K。
47
+
48
+ 更多信息参见 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B)。
49
+
50
+ **[2023/06/14]** 发布 [WebGLM](https://github.com/THUDM/WebGLM),一项被接受于KDD 2023的研究工作,支持利用网络信息生成带有准确引用的长回答。
51
+
52
+ ![](resources/webglm.jpg)
53
+
54
+ **[2023/05/17]** 发布 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B),一个支持图像理解的多模态对话语言模型。
55
+
56
+ ![](resources/visualglm.png)
57
+
58
+ 可以通过本仓库中的 [cli_demo_vision.py](cli_demo_vision.py) 和 [web_demo_vision.py](web_demo_vision.py) 来运行命令行和网页 Demo。注意 VisualGLM-6B 需要额外安装 [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) 和 torchvision。更多信息参见 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B)。
59
+
60
+ **[2023/05/15]** 更新 v1.1 版本 checkpoint,训练数据增加英文指令微调数据以平衡中英文数据比例,解决英文回答中夹杂中文词语的现象。
61
+
62
+ <details><summary><b>以下是更新前后的英文问题对比:</b></summary>
63
+
64
+ * 问题:Describe a time when you had to make a difficult decision.
65
+ - v1.0:
66
+ ![](resources/english-q1-old.png)
67
+ - v1.1:
68
+ ![](resources/english-q1-new.png)
69
+ * 问题:Describe the function of a computer motherboard
70
+ - v1.0:
71
+ ![](resources/english-q2-old.png)
72
+ - v1.1:
73
+ ![](resources/english-q2-new.png)
74
+ * 问题:Develop a plan to reduce electricity usage in a home.
75
+ - v1.0:
76
+ ![](resources/english-q3-old.png)
77
+ - v1.1:
78
+ ![](resources/english-q3-new.png)
79
+ * 问题:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语
80
+ - v1.0:
81
+ ![](resources/english-q4-old.png)
82
+ - v1.1:
83
+ ![](resources/english-q4-new.png)
84
+ </details>
85
+
86
+ 更多更新信息参见 [UPDATE.md](UPDATE.md)
87
+
88
+ ## 友情链接
89
+ 对 ChatGLM 进行加速的开源项目:
90
+ * [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): 对 ChatGLM-6B 进行推理加速,最高可以实现 9000+ tokens/s 的推理速度
91
+ * [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
92
+ * [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署
93
+ * [InferLLM](https://github.com/MegEngine/InferLLM):轻量级 C++ 推理,可以实现本地 x86,Arm 处理器上实时聊天,手机上也同样可以实时运行,运行内存只需要 4G
94
+
95
+ 基于或使用了 ChatGLM-6B 的开源项目:
96
+ * [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于 langchain 的 ChatGLM 应用,实现基于可扩展知识库的问答
97
+ * [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
98
+ * [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM
99
+ * [Chuanhu Chat](https://github.com/GaiZhenbiao/ChuanhuChatGPT): 为各个大语言模型和在线模型API提供美观易用、功能丰富、快速部署的用户界面,支持ChatGLM-6B。
100
+
101
+ 支持 ChatGLM-6B 和相关应用在线训练的示例项目:
102
+ * [ChatGLM-6B 的部署与微调教程](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e)
103
+ * [ChatGLM-6B 结合 langchain 实现本地知识库 QA Bot](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
104
+
105
+ 第三方评测:
106
+ * [Measuring Massive Multitask Chinese Understanding](https://arxiv.org/abs/2304.12986)
107
+
108
+ 更多开源��目参见 [PROJECT.md](PROJECT.md)
109
+
110
+ ## 使用方式
111
+
112
+ ### 硬件需求
113
+
114
+ | **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) |
115
+ | -------------- | ------------------------- | --------------------------------- |
116
+ | FP16(无量化) | 13 GB | 14 GB |
117
+ | INT8 | 8 GB | 9 GB |
118
+ | INT4 | 6 GB | 7 GB |
119
+ ### 环境安装
120
+
121
+ 使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.27.1`,但理论上不低于 `4.23.1` 即可。
122
+
123
+ 此外,如果需要在 cpu 上运行量化后的模型,还需要安装 `gcc` 与 `openmp`。多数 Linux 发行版默认已安装。对于 Windows ,可在安装 [TDM-GCC](https://jmeubank.github.io/tdm-gcc/) 时勾选 `openmp`。 Windows 测试环境 `gcc` 版本为 `TDM-GCC 10.3.0`, Linux 为 `gcc 11.3.0`。在 MacOS 上请参考 [Q1](FAQ.md#q1)。
124
+
125
+ ### 代码调用
126
+
127
+ 可以通过如下代码调用 ChatGLM-6B 模型来生成对话:
128
+
129
+ ```python
130
+ >>> from transformers import AutoTokenizer, AutoModel
131
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
132
+ >>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
133
+ >>> model = model.eval()
134
+ >>> response, history = model.chat(tokenizer, "你好", history=[])
135
+ >>> print(response)
136
+ 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
137
+ >>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
138
+ >>> print(response)
139
+ 晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:
140
+
141
+ 1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。
142
+ 2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。
143
+ 3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。
144
+ 4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。
145
+ 5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。
146
+ 6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。
147
+
148
+ 如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
149
+ ```
150
+ 模型的实现仍然处在变动中。如果希望固定使用的模型实现以保证兼容性,可以在 `from_pretrained` 的调用中增加 `revision="v1.1.0"` 参数。`v1.1.0` 是当前最新的版本号,完整的版本列表参见 [Change Log](https://huggingface.co/THUDM/chatglm-6b#change-log)。
151
+
152
+ ### 从本地加载模型
153
+ 以上代码会由 `transformers` 自动下载模型实现和参数。完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b)。如果你的网络环境较差,下载模型参数可能会花费较长时间甚至失败。此时可以先将模型下载到本地,然后从本地加载。
154
+
155
+ 从 Hugging Face Hub 下载模型需要先[安装Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage),然后运行
156
+ ```Shell
157
+ git clone https://huggingface.co/THUDM/chatglm-6b
158
+ ```
159
+
160
+ 如果你从 Hugging Face Hub 上下载 checkpoint 的速度较慢,可以只下载模型实现
161
+ ```Shell
162
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b
163
+ ```
164
+ 然后从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载模型参数文件,并将下载的文件替换到本地的 `chatglm-6b` 目录下。
165
+
166
+ 将模型下载到本地之后,将以上代码中的 `THUDM/chatglm-6b` 替换为你本地的 `chatglm-6b` 文件夹的路径,即可从本地加载模型。
167
+
168
+ **Optional** 模型的实现仍然处在变动中。如果希望固定使用的模型实现以保证兼容性,可以执行
169
+ ```Shell
170
+ git checkout v1.1.0
171
+ ```
172
+
173
+ ## Demo & API
174
+
175
+ 我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库:
176
+
177
+ ```shell
178
+ git clone https://github.com/THUDM/ChatGLM-6B
179
+ cd ChatGLM-6B
180
+ ```
181
+
182
+ ### 网页版 Demo
183
+
184
+ ![web-demo](resources/web-demo.gif)
185
+
186
+ 首先安装 Gradio:`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py):
187
+
188
+ ```shell
189
+ python web_demo.py
190
+ ```
191
+
192
+ 程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。
193
+
194
+ 感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo,运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).
195
+
196
+ ### 命令行 Demo
197
+
198
+ ![cli-demo](resources/cli-demo.png)
199
+
200
+ 运行仓库中 [cli_demo.py](cli_demo.py):
201
+
202
+ ```shell
203
+ python cli_demo.py
204
+ ```
205
+
206
+ 程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。
207
+
208
+ ### API部署
209
+ 首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py):
210
+ ```shell
211
+ python api.py
212
+ ```
213
+ 默认部署在本地的 8000 端口,通过 POST 方法进行调用
214
+ ```shell
215
+ curl -X POST "http://127.0.0.1:8000" \
216
+ -H 'Content-Type: application/json' \
217
+ -d '{"prompt": "你好", "history": []}'
218
+ ```
219
+ 得到的返回值为
220
+ ```shell
221
+ {
222
+ "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。",
223
+ "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]],
224
+ "status":200,
225
+ "time":"2023-03-23 21:38:40"
226
+ }
227
+ ```
228
+
229
+ ## 低成本部署
230
+ ### 模型量化
231
+ 默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
232
+
233
+ ```python
234
+ # 按需修改,目前只支持 4/8 bit 量化
235
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda()
236
+ ```
237
+
238
+ 进行 2 至 3 轮对话后,8-bit 量化下 GPU 显存占用约为 10GB,4-bit 量化下仅需 6GB 占用。随着对话轮数的增多,对应消耗显存也随之增长,由于采用了相对位置编码,理论上 ChatGLM-6B 支持无限长的 context-length,但总长度超过 2048(训练长度)后性能会逐渐下降。
239
+
240
+ 模型量化会带来一定的性能损失,经过测试,ChatGLM-6B 在 4-bit 量化下仍然能够进行自然流畅的生成。使用 [GPT-Q](https://arxiv.org/abs/2210.17323) 等量化方案可以进一步压缩量化精度/提升相同量化精度下的模型性能,欢迎大家提出对应的 Pull Request。
241
+
242
+ 量化过程需要在内存中首先加载 FP16 格式的模型,消耗大概 13GB 的内存。如果你的内存不足的话,可以直接加载量化后的模型,INT4 量化后的模型仅需大概 5.2GB 的内存:
243
+ ```python
244
+ # INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8"
245
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
246
+ ```
247
+ 量化模型的参数文件也可以从[这里](https://cloud.tsinghua.edu.cn/d/674208019e314311ab5c/)手动下载。
248
+
249
+ ### CPU 部署
250
+ 如果你没有 GPU 硬件的话,也可以在 CPU 上进行推理,但是推理速度会更慢。使用方法如下(需要大概 32GB 内存)
251
+ ```python
252
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
253
+ ```
254
+
255
+ 如果你的内存不足,可以直接加载量化后的模型:
256
+ ```python
257
+ # INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8"
258
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
259
+ ```
260
+
261
+ 如果遇到了报错 `Could not find module 'nvcuda.dll'` 或者 `RuntimeError: Unknown platform: darwin` (MacOS) ,请[从本地加载模型](README.md#从本地加载模型)
262
+
263
+ ### Mac 部署
264
+ 对于搭载了 Apple Silicon 或者 AMD GPU 的Mac,可以使用 MPS 后端来在 GPU 上运行 ChatGLM-6B。需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly(正确的版本号应该是2.1.0.dev2023xxxx,而不是2.0.0)。
265
+
266
+ 目前在 MacOS 上只支持[从本地加载模型](README.md#从本地加载模型)。将代码中的模型加载改为从本地加载,并使用 mps 后端:
267
+ ```python
268
+ model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps')
269
+ ```
270
+
271
+ 加载半精度的 ChatGLM-6B 模型需要大概 13GB 内存。内存较小的机器(比如 16GB 内存的 MacBook Pro),在空余内存不足的情况下会使用硬盘上的虚拟内存,导致推理速度严重变慢。此时可以使用量化后的模型如 chatglm-6b-int4。因为 GPU 上量化的 kernel 是使用 CUDA 编写的,因此无法在 MacOS 上使用,只能使用 CPU 进行推理。
272
+ ```python
273
+ # INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8"
274
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
275
+ ```
276
+ 为了充分使用 CPU 并行,还需要[单独安装 OpenMP](FAQ.md#q1)。
277
+
278
+ ### 多卡部署
279
+ 如果你有多张 GPU,但是每张 GPU 的显存大小都不足以容纳完整的模型,那么可以将模型切分在多张GPU上。首先安装 accelerate: `pip install accelerate`,然后通过如下方法加载模型:
280
+ ```python
281
+ from utils import load_model_on_gpus
282
+ model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
283
+ ```
284
+ 即可将模型部署到两张 GPU 上进行推理。你可以将 `num_gpus` 改为你希望使用的 GPU 数。默认是均匀切分的,你也可以传入 `device_map` 参数来自己指定。
285
+
286
+ ## 高效参数微调
287
+ 基于 [P-tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调。具体使用方法详见 [ptuning/README.md](ptuning/README.md)。
288
+
289
+ ## ChatGLM-6B 示例
290
+
291
+ 以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
292
+
293
+ <details><summary><b>自我认知</b></summary>
294
+
295
+ ![](examples/self-introduction.png)
296
+
297
+ </details>
298
+
299
+ <details><summary><b>提纲写作</b></summary>
300
+
301
+ ![](examples/blog-outline.png)
302
+
303
+ </details>
304
+
305
+ <details><summary><b>文案写作</b></summary>
306
+
307
+ ![](examples/ad-writing-2.png)
308
+
309
+ ![](examples/comments-writing.png)
310
+
311
+ </details>
312
+
313
+ <details><summary><b>邮件写作助手</b></summary>
314
+
315
+ ![](examples/email-writing-1.png)
316
+
317
+ ![](examples/email-writing-2.png)
318
+
319
+ </details>
320
+
321
+ <details><summary><b>信息抽取</b></summary>
322
+
323
+ ![](examples/information-extraction.png)
324
+
325
+ </details>
326
+
327
+ <details><summary><b>角色扮演</b></summary>
328
+
329
+ ![](examples/role-play.png)
330
+
331
+ </details>
332
+
333
+ <details><summary><b>评论比较</b></summary>
334
+
335
+ ![](examples/sport.png)
336
+
337
+ </details>
338
+
339
+ <details><summary><b>旅游向导</b></summary>
340
+
341
+ ![](examples/tour-guide.png)
342
+
343
+ </details>
344
+
345
+ ## 局限性
346
+
347
+ 由于 ChatGLM-6B 的小规模,其能力仍然有许多局限性。以下是我们目前发现的一些问题:
348
+
349
+ - 模型容量较小:6B 的小容量,决定了其相对较弱的模型记忆和语言能力。在面对许多事实性知识任务时,ChatGLM-6B 可能会生成不正确的信息;它也不擅长逻辑类问题(如数学、编程)的解答。
350
+ <details><summary><b>点击查看例子</b></summary>
351
+
352
+ ![](limitations/factual_error.png)
353
+
354
+ ![](limitations/math_error.png)
355
+
356
+ </details>
357
+
358
+ - 产生有害说明或有偏见的内容:ChatGLM-6B 只是一个初步与人类意图对齐的语言模型,可能会生成有害、有偏见的内容。(内容可能具有冒犯性,此处不展示)
359
+
360
+ - 英文能力不足:ChatGLM-6B 训练时使用的指示/回答大部分都是中文的,仅有极小一部分英文内容。因此,如果输入英文指示,回复的质量远不如中文,甚至与中文指示下的内容矛盾,并且出现中英夹杂的情况。
361
+
362
+ - 易被误导,对话能力较弱:ChatGLM-6B 对话能力还比较弱,而且 “自我认知” 存在问题,并很容易被误导并产生错误的言论。例如当前版本的模型在被误导的情况下,会在自我认知上发生偏差。
363
+ <details><summary><b>点击查看例子</b></summary>
364
+
365
+ ![](limitations/self-confusion_google.jpg)
366
+
367
+ ![](limitations/self-confusion_openai.jpg)
368
+
369
+ ![](limitations/self-confusion_tencent.jpg)
370
+
371
+ </details>
372
+
373
+ ## 协议
374
+
375
+ 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。ChatGLM-6B 权重对学术研究**完全开放**,在填写[问卷](https://open.bigmodel.cn/mla/form)进行登记后**亦允许免费商业使用**。
376
+
377
+ ## 引用
378
+
379
+ 如果你觉得我们的工作有帮助的话,请考虑引用下列论文
380
+
381
+ ```
382
+ @article{zeng2022glm,
383
+ title={Glm-130b: An open bilingual pre-trained model},
384
+ author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others},
385
+ journal={arXiv preprint arXiv:2210.02414},
386
+ year={2022}
387
+ }
388
+ ```
389
+ ```
390
+ @inproceedings{du2022glm,
391
+ title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
392
+ author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
393
+ booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
394
+ pages={320--335},
395
+ year={2022}
396
+ }
397
+ ```
README_en.md ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatGLM-6B
2
+
3
+ <p align="center">
4
+ 🌐 <a href="https://chatglm.cn/blog" target="_blank">Blog</a> • 🤗 <a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2210.02414" target="_blank">[GLM-130B@ICLR 23]</a> <a href="https://github.com/THUDM/GLM-130B" target="_blank">[GitHub]</a> <br>
5
+ </p>
6
+ <p align="center">
7
+ 👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="resources/WECHAT.md" target="_blank">WeChat</a>
8
+ </p>
9
+
10
+ ## Introduction
11
+
12
+ ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). Welcome to use the larger ChatGLM model on [chatglm.cn](https://chatglm.cn)
13
+
14
+ ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.
15
+
16
+ In order to facilitate downstream developers to customize the model for their own application scenarios, we also implements an parameter-efficient tuning method based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2)[(Guidelines)](ptuning/README_en.md). Tuning requires at least 7GB of GPU memory at INT4 quantization level.
17
+
18
+ ChatGLM-6B weights are **completely open** for academic research, and **free commercial use** is also allowed after completing the [questionnaire](https://open.bigmodel.cn/mla/form).
19
+
20
+ Try the [online demo](https://huggingface.co/spaces/ysharma/ChatGLM-6b_Gradio_Streaming) on Huggingface Spaces.
21
+
22
+ ## Update
23
+ **[2023/07/25]** Release [CodeGeeX2](https://github.com/THUDM/CodeGeeX2), which is based on ChatGLM2-6B and trained on more code data. It has the following features:
24
+
25
+ * **More Powerful Coding Capabilities**: CodeGeeX2-6B has been further pre-trained on 600B code tokens, which has been comprehensively improved in coding capability compared to the first-generation. On the [HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x) benchmark, all six languages have been significantly improved (Python +57%, C++ +71%, Java +54%, JavaScript +83%, Go +56%, Rust +321\%), and in Python it reached 35.9% of Pass@1 one-time pass rate, surpassing the larger StarCoder-15B.
26
+ * **More Useful Features**: Inheriting the ChatGLM2-6B model features, CodeGeeX2-6B better supports both Chinese and English prompts, maximum 8192 sequence length, and the inference speed is significantly improved compared to the first-generation. After quantization, it only needs 6GB of GPU memory for inference, thus supports lightweight local deployment.
27
+ * **Comprehensive AI Coding Assistant**: The backend of CodeGeeX plugin ([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex)) is upgraded, supporting 100+ programming languages, and adding practical functions such as infilling and cross-file completion. Combined with the "Ask CodeGeeX" interactive AI coding assistant, it can be used to solve various programming problems via Chinese or English dialogue, including but not limited to code summarization, code translation, debugging, and comment generation, which helps increasing the efficiency of developpers.
28
+
29
+ **[2023/06/25]** Release [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), the second-generation version of ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the following new features:
30
+
31
+ 1. **Stronger Performance**: Based on the development experience of the first-generation ChatGLM model, we have fully upgraded the base model of ChatGLM2-6B. ChatGLM2-6B uses the hybrid objective function of [GLM](https://github.com/THUDM/GLM), and has undergone pre-training with 1.4T bilingual tokens and human preference alignment training. The [evaluation results](README.md#evaluation-results) show that, compared to the first-generation model, ChatGLM2-6B has achieved substantial improvements in performance on datasets like MMLU (+23%), CEval (+33%), GSM8K (+571%), BBH (+60%), showing strong competitiveness among models of the same size.
32
+ 2. **Longer Context**: Based on [FlashAttention](https://github.com/HazyResearch/flash-attention) technique, we have extended the context length of the base model from 2K in ChatGLM-6B to 32K, and trained with a context length of 8K during the dialogue alignment, allowing for more rounds of dialogue. However, the current version of ChatGLM2-6B has limited understanding of single-round ultra-long documents, which we will focus on optimizing in future iterations.
33
+ 3. **More Efficient Inference**: Based on [Multi-Query Attention](http://arxiv.org/abs/1911.02150) technique, ChatGLM2-6B has more efficient inference speed and lower GPU memory usage: under the official implementation, the inference speed has increased by 42% compared to the first generation; under INT4 quantization, the dialogue length supported by 6G GPU memory has increased from 1K to 8K.
34
+
35
+ Fore more information, please refer to [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B).
36
+
37
+ **[2023/05/17]** Release [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B), a multimodal conversational language model supporting image understanding.
38
+
39
+ ![](resources/visualglm.png)
40
+
41
+ You can run the command line and web demo through [cli_demo_vision.py](cli_demo_vision.py) and [web_demo_vision.py](web_demo_vision.py) in the repository. Note that VisualGLM-6B requires additional installation of [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) and torchvision. For more information, please refer to [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B).
42
+
43
+ **[2023/05/15]** Update the checkpoint of v1.1 version, add English instruction data for training to balance the proportion of Chinese and English data, which solves the phenomenon of Chinese words mixed in English answers .
44
+
45
+ <details><summary><b>The following is a comparison of English questions before and after the update</b></summary>
46
+
47
+ * Question: Describe a time when you had to make a difficult decision.
48
+ - v1.0:
49
+ ![](resources/english-q1-old.png)
50
+ - v1.1:
51
+ ![](resources/english-q1-new.png)
52
+ * Question: Describe the function of a computer motherboard
53
+ - v1.0:
54
+ ![](resources/english-q2-old.png)
55
+ - v1.1:
56
+ ![](resources/english-q2-new.png)
57
+ * Question: Develop a plan to reduce electricity usage in a home.
58
+ - v1.0:
59
+ ![](resources/english-q3-old.png)
60
+ - v1.1:
61
+ ![](resources/english-q3-new.png)
62
+ * Question:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语
63
+ - v1.0:
64
+ ![](resources/english-q4-old.png)
65
+ - v1.1:
66
+ ![](resources/english-q4-new.png)
67
+ </details>
68
+
69
+ For more update info, please refer to [UPDATE.md](UPDATE.md).
70
+
71
+ ## Projects
72
+ Open source projects that accelerate ChatGLM:
73
+ * [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): Inference acceleration for ChatGLM-6B, up to 9000+ tokens/s inference speed.
74
+ * [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): An MNN-based implementation of ChatGLM-6B C++ inference, which supports automatic allocation of computing tasks to GPU and CPU according to the size of GPU memory
75
+ * [JittorLLMs](https://github.com/Jittor/JittorLLMs): Running ChatGLM-6B in FP16 with a minimum of 3G GPU memory or no GPU at all, with Linux, windows, and Mac support
76
+ * [InferLLM](https://github.com/MegEngine/InferLLM): Lightweight C++ inference, which can realize real-time chat on local x86 and Arm processors, and can also run in real time on mobile phones. It only requires 4G of running memory.
77
+
78
+ Open source projects using ChatGLM-6B:
79
+ * [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM): ChatGLM application based on langchain, realizing Q&A based on extensible knowledge base
80
+ * [Wenda](https://github.com/l15y/wenda): Large-scale language model call platform, based on ChatGLM-6B to achieve ChatPDF-like functions
81
+ * [chatgpt_academic](https://github.com/binary-husky/chatgpt_academic): An academic writing and programming toolbox that supports ChatGLM-6B. It has the characteristics of modularization and multi-thread calling LLM, and can call multiple LLMs in parallel.
82
+ * [glm-bot](https://github.com/initialencounter/glm-bot): Connect ChatGLM to Koishi to call ChatGLM on major chat platforms
83
+
84
+ Example projects supporting online training of ChatGLM-6B and related applications:
85
+ * [ChatGLM-6B deployment and fine-tuning tutorial](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e)
86
+ * [ChatGLM-6B combined with langchain to implement local knowledge base QA Bot](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
87
+
88
+ Third-party evaluation:
89
+ * [Measuring Massive Multitask Chinese Understanding](https://arxiv.org/abs/2304.12986)
90
+
91
+ For more open source projects, see [PROJECT.md](PROJECT.md).
92
+
93
+ ## Getting Started
94
+
95
+ ### Hardware Requirements
96
+
97
+ | **Quantization Level** | **GPU Memory** |
98
+ |------------------------|----------------|
99
+ | FP16(no quantization) | 13 GB |
100
+ | INT8 | 10 GB |
101
+ | INT4 | 6 GB |
102
+
103
+ ### Environment Setup
104
+
105
+ Install the requirements with pip: `pip install -r requirements.txt`. `transformers` library version is recommended to be `4.27.1`, but theoretically any version no lower than `4.23.1` is acceptable.
106
+
107
+ In addition, if you need to run the quantified model on the CPU, you also need to install `gcc` and `openmp`. Most Linux distributions are installed by default. For Windows, you can check `openmp` when installing [TDM-GCC](https://jmeubank.github.io/tdm-gcc/). On Windows testing environment, the `gcc` version is `TDM-GCC 10.3.0`, and on Linux is `gcc 11.3.0`.
108
+
109
+ ### Usage
110
+
111
+ Generate dialogue with the following code
112
+
113
+ ```python
114
+ >>> from transformers import AutoTokenizer, AutoModel
115
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
116
+ >>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
117
+ >>> model = model.eval()
118
+ >>> response, history = model.chat(tokenizer, "你好", history=[])
119
+ >>> print(response)
120
+ 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
121
+ >>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
122
+ >>> print(response)
123
+ 晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:
124
+
125
+ 1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。
126
+ 2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。
127
+ 3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。
128
+ 4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。
129
+ 5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。
130
+ 6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。
131
+
132
+ 如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
133
+ ```
134
+ The implementation of the model is still in development. If you want to fix the used model implementation to ensure compatibility, you can add the `revision="v1.1.0"` parameter in the `from_pretrained` call. `v1.1.0` is the latest version number. For a complete list of versions, see [Change Log](https://huggingface.co/THUDM/chatglm-6b#change-log).
135
+
136
+ ### Load the model locally
137
+ The above code will automatically download the model implementation and checkpoints by [transformers](https://github.com/huggingface/transformers). The full model implementation can be found at [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b). If your network environment is poor, downloading model parameters may take a long time or even fail. At this point, you can download the model to the local first, and then load it from the local.
138
+
139
+ To download models from Hugging Face Hub, you need to [install Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) , then run
140
+ ```Shell
141
+ git clone https://huggingface.co/THUDM/chatglm-6b
142
+ ```
143
+
144
+ After downloading the model locally, replace `THUDM/chatglm-6b` in the above code with the path of your local `chatglm-6b` folder to load the model locally.
145
+
146
+ **Optional**: The implementation of the model is still in development. If you want to fix the used model implementation to ensure compatibility, you can execute
147
+ ```Shell
148
+ git checkout v1.1.0
149
+ ```
150
+
151
+ ## Demo & API
152
+
153
+ We provide a Web demo based on [Gradio](https://gradio.app) and a command line demo in the repo. First clone our repo with:
154
+
155
+ ```shell
156
+ git clone https://github.com/THUDM/ChatGLM-6B
157
+ cd ChatGLM-6B
158
+ ```
159
+
160
+ ### Web Demo
161
+
162
+ ![web-demo](resources/web-demo.gif)
163
+
164
+ Install Gradio `pip install gradio`,and run [web_demo.py](web_demo.py):
165
+
166
+ ```shell
167
+ python web_demo.py
168
+ ```
169
+
170
+ The program runs a web server and outputs the URL. Open the URL in the browser to use the web demo.
171
+
172
+ Thanks to [@AdamBear](https://github.com/AdamBear) for implementing a web demo based on Streamlit, see [#117](https://github.com/THUDM/ChatGLM-6B/pull/117 ).
173
+
174
+ #### CLI Demo
175
+
176
+ ![cli-demo](resources/cli-demo.png)
177
+
178
+ Run [cli_demo.py](cli_demo.py) in the repo:
179
+
180
+ ```shell
181
+ python cli_demo.py
182
+ ```
183
+
184
+ The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program.
185
+
186
+ ## API Deployment
187
+ First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo.
188
+ ```shell
189
+ python api.py
190
+ ```
191
+ By default the api runs at the`8000`port of the local machine. You can call the API via
192
+ ```shell
193
+ curl -X POST "http://127.0.0.1:8000" \
194
+ -H 'Content-Type: application/json' \
195
+ -d '{"prompt": "你好", "history": []}'
196
+ ```
197
+ The returned value is
198
+ ```shell
199
+ {
200
+ "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。",
201
+ "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]],
202
+ "status":200,
203
+ "time":"2023-03-23 21:38:40"
204
+ }
205
+ ```
206
+
207
+ ## Deployment
208
+
209
+ ### Quantization
210
+
211
+ By default, the model parameters are loaded with FP16 precision, which require about 13GB of GPU memory. It your GPU memory is limited, you can try to load the model parameters with quantization:
212
+
213
+ ```python
214
+ # Change according to your hardware. Only support 4/8 bit quantization now.
215
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(8).cuda()
216
+ ```
217
+
218
+ After 2 to 3 rounds of dialogue, the GPU memory usage is about 10GB under 8-bit quantization, and only 6GB under 4-bit quantization. As the number of dialogue rounds increases, the corresponding GPU memory consumption also increases. Due to the use of relative position encoding, ChatGLM-6B theoretically supports an infinitely long context-length, but the performance will gradually decline after the total length exceeds 2048 (training length).
219
+
220
+ Model quantization brings a certain performance decline. After testing, ChatGLM-6B can still perform natural and smooth generation under 4-bit quantization. using [GPT-Q](https://arxiv.org/abs/2210.17323) etc. The quantization scheme can further compress the quantization accuracy/improve the model performance under the same quantization accuracy. You are welcome to submit corresponding Pull Requests.
221
+
222
+ The quantization costs about 13GB of CPU memory to load the FP16 model. If your CPU memory is limited, you can directly load the quantized model, which costs only 5.2GB CPU memory:
223
+ ```python
224
+ # For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8"
225
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
226
+ ```
227
+
228
+ ### CPU Deployment
229
+
230
+ If your computer is not equipped with GPU, you can also conduct inference on CPU, but the inference speed is slow (and taking about 32GB of memory):
231
+
232
+ ```python
233
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
234
+ ```
235
+
236
+ If your CPU memory is limited, you can directly load the quantized model:
237
+ ```python
238
+ # For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8"
239
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float()
240
+ ```
241
+
242
+ If your encounter the error `Could not find module 'nvcuda.dll'` or `RuntimeError: Unknown platform: darwin`(MacOS), please [load the model locally](README_en.md#load-the-model-locally).
243
+
244
+ ### Inference on Mac
245
+ For Macs (and MacBooks) with Apple Silicon, it is possible to use the MPS backend to run ChatGLM-6B on the GPU. First, you need to refer to Apple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly. (The correct version number should be 2.1.0.dev2023xxxx, not 2.0.0).
246
+
247
+ Currently you must [load the model locally](README_en.md#load-the-model-locally) on MacOS. Change the code to load the model from your local path, and use the mps backend:
248
+ ```python
249
+ model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps')
250
+ ```
251
+
252
+ Loading a FP16 ChatGLM-6B model requires about 13GB of memory. Machines with less memory (such as a MacBook Pro with 16GB of memory) will use the virtual memory on the hard disk when there is insufficient free memory, resulting in a serious slowdown in inference speed. At this time, a quantized model such as chatglm-6b-int4 can be used. Because the quantized kernel on the GPU is written in CUDA, it cannot be used on MacOS, and can only be inferred using the CPU:
253
+
254
+ ```python
255
+ # For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8"
256
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float()
257
+ ```
258
+
259
+ ### Multi-GPU Deployment
260
+ If you have multiple GPUs, but the memory size of each GPU is not sufficient to accommodate the entire model, you can split the model across multiple GPUs.
261
+
262
+ First, install accelerate: `pip install accelerate`, and then load the model using the following method:
263
+ ```python
264
+ from utils import load_model_on_gpus
265
+ model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
266
+ ```
267
+
268
+ This will deploy the model onto two GPUs for inference. You can change `num_gpus` to the number of GPUs you want to use. By default, the model is split evenly, but you can also specify the `device_map` parameter to customize the splitting.
269
+
270
+ ## Parameter-efficient Tuning
271
+ Parameter-efficient tuning based on [P-tuning v2](https://github.com/THUDM/P-tuning-v2). See [ptuning/README.md](ptuning/README.md) for details on how to use it.
272
+
273
+ ## ChatGLM-6B Examples
274
+
275
+ The following are some Chinese examples with `web_demo.py`. Welcome to explore more possibility with ChatGLM-6B.
276
+
277
+ <details><summary><b>Self Cognition</b></summary>
278
+
279
+ ![](examples/self-introduction.png)
280
+
281
+ </details>
282
+
283
+ <details><summary><b>Outline</b></summary>
284
+
285
+ ![](examples/blog-outline.png)
286
+
287
+ </details>
288
+
289
+ <details><summary><b>Ad</b></summary>
290
+
291
+ ![](examples/ad-writing-2.png)
292
+
293
+ ![](examples/comments-writing.png)
294
+
295
+ </details>
296
+
297
+ <details><summary><b>Email</b></summary>
298
+
299
+ ![](examples/email-writing-1.png)
300
+
301
+ ![](examples/email-writing-2.png)
302
+
303
+ </details>
304
+
305
+ <details><summary><b>Information Extraction</b></summary>
306
+
307
+ ![](examples/information-extraction.png)
308
+
309
+ </details>
310
+
311
+ <details><summary><b>Role Play</b></summary>
312
+
313
+ ![](examples/role-play.png)
314
+
315
+ </details>
316
+
317
+ <details><summary><b>Comparison</b></summary>
318
+
319
+ ![](examples/sport.png)
320
+
321
+ </details>
322
+
323
+ <details><summary><b>Travel Guide</b></summary>
324
+
325
+ ![](examples/tour-guide.png)
326
+
327
+ </details>
328
+
329
+ ## License
330
+
331
+ This repository is licensed under the [Apache-2.0 License](LICENSE). The use of ChatGLM-6B model weights is subject to the [Model License](MODEL_LICENSE)。
332
+
333
+ ## Citation
334
+
335
+ If you find our work useful, please consider citing the following papers:
336
+
337
+ ```
338
+ @inproceedings{
339
+ zeng2023glm-130b,
340
+ title={{GLM}-130B: An Open Bilingual Pre-trained Model},
341
+ author={Aohan Zeng and Xiao Liu and Zhengxiao Du and Zihan Wang and Hanyu Lai and Ming Ding and Zhuoyi Yang and Yifan Xu and Wendi Zheng and Xiao Xia and Weng Lam Tam and Zixuan Ma and Yufei Xue and Jidong Zhai and Wenguang Chen and Zhiyuan Liu and Peng Zhang and Yuxiao Dong and Jie Tang},
342
+ booktitle={The Eleventh International Conference on Learning Representations (ICLR)},
343
+ year={2023},
344
+ url={https://openreview.net/forum?id=-Aw0rrrPUF}
345
+ }
346
+ ```
347
+
348
+ ```
349
+ @inproceedings{du2022glm,
350
+ title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
351
+ author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
352
+ booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
353
+ pages={320--335},
354
+ year={2022}
355
+ }
356
+ ```
UPDATE.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 更新信息
2
+ **[2023/05/17]** 发布 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B),一个支持图像理解的多模态对话语言模型。
3
+
4
+ ![](resources/visualglm.png)
5
+
6
+ 可以通过本仓库中的 [cli_demo_vision.py](cli_demo_vision.py) 和 [web_demo_vision.py](web_demo_vision.py) 来运行命令行和网页 Demo。注意 VisualGLM-6B 需要额外安装 [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) 和 torchvision。更多信息参见 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B)。
7
+
8
+ **[2023/05/15]** 更新 v1.1 版本 checkpoint,训练数据增加英文数据以平衡中英文数据比例,解决英文回答中夹杂中文词语的现象。
9
+
10
+ <details><summary><b>以下是更新前后的英文问题对比:</b></summary>
11
+
12
+ * 问题:Describe a time when you had to make a difficult decision.
13
+ - v1.0:
14
+ ![](resources/english-q1-old.png)
15
+ - v1.1:
16
+ ![](resources/english-q1-new.png)
17
+ * 问题:Describe the function of a computer motherboard
18
+ - v1.0:
19
+ ![](resources/english-q2-old.png)
20
+ - v1.1:
21
+ ![](resources/english-q2-new.png)
22
+ * 问题:Develop a plan to reduce electricity usage in a home.
23
+ - v1.0:
24
+ ![](resources/english-q3-old.png)
25
+ - v1.1:
26
+ ![](resources/english-q3-new.png)
27
+ * 问题:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语
28
+ - v1.0:
29
+ ![](resources/english-q4-old.png)
30
+ - v1.1:
31
+ ![](resources/english-q4-new.png)
32
+ </details>
33
+
34
+ **[2023/04/16]** 增加 INT8 量化后的模型 [ChatGLM-6B-INT8](https://huggingface.co/THUDM/chatglm-6b-int8)。增加多卡部署(感谢 [@Cherrysaber](https://github.com/Cherrysaber))。
35
+
36
+ **[2023/04/06]** 优化web demo的界面(感谢 [@tuteng0915](https://github.com/tuteng0915))。移除embedding中的image token以减小显存占用(需要更新模型文件`pytorch_model-00001-of-00008.bin`和`pytorch_model-00008-of-00008.bin`,感谢 [@silverriver](https://github.com/silverriver) 提出的想法)。去掉了对 `icetk` 的依赖(需要更新模型文件`ice_text.model`)。
37
+
38
+ **[2023/03/31]** 增加基于 [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调实现,INT4 量化级别下最低只需 7GB 显存即可进行模型微调。详见[高效参数微调方法](ptuning/README.md)。
39
+
40
+ **[2023/03/23]** 增加 API 部署(感谢 [@LemonQu-GIT](https://github.com/LemonQu-GIT))。~~增加 Embedding 量化模型 [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)~~ (已停止维护)。增加配备 Apple Silicon 芯片的 Mac 上 GPU 加速的支持。
41
+
42
+ **[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加 INT4 量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4)
43
+
44
+
45
+ ## Update
46
+ **[2023/05/17]** Release [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B), a multimodal conversational language model supporting image understanding.
47
+
48
+ ![](resources/visualglm.png)
49
+
50
+ You can run the command line and web demo through [cli_demo_vision.py](cli_demo_vision.py) and [web_demo_vision.py](web_demo_vision.py) in the repository. Note that VisualGLM-6B requires additional installation of [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) and torchvision. For more information, please refer to [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B).
51
+
52
+ **[2023/05/15]** Update the checkpoint of v1.1 version, add English instruction data for training to balance the proportion of Chinese and English data, which solves the phenomenon of Chinese words mixed in English answers .
53
+
54
+ <details><summary><b>The following is a comparison of English questions before and after the update</b></summary>
55
+
56
+ * Question: Describe a time when you had to make a difficult decision.
57
+ - v1.0:
58
+ ![](resources/english-q1-old.png)
59
+ - v1.1:
60
+ ![](resources/english-q1-new.png)
61
+ * Question: Describe the function of a computer motherboard
62
+ - v1.0:
63
+ ![](resources/english-q2-old.png)
64
+ - v1.1:
65
+ ![](resources/english-q2-new.png)
66
+ * Question: Develop a plan to reduce electricity usage in a home.
67
+ - v1.0:
68
+ ![](resources/english-q3-old.png)
69
+ - v1.1:
70
+ ![](resources/english-q3-new.png)
71
+ * Question:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语
72
+ - v1.0:
73
+ ![](resources/english-q4-old.png)
74
+ - v1.1:
75
+ ![](resources/english-q4-new.png)
76
+ </details>
77
+
78
+ **[2023/04/16]** Added INT8 quantized model [ChatGLM-6B-INT8](https://huggingface.co/THUDM/chatglm-6b-int8). Added multi-GPU deployment (thanks to [@Cherrysaber](https://github.com/Cherrysaber)).
79
+
80
+ **[2023/04/06]** Improve the web demo interface (thanks to [@tuteng0915](https://github.com/tuteng0915)). Remove the image tokens in the embedding layer to reduce the memory usage (need to update the model files `pytorch_model-00001-of-00008.bin` and `pytorch_model-00008-of-00008.bin`, thanks to [@silverriver](https:/ /github.com/silverriver) for proposing the idea). Removed dependency on `icetk` (need to update model file `ice_text.model`).
81
+
82
+ **[2023/03/31]** Added a parameter-efficient tuning implementation based on [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2). The minimum INT4 quantization level only needs 7GB GPU memory is enough for model tuning. See [Parameter-efficient tuning method](ptuning/README.md) for details.
83
+
84
+ **[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe). Add support for GPU inference on Mac with Apple Silicon.
85
+
86
+ **[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4).
api.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import uvicorn, json, datetime
4
+ import torch
5
+
6
+ DEVICE = "cuda"
7
+ DEVICE_ID = "0"
8
+ CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
9
+
10
+
11
+ def torch_gc():
12
+ if torch.cuda.is_available():
13
+ with torch.cuda.device(CUDA_DEVICE):
14
+ torch.cuda.empty_cache()
15
+ torch.cuda.ipc_collect()
16
+
17
+
18
+ app = FastAPI()
19
+
20
+
21
+ @app.post("/")
22
+ async def create_item(request: Request):
23
+ global model, tokenizer
24
+ json_post_raw = await request.json()
25
+ json_post = json.dumps(json_post_raw)
26
+ json_post_list = json.loads(json_post)
27
+ prompt = json_post_list.get('prompt')
28
+ history = json_post_list.get('history')
29
+ max_length = json_post_list.get('max_length')
30
+ top_p = json_post_list.get('top_p')
31
+ temperature = json_post_list.get('temperature')
32
+ response, history = model.chat(tokenizer,
33
+ prompt,
34
+ history=history,
35
+ max_length=max_length if max_length else 2048,
36
+ top_p=top_p if top_p else 0.7,
37
+ temperature=temperature if temperature else 0.95)
38
+ now = datetime.datetime.now()
39
+ time = now.strftime("%Y-%m-%d %H:%M:%S")
40
+ answer = {
41
+ "response": response,
42
+ "history": history,
43
+ "status": 200,
44
+ "time": time
45
+ }
46
+ log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
47
+ print(log)
48
+ torch_gc()
49
+ return answer
50
+
51
+
52
+ if __name__ == '__main__':
53
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
54
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
55
+ model.eval()
56
+ uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
cli_demo.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import signal
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import readline
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
8
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
9
+ model = model.eval()
10
+
11
+ os_name = platform.system()
12
+ clear_command = 'cls' if os_name == 'Windows' else 'clear'
13
+ stop_stream = False
14
+
15
+
16
+ def build_prompt(history):
17
+ prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
18
+ for query, response in history:
19
+ prompt += f"\n\n用户:{query}"
20
+ prompt += f"\n\nChatGLM-6B:{response}"
21
+ return prompt
22
+
23
+
24
+ def signal_handler(signal, frame):
25
+ global stop_stream
26
+ stop_stream = True
27
+
28
+
29
+ def main():
30
+ history = []
31
+ global stop_stream
32
+ print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
33
+ while True:
34
+ query = input("\n用户:")
35
+ if query.strip() == "stop":
36
+ break
37
+ if query.strip() == "clear":
38
+ history = []
39
+ os.system(clear_command)
40
+ print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
41
+ continue
42
+ count = 0
43
+ for response, history in model.stream_chat(tokenizer, query, history=history):
44
+ if stop_stream:
45
+ stop_stream = False
46
+ break
47
+ else:
48
+ count += 1
49
+ if count % 8 == 0:
50
+ os.system(clear_command)
51
+ print(build_prompt(history), flush=True)
52
+ signal.signal(signal.SIGINT, signal_handler)
53
+ os.system(clear_command)
54
+ print(build_prompt(history), flush=True)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
cli_demo_vision.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import signal
4
+ import sys
5
+
6
+ from transformers import AutoTokenizer, AutoModel
7
+ import readline
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
10
+ model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
11
+ model = model.eval()
12
+
13
+ os_name = platform.system()
14
+ clear_command = 'cls' if os_name == 'Windows' else 'clear'
15
+ stop_stream = False
16
+
17
+
18
+ def build_prompt(history, prefix):
19
+ prompt = prefix
20
+ for query, response in history:
21
+ prompt += f"\n\n用户:{query}"
22
+ prompt += f"\n\nChatGLM-6B:{response}"
23
+ return prompt
24
+
25
+
26
+ def signal_handler(signal, frame):
27
+ global stop_stream
28
+ stop_stream = True
29
+
30
+
31
+ def main():
32
+ global stop_stream
33
+ while True:
34
+ history = []
35
+ prefix = "欢迎使用 VisualGLM-6B 模型,输入图片路径和内容即可进行对话,clear 清空对话历史,stop 终止程序"
36
+ print(prefix)
37
+ image_path = input("\n请输入图片路径:")
38
+ if image_path == "stop":
39
+ break
40
+ prefix = prefix + "\n" + image_path
41
+ query = "描述这张图片。"
42
+ while True:
43
+ count = 0
44
+ for response, history in model.stream_chat(tokenizer, image_path, query, history=history):
45
+ if stop_stream:
46
+ stop_stream = False
47
+ break
48
+ else:
49
+ count += 1
50
+ if count % 8 == 0:
51
+ os.system(clear_command)
52
+ print(build_prompt(history, prefix), flush=True)
53
+ signal.signal(signal.SIGINT, signal_handler)
54
+ os.system(clear_command)
55
+ print(build_prompt(history, prefix), flush=True)
56
+ query = input("\n用户:")
57
+ if query.strip() == "clear":
58
+ break
59
+ if query.strip() == "stop":
60
+ sys.exit(0)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ main()
examples/ad-writing-2.png ADDED
examples/blog-outline.png ADDED
examples/comments-writing.png ADDED
examples/email-writing-1.png ADDED
examples/email-writing-2.png ADDED
examples/information-extraction.png ADDED
examples/role-play.png ADDED
examples/self-introduction.png ADDED
examples/sport.png ADDED
examples/tour-guide.png ADDED
improve/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatGLM-6B Badcase 反馈计划
2
+ ChatGLM-6B 自3月14号发布以来受到了广大开发者和用户的喜爱,截至4月22号 GitHub 的 star 数达到 2 万,各个渠道模型的累计下载量过 100 万,并连续 12 天居 Hugging Face 全球大模型下载榜第一名。 与此同时,有一批基于 ChatGLM-6B 的[优秀开源项目](https://github.com/THUDM/ChatGLM-6B)出现,在各个平台也引起了广泛好评和关注。此外,基于 GLM-130B 的千亿对话模型 ChatGLM 也自3月14号开始了第一阶段的邀请制内测,得到了内测用户的好评和支持。谢谢大家对 ChatGLM 及其 6B 开源版本的大力支持!
3
+
4
+ 接下来,我们想邀请大家一起推动 ChatGLM-6B 的进一步提升,一起推动模型的发展。尽管ChatGLM-6B已初具符合人类偏好的问答对话能力,在相当多的指令和问题上,其回答仍存在不理解复杂指令和任务含义,缺乏领域概念理解,事实性错误,生成有害内容,对话上下文不一致等诸多问题。尽管我们提供的[微调代码](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)能够让用户通过自主训练修复部分问题,但因为神经网络的[灾难性遗忘](https://picture.iczhiku.com/weixin/message1587593113355.html)问题,微调后的模型往往会失去在通用领域的对话能力或者因数据较少而缺乏泛化能力。为了解决这些问题,进一步提升 ChatGLM-6B 的能力,我们启动了 ChatGLM-6B Badcase 反馈计划。
5
+
6
+ 具体来说,对于在使用 ChatGLM-6B 过程中遇到的表现不佳的Badcase对应的具体指令和提问,您可以修改或从头撰写您认为合适的正确答案,并反馈给我们改进 ChatGLM-6B。**请您确保提交的数据不包含任何个人信息、商业秘密或可能危害国家安全、侵害第三方知识产权的内容。** 我们会定期(每2-4周)对数据的有用性与正确性进行筛选,将筛选通过的数据,与通用域的对话数据一起加入到模型训练中,并**更新发布开源的模型参数**。**您提供的数据无论是否筛选通过,除非获得您的许可或根据国家法律规定和监管要求外,我们不会将您提供的数据对外公开。**
7
+
8
+ 您提供的数据如被筛选通过,您将同时优先获得最新版本的 ChatGLM-6B 模型的体验资格。此外,如果您愿意,您的用户名还将出现在 ChatGLM-6B Github页面的数据贡献者名单中,以此感谢您对推进大模型开源事业发展的帮助。您的支持和建议将为我们优化大模型提供源源不断的动力,在此向您表达我们最深的敬意与感谢!
9
+
10
+ 如果您希望参与反馈,请填写[问卷](https://www.wjx.cn/vm/rAoGx9X.aspx#)并按照具体要求上传。提交的数据为 jsonline 格式,每行的内容为
11
+ ```json lines
12
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿x s裤", "response": "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"}
13
+ ```
14
+ 其中,`prompt` 部分为模型的输入,`response` 部分为期望的模型输出。为了保证模型的性能,请在输入中尽可能详细地表述任务的类型和期望的输出格式。针对某项具体的任务,为了使模型得到充分的训练,一般需要100条左右的训练数据。
improve/data_sample.jsonl ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤", "response": "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"}
2
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖", "response": "圆形领口修饰脖颈线条,适合各种脸型,耐看有气质。无袖设计,尤显清凉,简约横条纹装饰,使得整身人鱼造型更为生动立体。加之撞色的鱼尾下摆,深邃富有诗意。收腰包臀,修饰女性身体曲线,结合别出心裁的鱼尾裙摆设计,勾勒出自然流畅的身体轮廓,展现了婀娜多姿的迷人姿态。"}
3
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#粉红色*图案#字母*图案#文字*图案#线条*衣样式#卫衣*衣款式#不规则", "response": "宽松的卫衣版型包裹着整个身材,宽大的衣身与身材形成鲜明的对比描绘出纤瘦的身形。下摆与袖口的不规则剪裁设计,彰显出时尚前卫的形态。被剪裁过的样式呈现出布条状自然地垂坠下来,别具有一番设计感。线条分明的字母样式有着花式的外观,棱角分明加上具有少女元气的枣红色十分有年轻活力感。粉红色的衣身把肌肤衬托得很白嫩又健康。"}
4
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*材质#雪纺*风格#清新*裙型#a字*裙长#连衣裙", "response": "踩着轻盈的步伐享受在午后的和煦风中,让放松与惬意感为你免去一身的压力与束缚,仿佛要将灵魂也寄托在随风摇曳的雪纺连衣裙上,吐露出<UNK>微妙而又浪漫的清新之意。宽松的a字版型除了能够带来足够的空间,也能以上窄下宽的方式强化立体层次,携带出自然优雅的曼妙体验。"}
5
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#棉*颜色#蓝色*风格#潮*衣样式#polo*衣领型#polo领*衣袖长#短袖*衣款式#拼接", "response": "想要在人群中脱颖而出吗?那么最适合您的莫过于这款polo衫短袖,采用了经典的polo领口和柔软纯棉面料,让您紧跟时尚潮流。再配合上潮流的蓝色拼接设计,使您的风格更加出众。就算单从选料上来说,这款polo衫的颜色沉稳经典,是这个季度十分受大众喜爱的风格了,而且兼具舒适感和时尚感。"}
6
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#h*材质#蚕丝*风格#复古*图案#条纹*图案#复古*图案#撞色*衣样式#衬衫*衣领型#小立领", "response": "小女人十足的条纹衬衣,缎面一点点的复古,还有蓝绿色这种高级气质复古色,真丝材质,撞色竖条纹特别的现代感味道,直h型的裁剪和特别的衣长款式,更加独立性格。双层小立领,更显脸型。"}
7
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#网纱*颜色#粉红色*图案#线条*图案#刺绣*裙腰型#高腰*裙长#连衣裙*裙袖长#短袖*裙领型#圆领", "response": "这款连衣裙,由上到下都透出一丝迷人诱惑的女性魅力,经典圆领型,开口度恰好,露出你的迷人修长的脖颈线条,很是优雅气质,短袖设计,在这款上竟是撩人美貌,高腰线,散开的裙摆,到小腿的长度,遮住了腿部粗的部分,对身材有很好的修饰作用,穿起来很女神;裙身粉红色花枝重工刺绣,让人一眼难忘!而且在这种网纱面料上做繁复图案的绣花,是很考验工艺的,对机器的要求会更高,更加凸显我们的高品质做工;"}
8
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#文字*图案#印花*衣样式#卫衣", "response": "一款非常简洁大方的纯色卫衣,设计点在于胸前的“<UNK><UNK>”的中文字印花,新颖特别,让人眼前一亮。简单又吸睛的款式,而且不失时髦感,很适合个性年轻人。"}
9
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#黑色*颜色#灰色*颜色#姜黄色*风格#休闲*图案#线条*图案#撞色*衣样式#毛衣*衣袖型#落肩袖", "response": "看惯了灰色的冷淡和黑色的沉闷感,来一点醒目的彩色增添点活力吧。亮眼又吸睛的姜黄色色调,嫩肤显白非常的有设计感。趣味的撞色和宽松的版型相交辉映,修饰身形小缺点的同时,时尚又百搭。优雅的��肩袖,轻松修饰肩部线条,让毛衣上身凸显出一丝慵懒随性的休闲感,时尚魅力尽显。"}
10
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*风格#潮*图案#印花*图案#撞色*衣样式#衬衫*衣领型#圆领*衣长#中长款*衣长#常规*衣袖长#无袖", "response": "黑与白,两种最极端的颜色却轻松搭配成了经典,就像此款衬衣,无需过多装饰,仅色调就足够醒目个性,受潮<UNK>所喜欢。做了无袖中长款的样式,走路带风的感觉着实不错,圆领的设计,不是常规的衬衫领,少了点正式反而有种休闲感觉,适合孩子们穿着。后背大面积撞色印花装点,是时尚潮流的象征,也让衣衣不至于单调,轻松就能穿出彩。"}
11
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#街头*风格#休闲*风格#朋克*图案#字母*图案#文字*图案#印花*衣样式#卫衣*衣款式#连帽*衣款式#对称", "response": "个性休闲风的连帽卫衣造型时髦大方,宽松的版型剪裁让肉肉的小宝贝也可以穿着,保暖的连帽设计时刻给予宝贝温柔的呵护,袖子和后背别致时髦的字母印花点缀,满满的街头元素融入,演绎休闲朋克风,对称的小口袋美观大方,方便放置更多的随身物品。"}
12
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙款式#链条", "response": "简单大气的设计,不费吹灰之力就能搭配的时髦范儿。时尚的配色一点都不觉得平淡了,有种浑然天成的大气感。强调了整体的装饰,和谐又不失个性,搭配裤装帅气十足,搭配裙子精致优雅。链条和肩带的搭配让使用感更加舒服,单肩手提都好看。"}
13
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#显瘦*材质#牛仔布*颜色#深蓝色*风格#复古*图案#复古*图案#线条*裤腰型#高腰*裤口#微喇裤", "response": "深蓝色的高腰牛仔裤,修身的款式勾勒出纤细的美腿。牛仔裤的裤脚设计<UNK>张开的喇叭型,巧妙地修饰了小腿的线条,洋溢着复古的年代感。"}
14
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#清新*风格#潮*风格#性感*图案#条纹*图案#蝴蝶结*衣样式#衬衫*衣领型#一字领*衣门襟#系带*衣款式#不对称", "response": "这是一件显得特别清新的衬衣,采用了条纹的设计,给予人一种甜美可人的气质。并且融合了别致的斜肩一字领设计,高调的展示出性感的锁骨,将迷人的香肩展现在外,性感中不失去清纯的气息。袖口处的蝴蝶结系带装饰,增添了俏皮的韵味,简洁大方。且在下摆处采用了不对称的设计,增强了视觉效果,更显潮流。"}
15
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*风格#复古*图案#复古*裤型#直筒裤*裤款式#纽扣*裤腰型#高腰", "response": "作为基础款单品,牛仔裤也<UNK><UNK>,想要呈现给大家的是——每次搭配都有新感觉。裤子经过复古做旧处理,风格鲜明,也很注重细节,连纽扣也做了统一的做旧处理,融入个性十足的磨破设计,高腰直筒basic裤型,修饰身材,穿出高挑长腿。"}
16
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*版型#显瘦*图案#线条*图案#刺绣*衣样式#针织衫*衣领型#v领", "response": "一款温暖柔软又富有弹性的针织衫,不仅可以抵御严寒侵袭,还能更好地进行搭配。v领的设计,能勾勒出迷人的天鹅颈以及衬托出娇小的脸型。宽松又别致的剪裁,能从视觉上显露纤长的下半身,起到显瘦的效果。直筒造型的袖子,修饰出优美的手臂线条,衣身上的方格刺绣,时尚又吸睛。"}
17
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#绿色*风格#清新*图案#线条*衣样式#衬衫*衣领型#翻领", "response": "绿色的衣身上镶嵌着<UNK>,就是这款衬衫最大的迷人之处,“红花配绿叶”般的色调,将清新气息阐述的淋漓尽致。经典的翻领更是贴心,修饰颈部线条的同时,尽显精致干练的气质,出街轻松凹造型。"}
18
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*图案#字母*图案#文字*图案#印花*图案#撞色*衣样式#外套*衣门襟#拉链*衣款式#拉链", "response": "这款外套采用了撞色拉链织带以及字母印花设计。这两种元素的融入使外套不会显得过于单调沉闷,吸睛而亮眼,充满年轻与朝气感,非常减龄。"}
19
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*版型#h*风格#复古*图案#复古*图案#刺绣*裙长#连衣裙*裙袖长#长袖*裙领型#翻领*裙衣门襟#单排扣", "response": "本款连衣裙整体采用h型的轮廓设计,藏肉显瘦,不挑身材,��合各种身形的人穿着。小翻领的领口设计,使得本款连衣裙穿在身上看起来十分的精神帅气,具有青春活力。单排扣的衣门襟设计,又给本款连衣裙带来了一丝的复古味道。裙身上的刺绣花朵装饰,使得本款连衣裙不显得单调,富有层次感,上身给人一种独特的时尚魅力。长袖的设计,更加的贴合手臂曲线,上身更加的舒适贴身。"}
20
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#粉色*风格#清新*衣样式#外套*衣样式#西装*衣门襟#双排扣", "response": "这款外套设计成西装的版型,彰显经典优雅的气质,结合了粉色又添清新气息,甜美百搭时尚感满满。利落的版型简洁流畅,亮色双排扣更添精致感。"}
21
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*图案#线条*衣样式#风衣*衣样式#外套*衣门襟#拉链*衣款式#拉链*衣款式#松紧带*衣款式#连帽*衣款式#收腰", "response": "选自品牌江南布衣的一款女士长风衣外套,选用轻薄的<UNK><UNK>,穿着灵活毫无压力。直筒版型简洁利落,长过膝盖的长度穿着个性十足,连帽宽大有型,富有活力,<UNK>拉链开合,拉上拉链有一丝酷劲,敞开穿则更休闲,连帽领翻开修饰颈部线条。松紧带收腰设计,低调的分割上下比例,打造显高小心机。"}
22
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#棉*材质#牛仔布*风格#街头*风格#简约*图案#刺绣*裤长#短裤*裤款式#钉珠*裤口#毛边", "response": "又到了光腿穿裙子和短裤的时候了,BRAND的这款短裤,采用柔软透气的纯棉牛仔面料,穿着舒适无束缚感。而简约的版型加入了精美的刺绣和钉珠装饰,提升了整体的品质感,显得精美而又立体饱满。搭配下摆的毛边装饰,散发出不羁的街头感。"}
23
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#牛仔布*颜色#黑色*图案#条纹*衣样式#衬衫*衣领型#翻领*衣门襟#系带*衣款式#拼接*衣款式#露肩", "response": "一款老鹰图案露肩衬衫,露肩系带的设计,少女感十足。老鹰图案的设计,更添几分趣味感。条纹面料和牛仔面料的拼接设计,给人一种风度的层次感。小翻领的设计十分的精致,搭配一件黑色打底裤也吸晴万分。"}
24
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*裙型#百褶*裙长#半身裙*裙款式#拼接*裙款式#腰带", "response": "一款颇有设计感的半身裙,单侧雪纺百褶的拼接设计,规整排列的层次感带来立体效果,增加了裙身的廓形,行走间更是带来柔美的灵动气息,轻而易举穿出优雅的轻熟风,呈现十足的女人味来。同面料延伸处理的半固定腰带,可以自然的垂落下来,也算是为整体打造造型亮点,彰显你独特的时尚品味,迎合早春对轻盈的追求。"}
25
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#性感*裙型#包臀裙*裙型#鱼尾裙", "response": "修身包臀版型结合性感鱼尾裙摆设计,彰显婉约优雅风情之余,为整体注入几分俏皮灵动气息。且下摆辅以律动感摺裥元素,更烘托出女性浪漫精致的一面。"}
26
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*颜色#绿色*图案#线条*裙长#连衣裙*裙领型#v领*裙款式#勾花镂空", "response": "连衣裙可以让你在旋转与跳跃间,都散发出万种风情,受到了万千女性的喜爱。这款连衣裙选用绿色调,既散发出活力气息,又增添了高雅的气质。而镂空的钩花设计,则为其增添了浪漫的风情,同时更显美观与时尚。再加上v领的设计,不仅映衬出精致的脸颊,还打造出优美的颈部线条。"}
27
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#淑女*衣样式#毛衣*衣领型#高领", "response": "高领毛衣一直是网红妹子,因为穿着它有一种淑女甜美气质。它最大的亮点在于它的高领设计和花边装饰。在淑女干练的气质基础上又增加了一些少女的甜美气息,穿着非常有型,最佳搭配小白鞋。"}
28
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*裤长#九分裤*裤型#阔腿裤*裤款式#拉链*裤腰型#高腰*裤口#开叉", "response": "九分裤长,把妹子的拉长了腿的比例,配合高腰设计,瞬间显得妹子的腿长了很多,一下子自信满满啦。采用侧面隐藏拉链设计,穿脱方便又舒适。设计感十足的开叉裤脚,身上的摩登<UNK>浓了。这个春天妹子的腿型,就交给阔腿裤啦。"}
29
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#黑白*风格#复古*风格#文艺*图案#格子*图案#复古*衣样式#衬衫*衣领型#翻领*衣门襟#套头*衣��式#纽扣", "response": "经典的套头翻领衬衫与黑白格纹元素组合,一直以来的气场经久不衰。而采用复古精致的纽扣装点的半门襟设计,简单的小细节处理,彰显出浓浓的文艺气息。"}
30
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#黑白*风格#复古*风格#文艺*图案#格子*图案#复古*衣样式#衬衫*衣领型#翻领*衣门襟#套头*衣款式#纽扣", "response": "套头翻领衬衫与黑白格纹元素组合,一直以来的气场经久不衰。而采用复古精致的纽扣装点的半门襟设计,简单的小细节处理,彰显出浓浓的文艺气息。"}
31
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#卡通*裙长#连衣裙", "response": "传奇而又经典的卡通形象,米老鼠似乎已经成为了孩童风格的一种标志,大小不一的头像以及奇趣的表情设计。满版的点缀风格让整个连衣裙洋溢着独特的天真气质,加之面料小口袋的点缀,小小的造型呈现出灵巧而又可爱的格调,让宝贝俏皮萌动。"}
32
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#条纹*图案#刺绣*裙型#背带裙*裙下摆#毛边", "response": "假两件版型的设计,给人一种错觉,大大增添自身时髦感。毛边裙摆的采用,看起来活力十足。设计师解决了以往穿脱不方便的问题,应用的可调节背带设计,非常的人性化。裙子上的花朵刺绣图案,看起来也栩栩如生,同时也展示出了精湛的做工手艺。为了与女人自身清纯的一面形成呼应,应用的条纹图案非常完美。"}
33
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#牛仔布*风格#街头*风格#休闲*裤长#五分裤*裤腰型#松紧腰", "response": "这款休闲五分裤,采用亲肤软牛仔,洗水磨白形成深浅对比,更加个性。大弹力松紧腰,舒适贴合,一点都不紧勒。裤子门襟时尚的设计,为细节加分。立体双贴袋,腰间系带的点缀更吸睛。精致的裁剪,或是干练整洁的走线和宽松版型,是对街头的描写。"}
34
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*图案#条纹*图案#蕾丝*衣样式#衬衫*衣领型#立领*衣款式#荷叶边", "response": "条纹衬衫是引领时尚圈的常青树,尤其给人舒适感官享受的蓝白条纹,更是深得时尚icon的喜爱。加之搭载经典的立领秀出纤长的玉颈,更显气质优雅。肩膀上饰有薄薄的蕾丝,打破了条纹衬衫的干练,更添别样风情。荷叶边的蔓延更显气质甜美,自然吸睛无数。"}
35
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#牛仔布*风格#休闲*图案#线条*裙型#牛仔裙*裙长#半身裙", "response": "牛仔半身裙作为时尚宠儿,一直被很多潮人捧在手心,zui妙的莫过于它<UNK>的时髦感以及百搭易驾驭的属性。裙身设计了自然的腰型,拉长腿部线条,让小仙女们感受到大长腿。以及两侧插袋的造型,显得比较随意,休闲的感觉。"}
36
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#刺绣*衣样式#卫衣*衣袖型#罗纹袖口", "response": "乍一看很平凡的纯色系卫衣,<UNK>暗藏的刺绣玄机,就足够把时髦的张力表现得不凡。很有包容性的廓形,舒适的罗纹收口,宽大的样子却依旧<UNK>既定的风格,让你的潇洒随性表现得收放自如。呆萌查理的袖间刺绣,极简的漫画笔触巧塑生动有趣的风格。"}
37
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*衣样式#衬衫*衣样式#风衣*衣款式#抽绳", "response": "风衣带有一种随性大气的感觉,在春风拂面的日子里能衬托出你的气质。草绿色的衣身配色,与与自身清纯干净的性格形成了呼应。具有一定实用性的下摆抽绳,可以让你任意的变换风格。抛弃了衬衫领的设计应用的设计,更能将你帅气的一面展示出来。"}
38
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "优质挺括的面料。包臀版型,长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚"}
39
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "优质挺括的面料。包臀版型,长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚。"}
40
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚"}
41
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#针织*风格#复古*风格#清新*图案#条纹*图案#复古*衣样式#针织衫*衣样式#开衫*衣长#常规*衣款式#拼接*衣款式#纽扣*衣款式#罗纹", "response": "慵懒气质的针织开衫,充满了复古的情调,奶奶级的麻花编织手法,充满立体感的同时保暖效果也是满分。下摆的罗纹拼接,让针织衫回暖性更棒。活泼的<UNK>条纹拼接,跳脱出常规配色,清新色调的选用,更加衬托出肌肤的雪白。精致的纽扣点缀,反光的质感让针织衫充满现代感。"}
42
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#复古*图案#蝴蝶结*图案#复古*图案#波点*衣样式#衬衫*衣领型#立领*衣门襟#系带*衣款式#木耳", "response": "【<UNK>说】<UNK>衬衫,大波点气质复古从立领上延伸的长系带,可轻松绑成蝴蝶结,甜美感加分采用打缆工艺的松紧袖口边边处的木耳<UNK>很可爱"}
43
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#简约*风格#青春*图案#字母*图案#文字*裙型#网纱裙*裙袖长#无袖*裙领型#圆领", "response": "大气的圆领舒适贴合,彰显出女孩儿精神的气质。无袖的款式与圆领相迎合,简约的同时又不失时尚风采。前身由可爱蝴蝶图案点缀,亮丽的字母映衬其上,诉说着一丝精美感。橙色网纱裙摆造型优雅唯美,与上身的图案相呼应,十分富有青春的气息,伴随着步伐的行走间,带出一丝别致浪漫的风情。"}
44
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#丝绒*风格#复古*图案#复古*衣样式#雪纺衫*衣袖型#喇叭袖*衣款式#木耳边*衣款式#飘带*衣款式#荷叶边", "response": "这款雪纺衫,采用具有复古韵味的荷叶边元素,加上丝绒质感的加长飘带,洋溢着浪漫古典的韵味。<UNK>两侧镶有包扣,和立体木耳边装饰,大大提升时髦指数。而流线型喇叭袖设计,充满灵动质感,为造型平添活力。"}
45
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*版型#显瘦*裙下摆#荷叶边*裙腰型#高腰*裙长#半身裙", "response": "很简洁百搭的一款半裙,裙身荷叶边设计,飘逸灵动,上身更显层次感丰富。高腰造型,版型优良,衬显修长双腿。裙子做的比较宽松,包容性敲好,遮肉效果棒棒的。非常的显瘦哦,选用精品梭织面料,垂感好,肌理细致,上身敲舒服哟。"}
46
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#青春*风格#性感*图案#线条*裙下摆#开叉*裙长#连衣裙*裙领型#翻领*裙款式#腰带*裙款式#衬衫式", "response": "设计师以衬衫式的创作灵感,巧妙地搬运到连衣裙身上,中性又不失性感;时尚小翻领设计,巧妙衬托颈部线条,彰显青春派的艺术时尚,小资派的精彩演绎。耳目一新的双腰带设计,既突出了腰线又感觉很前卫;下面走心的大开叉设计,更能激发人的好奇心,营造出无人超越的高级性感,只需一眼就令人<UNK>。"}
47
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*风格#性感*图案#印花*裙下摆#荷叶边*裙长#连衣裙*裙袖型#灯笼袖", "response": "这款连衣裙走的是性感大方的风格路线,展现出你的大大咧咧的性情,非常的有趣。选用了宽松的版型,配合星空印花的图案,塑造出新颖有趣,不失活力四射的印象感。荷叶边的裙摆设计,突显出飘逸性感的一面。配合灯笼袖的袖型细节,体现出<UNK>的一面。"}
48
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*材质#水洗*颜色#浅色*风格#休闲*风格#性感*图案#线条*裙型#牛仔裙*裙型#直筒裙*裙下摆#开叉*裙下摆#毛边*裙腰型#高腰", "response": "浅色水洗效果牛仔裙,高腰设计融合修身直筒廓形,凸显纤细腰部和迷人翘臀,美化勾勒性感身材曲线。正面开叉细节有效拉长腿部线条,灵动性感。磨毛边下摆设计,带来休闲随性气息。"}
49
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*图案#条纹*图案#印花*衣样式#卫衣*衣款式#连帽*衣款式#罗纹", "response": "这款连帽卫衣自带休闲魅力,将杜嘉班纳的品牌标志以印花的形式装饰在衣身前幅,展现出华丽不失看点的视觉效果,每时每刻都在彰显不凡品味��罗纹条纹袖口和下摆,不仅能使卫衣更帅气惹眼,还能为整体增加一股前卫之风。"}
50
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#简约*图案#卡通*图案#蝴蝶结*图案#印花*衣样式#衬衫", "response": "大面积的卡通兔子印花,童趣满满,再加上领口的蝴蝶结装饰织带。充满童趣的同时又不失小女生的甜美气息,相当减龄。这款衬衫选用真丝面料,真丝面料不仅轻薄,而且柔滑、亲肤,就好像人的第二层肌肤般带给你清凉舒适的穿着感觉。合身的版型,裁剪得干净利落,简约又不失时尚气息,打造干练的气场。这款衬衫日常十分百搭,不仅可以与其他服饰搭配,作为一件单品也十分出彩。"}
51
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*材质#水洗*风格#复古*风格#简约*图案#复古*图案#线条*裤长#九分裤*裤款式#不规则*裤口#毛边", "response": "misssixty的这款单品延续经典的九分牛仔裤版型,结合贴合身形的剪裁,展现出柔美修长的腿部线条;不同的位置做了不同程度的水洗复古工艺,使得裤身更加立体厚重;此外,裤脚处采用了微微不规则的毛边剪裁,为简约的整体注入一丝随性之感;再加上<UNK>相互呼应的翅膀状图案点缀,瞬间带来一丝浪漫唯美的味道。"}
52
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#线条*衣样式#卫衣*衣领型#圆领*衣袖型#收口*衣门襟#套头*衣款式#螺纹", "response": "使用经典的螺纹圆领来展开设计,将衣型打造成套头卫衣的款式,穿着时轻松收口,将颈部线条修饰出挺拔优美的的效果,让穿着更加具有精气神。衣身以纯色作为主色调,配上经典的小企鹅logo,将正面点缀,它拥有一个俏皮的小蝴蝶领结,充满细节感,使得衣身吸睛耀眼。"}
53
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*裤长#九分裤*裤型#直筒裤", "response": "c小小的这样一条迷人的牛仔裤彰显出你的大气个性,它的别致直筒版型十分的高端迷人,让你吸睛十足。个性九分的版型展示出你的迷人小脚踝。它的大气牛仔材质,十分的舒适洒脱,迷人更有型。"}
54
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*颜色#纯色*风格#简约*图案#纯色*图案#线条*图案#蕾丝*衣样式#衬衫*衣领型#v领", "response": "一款简约的纯色衬衫,采用了个性的大v领,露出柔美的锁骨和颈部线条,散发出清爽迷人的气质;点缀精美的蕾丝花边装饰,波浪形的花边很有美感,增加了视觉亮点。"}
55
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#撞色*裙下摆#垂坠*裙长#连衣裙*裙袖长#无袖*裙袖型#收口*裙款式#拼接*裙款式#绑带*裙款式#波浪", "response": "来自奥芝国的推出的无袖连衣裙,精选弹力冰丝材质穿柔软垂坠性很好,适合春夏秋三季穿搭。腰部的撞色波浪纹弹力腰封拼接,并以交叉绑带式收口,修饰腰身轻松穿<UNK>人大长腿。"}
56
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*材质#针织*颜色#灰色*颜色#深蓝色*图案#线条*衣样式#毛衣", "response": "这是一款专为胖孩子设计的针织毛衣,加肥加大的立体版型,利落有型穿着合体不臃肿,穿着更加帅气显瘦;领口、袖口和下摆收紧处理使衣衣更加利落有型,久穿久洗也不易磨损和变形,颇具品质感;深蓝色的大身巧妙地加入一些灰色线条修饰活泼大方,孩子穿上它,洋溢着青春活力。"}
57
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#牛仔布*材质#网纱*风格#街头*衣样式#衬衫*衣款式#拼接*衣款式#勾花镂空*衣款式#钉珠", "response": "时髦又帅气的牛仔拼接裙,利用多材质拼接演绎刚柔并济的设计。硬朗的牛仔衬衫以镂空拼接,构造出深浅的色系变化,加上钉珠铆钉的装饰,更是玩味出十足的街头帅气。下身拼接的网纱半裙,层次细腻又丰富,两侧加入牛仔插袋呼应上身面料,带来一体感设计。"}
58
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#牛仔布*颜色#蓝色*颜色#浅蓝色*风格#性感*裙型#牛仔裙*裙型#包臀裙*裙下摆#开叉*裙款式#拼接*裙款式#纽扣", "response": "mm们<UNK>拼接风呢?这款牛仔裤是非常有趣的拼接风,浅蓝色和原蓝色的牛仔拼接在一起,非常吸引眼球。在左侧的裙摆处还做了开叉设计,微微露出腿部皮肤,展现性感姿态。包臀的设计,凸显圆润的臀部。前幅一排金属纽扣,增添细节感和精致度。喜欢的mm千万不要错过~"}
59
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#蕾丝*颜色#粉色*风格#清新*图案#碎花*图案#线条*图案#蕾丝*裙型#a字*裙下摆#花边*裙领型#圆领*裙款式#飘带", "response": "清新的小碎花缀满衣身,以淡雅的粉色调为底色,焕发出甜美温婉的少女气息。简洁的圆领设计,柔化脸部线条,加上蕾丝飘带点缀,更显娇俏减龄。下摆蕾丝花边分割裙裾,转身间将浪漫挥洒。散开的a字裙摆,恰到好处遮住了臀部和腿部粗的部分,有很好的修饰作用。"}
60
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#淑女*图案#植物*图案#印花*裙型#百褶*裙长#连衣裙*裙领型#娃娃领*裙款式#拼接*裙款式#腰带", "response": "法式浪漫情怀,由这款印花连衣裙为你抒写。蝴蝶花卉印花铺陈裙身,蝴蝶翩跹BRAND花丛,浪漫迷人美如画,法式风情呼之欲出。娃娃领的设计,凸显一身柔美的淑女气质。裙摆百褶的设计,<UNK>飞舞更添灵动飘逸的美。腰带拼接的设计,完美打造显瘦显高的身材比例。"}
61
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*风格#复古*风格#简约*风格#休闲*图案#复古*图案#线条*图案#印花*裙长#连衣裙", "response": "这一款雪纺连衣裙复古的小立领带来不一样的惊喜,不仅拉伸了脖颈的线条,同时衬托出娇小的脸型。衣身大大的印花很有质感,简约休闲中透露着复古精致的美丽。"}
62
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#文艺*风格#简约*风格#清新*衣样式#外套*衣门襟#拉链*衣款式#拉链", "response": "飘飘落落,暖色的布料上纷纷落落的铺着羽毛,灰常有意境的一款连衣裙。羽毛是这款连衣裙最大的亮点,色彩也丰富饱满,凸显的文艺感也灰常强烈,满满的文艺清新气息;简约大方的设计,有种不喧嚣的热烈感;凸显内敛的气质。搭大衣、棉服外套不仅保暖又灰常的有韵味,而且这款不仅做了开扣的设计,还做了隐形的小拉链!是可哺乳的款式,方便孕后哺乳穿,墙裂推荐!"}
63
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#针织*风格#简约*风格#青春*风格#清新*风格#性感*图案#条纹*图案#撞色*裙下摆#开叉*裙长#连衣裙*裙款式#拼接*裙款式#吊带", "response": "这款针织吊带连衣裙展现青春时尚的格调,双侧撞色条纹的拼接简约经典,散发出清新爽朗的气息,显得格外惹眼,营造出明媚动人的视觉吸引力。赋予简约的吊带裙满满的活力,开叉的剪裁性感别致,充满小女人的韵味。"}
64
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#街头*风格#青春*衣样式#t恤*衣领型#圆领", "response": "三叶草的这款体恤面料比较舒适,穿起来也能很好的透气排汗。整体的设计风格就是经典的款式,所以说是街头常年流行的必备。圆领的领口设计在穿脱时起到了方便。同时修饰脸部轮廓,更显小脸。三叶草的标志也是最为独特的品牌标识,穿出了个人的品味。"}
65
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#白色*风格#休闲*风格#清新*衣样式#外套*衣款式#连帽", "response": "春天家以清新白色为主基调打造的这款外套,整体采用了直筒的极简剪裁配合休闲感的连帽设计,穿着在身上的舒适度较高。设计师为这款上衣做了<UNK>口的袖子和下摆的处理,穿着后对于身形的修饰效果会更为出众,显得较为得体、大方。"}
66
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#军绿色*风格#复古*风格#文艺*风格#知性*风格#休闲*风格#潮*图案#复古*图案#撞色*衣样式#外套*衣样式#西装*衣领型#西装领*衣长#短款*衣袖型#插肩袖", "response": "短款西装小外套,结合了知性和休闲两种风格,在现代的潮流款式中又融入了淡淡的复古韵味。端庄典雅的军绿色衣身,带着自由舒畅的旅行感,款式上选用利落率性色西装领,宽松闲适的插肩袖,门襟选用撞色的两粒扣设计,复古文艺又简洁随性。"}
67
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#线条*裙款式#勾花镂空*裙款式#收腰", "response": "亮眼的橙红色展示出迎面而来的热情感,衬托肤色白皙红润,在宴会上气场十足。方形的镂空点缀着衣领下方,增加看点散发出小女人的妩媚感。独特的衣袖造型倾斜而下,修饰手臂线条非常修长,在举手投足间优雅又大气。收腰的版型设计修饰腰部线条更纤细,打褶的裙摆在行走时灵动十足,仿佛<UNK>的精灵一般。"}
68
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#线条*裙款式#勾花镂空*裙��式#收腰", "response": "裙子表面的镂空花网就使其充满了很强的设计美感,首先是肩部将落肩袖和镂空图案相结合,白皙的肌肤隐隐约约,而且能够很好的缩小肩宽比例。v型领口修饰拉长颈部线条和显得脸小。裙子做了收腰裁剪,并将腰线提高,轻松拉长下半身身材比例,裙摆也更加挺括,从而能够解决胯宽等身材烦恼。"}
69
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*材质#涤纶*衣样式#风衣*衣袖型#灯笼袖*衣款式#纽扣*衣款式#飘带", "response": "风衣在摒弃了传统的版型样式,将袖子设计成花苞型的灯笼袖,与春天搭配得恰到好处。并在袖子处装饰了四颗纽扣,采用飘带作为松紧调节,增添层次感更显个性别致。除此之外,风衣采用涤纶材质制成,垂顺感好挺括修身,结合小a字形轮廓,更显身形高挑秀美,并且让矮小个的女性也能撑起风衣的气场。"}
70
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#英伦*风格#简约*风格#休闲*图案#格子*图案#线条*衣样式#西装*衣领型#翻领*衣门襟#双排扣", "response": "这一款休闲西装简约利落的翻领,可以很好地修饰脸型和颈部线条,显脸小的同时又让脖子看上去更纤细。加上精致的格纹装饰,视觉美丽凸显英伦风。而且双排扣设计,时尚大气美观实用。"}
71
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#针织*风格#文艺*风格#休闲*风格#性感*裙长#半身裙*裙长#连衣裙*裙款式#拼接", "response": "连衣裙的灵感来自于<UNK>匠人穿着围裙的状态,设计师将针织上衣与半裙结合,整体松软舒适,且不失休闲随性感。裙摆不同材质的拼接,带来丰富的层次细节,让时髦度倍增。偏暗调的配色融入文艺田园气息,显随性姿态。"}
72
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#复古*图案#复古*裙下摆#荷叶边*裙长#连衣裙", "response": "对于女孩子来说,喜爱连衣裙是与生俱来的!几乎没有问题是一条裙纸<UNK>的~BRAND这款裙子整体的设计有点小复古的感觉,而且艳丽的枣红色也是复古色的代表,上身穿着十分衬肤显白哦。个性而时髦的挂脖式领口露出锁骨很是撩人,另外领口至腰间的衣身前片还加入了很有灵动感的荷叶边作为点缀,瞬间点亮了整体的造型感,由内而外散发的优雅而温柔的气质无人能挡。"}
73
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*风格#日系*风格#简约*图案#线条", "response": "很喜欢这款简单却不简约的时尚牛仔裤,在夏天可以穿出个性与时尚。整个风格比较偏向于日系的身体,任何妹子都能够轻松驾驭,最重要的是版型。穿上特别修饰腿部的线条,打造出了高挑的身材,让你看起来非常有自信的呢,这手工的工艺凸显出了无限的高级质感。"}
74
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#街头*风格#潮*裙型#a字", "response": "孕期就一定要穿的沉闷单调吗?热爱潮流的怎能束缚自己个性的心呢,这款裙子采用a字型设计,让你搭配更为轻松随意,飘逸的撞色织带设计,即刻将原本沉闷的空气也带动的活跃起来。从街头到<UNK>,尽显潮流个性时尚。"}
75
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*颜色#浅蓝色*风格#街头*风格#休闲*裤型#直筒裤*裤款式#破洞", "response": "破洞元素已变成彰显个性的元素,这款浅蓝色牛仔裤糅合磨白磨破设计,弥漫摩登个性格调,而且破洞设计,打破裤装闷热形象,休闲时髦;直筒款巧妙糅合酷帅感与时髦感,塑造街头潮人印象。"}
76
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#雪纺*风格#知性*风格#性感*图案#线条*裤长#连体裤*裤款式#木耳边", "response": "雪纺面料的一袭连体裤,舒适的手感,轻盈的穿着,宽松的版型,让上身穿着没有束缚感。一字肩的设计,木耳的花边,显露颈部柔美的线条,与性感的锁骨,展现女性知性的一面,木耳花边的<UNK>设计,显露穿着的甜美感,与少女味。高收腰的设计,拉伸腰部的曲线,提高腰线,显露穿着高挑的身姿。"}
77
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*风格#简约*图案#线条*裤款式#口袋*裤款式#拉链", "response": "侧缝处添置有立体拉链口袋作为装饰,实用性强且兼备美观性。净色的大体外观,简约低调,大方得体,易于搭配。裤腰处植入张弛有度的弹性带,贴合腰部,适合于大多数人穿着。衣身剪裁干净利落,线条流畅。"}
78
+ {"prompt": "请根据以下标签为商品编写一段广��\n类型#上衣*颜色#白色*图案#条纹*图案#线条*衣样式#衬衫", "response": "白色的衬衫采用了百褶的袖子设计,既修饰了手臂线条,又为整体增强了设计感。背带裤是永不过时的条纹款式,加上阔腿裤的设计,更显女性身材。"}
79
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#棉*材质#牛仔布*风格#简约*风格#休闲*裤长#短裤*裤款式#破洞", "response": "选用优质的纯棉面料打造出舒适的质感,而且上身不会扎身。同时,个性破洞细节设计,增加了牛仔短裤的细节感和吸睛度。此外,简约好搭的配色,柔和你的棱角,让你看起来温柔又平易近人。适合约会等休闲场景,是你衣柜里不可或缺的时髦单品之一。"}
80
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#水洗*风格#潮*裤款式#不规则*裤口#毛边", "response": "年轻潮流的设计品味,洋气又好穿。细节相当丰富有看点,融入水洗磨白,使其充满时尚不羁的气息。裤脚前后毛边处理,配上不规则脚口,更添青春活力。"}
81
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*颜色#蓝色*风格#简约*裤型#背带裤*裤款式#纽扣", "response": "背带裤的选用天蓝色的主题,远远看上去就像是蓝色<UNK>悬挂在活跃孩子的身上。简约的背带设计,可随时拆开的纽扣,让稚嫩孩子穿衣时不费吹灰之力。腰部更是搭配弹性材料缝制的腰带,不仅方便穿戴而且完美的起到了修饰作用。后背交叉背带,更是独特新颖的处理,更好更牢固的穿搭,不易滑落。"}
82
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#涤纶*裤款式#拼接*裤款式#口袋", "response": "前置的口袋盖拼接,为本来单调的设计布局增加了亮点,更突出了裤子的个性化特点。加上精致的涤纶梭织面料制作,具备更加亲肤不刺激的丝质般触感,给你带来更加柔软舒适的穿着体验。其良好的透气性,有效提升了裤子的吸湿排汗性能,为你提供更加清爽舒适的体感。"}
83
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙下摆#荷叶边*裙领型#圆领*裙袖型#收口*裙款式#螺纹", "response": "此款上衣采用了经典的圆领款式打造,贴合舒适并能修饰出完美的脸型。同时螺纹的收口贴合身材更完美,在前身处采用了可爱的小狮子造型,带<UNK>真的感觉,而狮子的毛发更是立体精致,显得真实又有丰富的层次。裙身的下摆处采用了荷叶边的设计,俏皮活泼更可爱。"}
84
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*版型#显瘦*材质#网纱*风格#青春*图案#印花*衣样式#衬衫*衣领型#v领*衣款式#拼接", "response": "这一款衬衫交叠v领的设计,修饰脖颈尽显女人味,宽松的廓形,穿上非常轻松有范毫不拘束,并很好的遮盖身材,非常显瘦。时尚的网纱拼接,自然美感特别出彩。精致印花,青春减龄特别活力。"}
85
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#运动*风格#性感*衣样式#西装*衣领型#一字领*衣款式#荷叶边", "response": "荷叶边能够表达出女性的优雅,BRAND的这款上衣,将荷叶边很好地运动起来。性感的一字肩设计,荷叶边从一侧手臂的手肘从前胸绕到另一侧,有着前短后长的感觉,自然垂坠很有层次感,举手投足之间,灵动而优雅。西装袖很好地融合,优雅之中透着小帅气。"}
86
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#运动*风格#休闲*风格#青春*图案#字母*图案#形状*图案#文字*图案#刺绣*图案#撞色*衣样式#卫衣*衣袖型#收口*衣款式#抽绳*衣款式#连帽", "response": "这款dolce&gabbana的连帽卫衣,撞色的<UNK>字母加上桃心形状的刺绣图案令人耳目一新,举手投足间散发阳光活力少女的青春气息;连帽款式尽显帅气利落风范,细节上采用抽绳处理实用又美观,洋溢满满的运动休闲范儿;加之袖口处的收口设计别出心裁,宽松的衣身烘托出慵懒率性的格调。"}
87
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#牛仔布*风格#休闲*图案#字母*图案#文字*图案#线条*图案#印花*图案#撞色*裤款式#拼接*裤口#小脚", "response": "上下<UNK>拼接撞色设计,吸睛十足,轻松聚焦视线,个性前卫。字母印花设计,巧添时尚细节看点,以鲜明撞色渲染,展现年轻活力气息。长袖套头轮廓,线条处理恰到好处,呼应休闲基调。宽松的版型,不挑身材,上身好看。连帽的设计美观实用,防风保暖。时尚百搭,可以搭配牛仔裤、紧身裤、休闲裤、束脚裤等。"}
88
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*颜色#灰色*风格#复古*风格#文艺*风格#青春*图案#卡通*图案#复古*衣样式#风衣*衣长#中长款", "response": "一款好看的风衣大概能为这个姹紫嫣红的春天多一份色彩,沉静的灰色上身具有非常好的效果,显得热更加内敛沉稳,有一股淡淡的复古文艺风格。而中长的版型自然下垂,修身显高又瞬间提升气场。后背的卡通图案别致可爱,更添青春气息。"}
89
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#性感*衣样式#针织衫*衣款式#露肩", "response": "这一款针织衫露肩设计,风情浪漫性感迷人。略微宽松的版型舒适随意,很好的掩饰身材小小的缺陷,看起来精致高挑。加上时尚的花边下摆,错落有致视觉美丽。精致袖口,修饰手臂特别出彩。"}
90
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙款式#松紧带*裙款式#飘带", "response": "<UNK>冷风的气质感,干净利落的feel,小露香肩有一种含蓄撩拨的趣味,袖口领口的飘带设计很是巧妙,让整个小衫更加优美,领子部分的两边肩部松紧带设计,大胆随意的穿出多种效果。让仙女们走在时尚<UNK>的道路上更加自信。"}
91
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#针织*衣样式#卫衣*衣领型#圆领", "response": "针织卫衣采用了简洁的圆领设计,非常百搭,免去了你<UNK>找不到搭配的烦恼。合体的剪裁设计,让你在跑步健身时轻巧灵便,活动自如,达到更好的锻炼效果。"}
92
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#白色*风格#简约*图案#蝴蝶结*图案#刺绣*衣样式#衬衫*衣袖型#喇叭袖", "response": "这一款很好穿的白色衬衫,利落的宽松版型几乎是不挑身材的,无门襟的设计也符合整体的气息。胸前做了绣花的点缀,为简约的衬衫增添了几分柔美的气质。七分的喇叭袖露出小臂,蝴蝶结的点缀显得气质更加的浪漫。"}
93
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*颜色#灰色*风格#英伦*风格#复古*图案#格子*图案#复古*裙型#百褶*裙长#半身裙*裙款式#波浪*裙款式#收腰", "response": "BRAND这款半身裙,用复古的灰色格纹,打造出十足英伦范儿。搭配百褶裙身,为整体增添层次感,穿出减龄风。同时,波浪边的收腰设计,不仅更好的修饰腰部曲线,还为整体气质增添了优雅美感。而雪纺面料,使你在夏日也能穿出清爽感。"}
94
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#复古*风格#文艺*风格#中国风*风格#性感*图案#复古*图案#刺绣*裙型#a字*裙领型#v领", "response": "超级具有中国风气息的一款裙子,带着古典的柔婉。花朵刺绣的运用,色彩缤纷靓丽,冲击视觉,演绎复古文艺范儿。经典的气质v领,既凸显了小性感与时尚,又起到点睛的效果。腰部系的设计,配上a字版型,显瘦又遮肚子。"}
95
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#清新*风格#性感*图案#线条*衣样式#马甲*衣领型#翻领*衣款式#露背*衣款式#绑带*衣款式#吊带*衣款式#收腰", "response": "小吊带马甲叠穿造型,年轻而不失时尚格调,有着绑带收腰设计,强调出纤细的腰肢,摩登帅气;小翻领露出纤细修长的脖颈线条,散发清爽利落的小清新气息;性感交叉露背设计,别致吸睛,女人味十足;高腰伞形裙摆自然撑开,上身塑造黄金比例,突显得腰更细,巧妙地修饰身型。"}
96
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*颜色#黑白*风格#英伦*风格#简约*图案#格子*图案#线条*衣样式#外套*衣样式#西装*衣门襟#一粒扣", "response": "这款西装外套,版型加长修身,能更好凸显成熟与稳重。细细密密的黑白图案,远远看形成自然的格纹,时髦英伦范儿。平整肩线将线条感拉伸,让身姿显得更挺拔有型。一粒扣设计,简约大气。"}
97
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#潮*图案#线条*图案#撞色*裙领型#圆领", "response": "采用经典的圆领设计,修饰颈部线条的同时,且上身穿着舒适不易变形,轻松演绎时髦造型。大面积撞色贴花装饰,无疑是点睛之笔,为简洁的款式轮廓带来了更多的视觉层次感与潮流气息。与众不同的你,不在畏惧撞衫的尴尬。"}
98
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#雪纺*颜色#纯色*风格#清新*图案#纯色*图案#碎花*衣样式#衬衫*衣款式#荷叶边", "response": "这件荷叶边雪纺碎花衬衫和其他的碎花衬衫相比整体的风格��更优雅柔美一些。颜色上也是比较<UNK>清新的花型配色和纯色的大身相结合,会让人看着很舒服,而且每个碎花之间都会<UNK>限视觉上不会觉得太紧密,更有法式的浪漫优雅。"}
99
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*风格#简约*风格#青春*风格#潮*风格#性感*图案#线条*图案#蕾丝*衣样式#雪纺衫*衣领型#圆领*衣款式#勾花镂空", "response": "这款时尚镂空雪纺衫,带有性感蕾丝工艺,精致百搭的圆领设计,彰显显独特的质感。背部线条流畅,笔挺而有型干练,彰显潮流时尚之风。走线十分笔直,针脚均匀,尽显裁缝之细致。简约时尚的透视蕾丝袖口,彰显成熟又不乏活力的青春气质。青春优雅的独特风格,流露出满满的潮流感。"}
100
+ {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*风格#休闲*裤长#短裤", "response": "来自英国<UNK>的这款儿童休闲短裤,以趣味的小恐龙图案满印裤身,可爱童真,彰显出宝宝的活泼天真范儿。柔软的全棉布料质地,手感细腻顺滑,亲和宝宝的肌肤,带来舒适自在的穿着体验。"}
limitations/factual_error.png ADDED
limitations/math_error.png ADDED
limitations/self-confusion_google.jpg ADDED
limitations/self-confusion_openai.jpg ADDED
limitations/self-confusion_tencent.jpg ADDED
ptuning/README.md ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatGLM-6B-PT
2
+ 本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。
3
+
4
+ 下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。
5
+
6
+ *Read this in [English](README_en.md).
7
+
8
+ ## 软件依赖
9
+ 运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖
10
+ ```
11
+ pip install rouge_chinese nltk jieba datasets
12
+ ```
13
+ ## 使用方法
14
+
15
+ ### 下载数据集
16
+ ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
17
+
18
+ ```json
19
+ {
20
+ "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
21
+ "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
22
+ }
23
+ ```
24
+
25
+ 从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。
26
+
27
+ ### 训练
28
+
29
+ #### P-Tuning v2
30
+
31
+ 运行以下指令进行训练:
32
+ ```shell
33
+ bash train.sh
34
+ ```
35
+ `train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。
36
+
37
+ 在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
38
+
39
+ 如果你想要[从本地加载模型](../README_en.md#load-the-model-locally),可以将 `train.sh` 中的 `THUDM/chatglm-6b` 改为你本地的模型路径。
40
+
41
+ #### Finetune
42
+
43
+ 如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令:
44
+
45
+ ```shell
46
+ bash ds_train_finetune.sh
47
+ ```
48
+
49
+ ### 推理
50
+
51
+ 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重,因此需要指定 `evaluate.sh` 中的参数:
52
+
53
+ ```shell
54
+ --model_name_or_path THUDM/chatglm-6b
55
+ --ptuning_checkpoint $CHECKPOINT_PATH
56
+ ```
57
+
58
+ 仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`:
59
+
60
+ ```shell
61
+ --model_name_or_path $CHECKPOINT_PATH
62
+ ```
63
+
64
+ 评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
65
+ `./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。
66
+
67
+ ### 例子
68
+ #### 示例1
69
+ * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
70
+ * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
71
+ * Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
72
+ * Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
73
+
74
+ #### 示例2
75
+
76
+ * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
77
+ * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
78
+ * Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
79
+ * Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
80
+
81
+ ### 评估结果
82
+
83
+ | | Finetune | P-tuning v2 | LoRA |
84
+ | ------------- | ----------- | ----- | ------------- |
85
+ | BLEU-4 | 8.01 | 8.10 | 7.62 |
86
+ | Rouge-1 | 31.23 | 31.12 | 30.60 |
87
+ | Rouge-2 | 7.36 | 7.11 | 6.96 |
88
+ | Rouge-l | 25.08 | 24.97 | 24.80 |
89
+ | Training Loss | 3.00 | 3.74 | 3.32 |
90
+
91
+
92
+
93
+ #### 实验设置
94
+
95
+ ```
96
+ max_source_length=64
97
+ max_target_length=64
98
+ max_steps=3000
99
+ ```
100
+
101
+ ##### P-tuning v2
102
+
103
+ ```
104
+ pre_seq_len=128
105
+ learning_rate=2e-2
106
+ quantization_bit=4
107
+ per_device_train_batch_size=16
108
+ gradient_accumulation_steps=1
109
+ ```
110
+
111
+ ##### Finetune
112
+
113
+ ```
114
+ learning_rate=1e-4
115
+ fp16
116
+ num_gpus=4
117
+ per_device_train_batch_size=4
118
+ gradient_accumulation_steps=1
119
+ ```
120
+
121
+ ##### LoRA
122
+
123
+ 实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)
124
+
125
+ ```
126
+ learning_rate=5e-4
127
+ per_device_train_batch_size=16
128
+ gradient_accumulation_steps=1
129
+ ```
130
+
131
+ ## 模型部署
132
+ 首先载入Tokenizer:
133
+
134
+ ```python
135
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
136
+
137
+ # 载入Tokenizer
138
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
139
+ ```
140
+
141
+ 1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数):
142
+
143
+ ```python
144
+ config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
145
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
146
+ prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
147
+ new_prefix_state_dict = {}
148
+ for k, v in prefix_state_dict.items():
149
+ if k.startswith("transformer.prefix_encoder."):
150
+ new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
151
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
152
+ ```
153
+ 注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)的话,需要将 `THUDM/chatglm-6b` 改成本地的模型路径(注意不是checkpoint路径)。
154
+
155
+ 2. 如果需要加载的是旧 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 参数),或者进行的是全参数微调,则直接加载整个 Checkpoint:
156
+
157
+ ```python
158
+ model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
159
+ ```
160
+
161
+ 之后根据需求可以进行量化,也可以直接使用:
162
+
163
+ ```python
164
+ # Comment out the following line if you don't use quantization
165
+ model = model.quantize(4)
166
+ model = model.half().cuda()
167
+ model.transformer.prefix_encoder.float()
168
+ model = model.eval()
169
+
170
+ response, history = model.chat(tokenizer, "你好", history=[])
171
+ ```
172
+
173
+ **[23/04/19]** 你也可以直接运行支持加载 P-Tuning v2 checkpoint 的 [web demo](./web_demo.py)
174
+ ```shell
175
+ bash web_demo.sh
176
+ ```
177
+ 可能需要修改 [web_demo.sh](./web_demo.sh) 的内容以符合你实际的 checkpoint 情况。
178
+
179
+ ## 使用自己的数据集
180
+ 修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。
181
+
182
+ ## 对话数据集
183
+
184
+ 如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据:
185
+
186
+ ```json lines
187
+ {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
188
+ {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
189
+ {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上���水管温差怎么样啊?空气是不是都排干净了呢?"]]}
190
+ ```
191
+
192
+ 训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接。要注意超过输入长度 `max_source_length` 的内容会被截断。
193
+
194
+ 可以参考以下指令:
195
+
196
+ ```shell
197
+ bash train_chat.sh
198
+ ```
199
+
200
+ ## 引用
201
+
202
+ ```
203
+ @inproceedings{liu2022p,
204
+ title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
205
+ author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
206
+ booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
207
+ pages={61--68},
208
+ year={2022}
209
+ }
210
+ ```
211
+
212
+
213
+
ptuning/README_en.md ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatGLM-6B-PT
2
+ This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run.
3
+
4
+ The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code.
5
+
6
+ ## Software dependencies
7
+ Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required
8
+ ```
9
+ pip install rouge_chinese nltk jieba datasets
10
+ ```
11
+ ## Instructions
12
+
13
+ ### Download the dataset
14
+ The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content).
15
+
16
+ ```json
17
+ {
18
+ "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
19
+ "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
20
+ }
21
+ ```
22
+
23
+ From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory.
24
+
25
+ ### Training
26
+
27
+ #### P-Tuning v2
28
+
29
+ Run the following commands for training:
30
+ ```shell
31
+ bash train.sh
32
+ ```
33
+ `PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision.
34
+
35
+ Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation.
36
+
37
+ If you want to [load the model locally](../README_en.md#load-the-model-locally), you can change `THUDM/chatglm-6b` in `train.sh` to your local model path.
38
+
39
+ #### Finetune
40
+ To finetune the full parameters, you need to install [Deepspeed](https://github.com/microsoft/DeepSpeed), and then run the following command:
41
+
42
+ ```shell
43
+ bash ds_train_finetune.sh
44
+ ```
45
+
46
+ ### Inference
47
+
48
+ During P-tuning v2 training, the model only saves the parameters of the PrefixEncoder part, so the original ChatGLM-6B model and the weight of the PrefixEncoder need to be loaded at the same time during inference, and the arguments need to be specified in `evaluate.sh`:
49
+
50
+ ```shell
51
+ --model_name_or_path THUDM/chatglm-6b
52
+ --ptuning_checkpoint $CHECKPOINT_PATH
53
+ ```
54
+
55
+ It is still compatible with the old version of Checkpoint saved with full parameters, just set `model_name_or_path` as before:
56
+
57
+ ```shell
58
+ --model_name_or_path $CHECKPOINT_PATH
59
+ ```
60
+
61
+ The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in
62
+ `./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`.
63
+
64
+ ### Example
65
+ #### Example 1
66
+ * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
67
+ * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
68
+ * Output[before tuning]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
69
+ * Output[after tuning]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
70
+
71
+ #### Example 2
72
+
73
+ * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
74
+ * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒���布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
75
+ * Output[before tuning]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
76
+ * Output[after tuning]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
77
+
78
+ ### evaluation result
79
+
80
+ | | Finetune | P-tuning v2 | LoRA |
81
+ | ------------- | ----------- | ----- | ------------- |
82
+ | BLEU-4 | 8.01 | 8.10 | 7.62 |
83
+ | Rouge-1 | 31.23 | 31.12 | 30.60 |
84
+ | Rouge-2 | 7.36 | 7.11 | 6.96 |
85
+ | Rouge-l | 25.08 | 24.97 | 24.80 |
86
+ | Training Loss | 3.00 | 3.74 | 3.32 |
87
+
88
+ #### Experiment Settings
89
+
90
+ ```
91
+ max_source_length=64
92
+ max_target_length=64
93
+ max_steps=3000
94
+ ```
95
+
96
+ ##### P-tuning v2
97
+
98
+ ```
99
+ pre_seq_len=128
100
+ learning_rate=2e-2
101
+ quantization_bit=4
102
+ per_device_train_batch_size=16
103
+ gradient_accumulation_steps=1
104
+ ```
105
+
106
+ ##### Finetune
107
+
108
+ ```
109
+ learning_rate=1e-4
110
+ fp16
111
+ num_gpus=4
112
+ per_device_train_batch_size=4
113
+ gradient_accumulation_steps=1
114
+ ```
115
+
116
+ ##### LoRA
117
+
118
+ The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)
119
+
120
+ ```
121
+ learning_rate=5e-4
122
+ per_device_train_batch_size=16
123
+ gradient_accumulation_steps=1
124
+ ```
125
+
126
+ ## Model Deployment
127
+ First load the tokenizer:
128
+
129
+ ```python
130
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
131
+
132
+ # Load Tokenizer
133
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
134
+ ```
135
+
136
+ 1. If a new Checkpoint needs to be loaded (only contains the PrefixEncoder parameter):
137
+
138
+ ```python
139
+ config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
140
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
141
+ prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
142
+ new_prefix_state_dict = {}
143
+ for k, v in prefix_state_dict.items():
144
+ if k.startswith("transformer.prefix_encoder."):
145
+ new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
146
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
147
+ ```
148
+ Note that you may need to change `pre_seq_len` to the actual value of your training. If you [load model from local](../README_en.md#load-the-model-locally), you need to change `THUDM/chatglm-6b` to the local model path (not the checkpoint path).
149
+
150
+ 2. If you need to load the old checkpoint (including both ChatGLM-6B and PrefixEncoder parameters), or perform full parameter fine-tuning, then directly load the entire checkpoint:
151
+
152
+ ```python
153
+ model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
154
+ ```
155
+
156
+ Then it can be quantified according to the needs, or it can be used directly:
157
+
158
+ ```python
159
+ # Comment out the following line if you don't use quantization
160
+ model = model. quantize(4)
161
+ model = model.half().cuda()
162
+ model.transformer.prefix_encoder.float()
163
+ model = model.eval()
164
+
165
+ response, history = model.chat(tokenizer, "Hello", history=[])
166
+ ```
167
+
168
+ **[23/04/19]** You can also directly run [web demo](./web_demo.py) which supports loading P-Tuning v2 checkpoint
169
+ ```shell
170
+ bash web_demo.sh
171
+ ```
172
+ It may be necessary to modify the content of [web_demo.sh](./web_demo.sh) to match your actual checkpoint situation.
173
+
174
+ ## Use your own dataset
175
+ Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text.
176
+ You may also need to increase `max_source_length` and `max_target_length` to match the maximum input and output lengths in your own dataset.
177
+
178
+ ## Dialog Dataset
179
+
180
+ If you need to use multiple rounds of dialogue data to train the model, you can provide chat history. For example, the following is the training data for a three-round dialogue:
181
+
182
+ ```json lines
183
+ {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
184
+ {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好���传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
185
+ {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]}
186
+ ```
187
+
188
+ During training, you need to specify `--history_column` as the key of the chat history in the data (`history` in this example), and the chat history will be stitched automatically. Note that content exceeding the input length `max_source_length` will be truncated.
189
+
190
+ You can refer to the following instructions:
191
+
192
+ ```shell
193
+ bash train_chat.sh
194
+ ```
195
+
196
+ ## Citation
197
+
198
+ ```
199
+ @inproceedings{liu2022p,
200
+ title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
201
+ author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
202
+ booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
203
+ pages={61--68},
204
+ year={2022}
205
+ }
206
+ ```
ptuning/arguments.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class ModelArguments:
7
+ """
8
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9
+ """
10
+
11
+ model_name_or_path: str = field(
12
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
13
+ )
14
+ ptuning_checkpoint: str = field(
15
+ default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
16
+ )
17
+ config_name: Optional[str] = field(
18
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
19
+ )
20
+ tokenizer_name: Optional[str] = field(
21
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
22
+ )
23
+ cache_dir: Optional[str] = field(
24
+ default=None,
25
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
26
+ )
27
+ use_fast_tokenizer: bool = field(
28
+ default=True,
29
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
30
+ )
31
+ model_revision: str = field(
32
+ default="main",
33
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
34
+ )
35
+ use_auth_token: bool = field(
36
+ default=False,
37
+ metadata={
38
+ "help": (
39
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
40
+ "with private models)."
41
+ )
42
+ },
43
+ )
44
+ resize_position_embeddings: Optional[bool] = field(
45
+ default=None,
46
+ metadata={
47
+ "help": (
48
+ "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
49
+ "the model's position embeddings."
50
+ )
51
+ },
52
+ )
53
+ quantization_bit: Optional[int] = field(
54
+ default=None
55
+ )
56
+ pre_seq_len: Optional[int] = field(
57
+ default=None
58
+ )
59
+ prefix_projection: bool = field(
60
+ default=False
61
+ )
62
+
63
+
64
+ @dataclass
65
+ class DataTrainingArguments:
66
+ """
67
+ Arguments pertaining to what data we are going to input our model for training and eval.
68
+ """
69
+
70
+ lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
71
+
72
+ dataset_name: Optional[str] = field(
73
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
74
+ )
75
+ dataset_config_name: Optional[str] = field(
76
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
77
+ )
78
+ prompt_column: Optional[str] = field(
79
+ default=None,
80
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
81
+ )
82
+ response_column: Optional[str] = field(
83
+ default=None,
84
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
85
+ )
86
+ history_column: Optional[str] = field(
87
+ default=None,
88
+ metadata={"help": "The name of the column in the datasets containing the history of chat."},
89
+ )
90
+ train_file: Optional[str] = field(
91
+ default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
92
+ )
93
+ validation_file: Optional[str] = field(
94
+ default=None,
95
+ metadata={
96
+ "help": (
97
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
98
+ )
99
+ },
100
+ )
101
+ test_file: Optional[str] = field(
102
+ default=None,
103
+ metadata={
104
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
105
+ },
106
+ )
107
+ overwrite_cache: bool = field(
108
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
109
+ )
110
+ preprocessing_num_workers: Optional[int] = field(
111
+ default=None,
112
+ metadata={"help": "The number of processes to use for the preprocessing."},
113
+ )
114
+ max_source_length: Optional[int] = field(
115
+ default=1024,
116
+ metadata={
117
+ "help": (
118
+ "The maximum total input sequence length after tokenization. Sequences longer "
119
+ "than this will be truncated, sequences shorter will be padded."
120
+ )
121
+ },
122
+ )
123
+ max_target_length: Optional[int] = field(
124
+ default=128,
125
+ metadata={
126
+ "help": (
127
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
128
+ "than this will be truncated, sequences shorter will be padded."
129
+ )
130
+ },
131
+ )
132
+ val_max_target_length: Optional[int] = field(
133
+ default=None,
134
+ metadata={
135
+ "help": (
136
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
137
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
138
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
139
+ "during ``evaluate`` and ``predict``."
140
+ )
141
+ },
142
+ )
143
+ pad_to_max_length: bool = field(
144
+ default=False,
145
+ metadata={
146
+ "help": (
147
+ "Whether to pad all samples to model maximum sentence length. "
148
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
149
+ "efficient on GPU but very bad for TPU."
150
+ )
151
+ },
152
+ )
153
+ max_train_samples: Optional[int] = field(
154
+ default=None,
155
+ metadata={
156
+ "help": (
157
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
158
+ "value if set."
159
+ )
160
+ },
161
+ )
162
+ max_eval_samples: Optional[int] = field(
163
+ default=None,
164
+ metadata={
165
+ "help": (
166
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
167
+ "value if set."
168
+ )
169
+ },
170
+ )
171
+ max_predict_samples: Optional[int] = field(
172
+ default=None,
173
+ metadata={
174
+ "help": (
175
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
176
+ "value if set."
177
+ )
178
+ },
179
+ )
180
+ num_beams: Optional[int] = field(
181
+ default=None,
182
+ metadata={
183
+ "help": (
184
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
185
+ "which is used during ``evaluate`` and ``predict``."
186
+ )
187
+ },
188
+ )
189
+ ignore_pad_token_for_loss: bool = field(
190
+ default=True,
191
+ metadata={
192
+ "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
193
+ },
194
+ )
195
+ source_prefix: Optional[str] = field(
196
+ default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
197
+ )
198
+
199
+ forced_bos_token: Optional[str] = field(
200
+ default=None,
201
+ metadata={
202
+ "help": (
203
+ "The token to force as the first generated token after the decoder_start_token_id."
204
+ "Useful for multilingual models like mBART where the first generated token"
205
+ "needs to be the target language token (Usually it is the target language token)"
206
+ )
207
+ },
208
+ )
209
+
210
+
211
+
212
+ def __post_init__(self):
213
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None:
214
+ raise ValueError("Need either a dataset name or a training/validation/test file.")
215
+ else:
216
+ if self.train_file is not None:
217
+ extension = self.train_file.split(".")[-1]
218
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
219
+ if self.validation_file is not None:
220
+ extension = self.validation_file.split(".")[-1]
221
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
222
+ if self.val_max_target_length is None:
223
+ self.val_max_target_length = self.max_target_length
224
+
ptuning/deepspeed.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu": "auto",
3
+ "zero_allow_untested_optimizer": true,
4
+ "fp16": {
5
+ "enabled": "auto",
6
+ "loss_scale": 0,
7
+ "initial_scale_power": 16,
8
+ "loss_scale_window": 1000,
9
+ "hysteresis": 2,
10
+ "min_loss_scale": 1
11
+ },
12
+ "zero_optimization": {
13
+ "stage": 2,
14
+ "allgather_partitions": true,
15
+ "allgather_bucket_size": 5e8,
16
+ "overlap_comm": false,
17
+ "reduce_scatter": true,
18
+ "reduce_bucket_size": 5e8,
19
+ "contiguous_gradients" : true
20
+ }
21
+ }
ptuning/ds_train_finetune.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ LR=1e-4
3
+
4
+ MASTER_PORT=$(shuf -n 1 -i 10000-65535)
5
+
6
+ deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \
7
+ --deepspeed deepspeed.json \
8
+ --do_train \
9
+ --train_file AdvertiseGen/train.json \
10
+ --test_file AdvertiseGen/dev.json \
11
+ --prompt_column content \
12
+ --response_column summary \
13
+ --overwrite_cache \
14
+ --model_name_or_path THUDM/chatglm-6b \
15
+ --output_dir ./output/adgen-chatglm-6b-ft-$LR \
16
+ --overwrite_output_dir \
17
+ --max_source_length 64 \
18
+ --max_target_length 64 \
19
+ --per_device_train_batch_size 4 \
20
+ --per_device_eval_batch_size 1 \
21
+ --gradient_accumulation_steps 1 \
22
+ --predict_with_generate \
23
+ --max_steps 5000 \
24
+ --logging_steps 10 \
25
+ --save_steps 1000 \
26
+ --learning_rate $LR \
27
+ --fp16
28
+
ptuning/evaluate.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRE_SEQ_LEN=128
2
+ CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
3
+ STEP=3000
4
+
5
+ CUDA_VISIBLE_DEVICES=0 python3 main.py \
6
+ --do_predict \
7
+ --validation_file AdvertiseGen/dev.json \
8
+ --test_file AdvertiseGen/dev.json \
9
+ --overwrite_cache \
10
+ --prompt_column content \
11
+ --response_column summary \
12
+ --model_name_or_path THUDM/chatglm-6b \
13
+ --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
14
+ --output_dir ./output/$CHECKPOINT \
15
+ --overwrite_output_dir \
16
+ --max_source_length 64 \
17
+ --max_target_length 64 \
18
+ --per_device_eval_batch_size 1 \
19
+ --predict_with_generate \
20
+ --pre_seq_len $PRE_SEQ_LEN \
21
+ --quantization_bit 4
ptuning/evaluate_finetune.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CHECKPOINT=adgen-chatglm-6b-ft-1e-4
2
+ STEP=3000
3
+
4
+ CUDA_VISIBLE_DEVICES=0 python3 main.py \
5
+ --do_predict \
6
+ --validation_file AdvertiseGen/dev.json \
7
+ --test_file AdvertiseGen/dev.json \
8
+ --overwrite_cache \
9
+ --prompt_column content \
10
+ --response_column summary \
11
+ --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \
12
+ --output_dir ./output/$CHECKPOINT \
13
+ --overwrite_output_dir \
14
+ --max_source_length 256 \
15
+ --max_target_length 256 \
16
+ --per_device_eval_batch_size 1 \
17
+ --predict_with_generate \
18
+ --fp16_full_eval
ptuning/main.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import sys
24
+ import json
25
+
26
+ import numpy as np
27
+ from datasets import load_dataset
28
+ import jieba
29
+ from rouge_chinese import Rouge
30
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
31
+ import torch
32
+
33
+ import transformers
34
+ from transformers import (
35
+ AutoConfig,
36
+ AutoModel,
37
+ AutoTokenizer,
38
+ DataCollatorForSeq2Seq,
39
+ HfArgumentParser,
40
+ Seq2SeqTrainingArguments,
41
+ set_seed,
42
+ )
43
+ from trainer_seq2seq import Seq2SeqTrainer
44
+
45
+ from arguments import ModelArguments, DataTrainingArguments
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ def main():
50
+
51
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
52
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
53
+ # If we pass only one argument to the script and it's the path to a json file,
54
+ # let's parse it to get our arguments.
55
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
56
+ else:
57
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
58
+
59
+ # Setup logging
60
+ logging.basicConfig(
61
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
62
+ datefmt="%m/%d/%Y %H:%M:%S",
63
+ handlers=[logging.StreamHandler(sys.stdout)],
64
+ )
65
+
66
+ if training_args.should_log:
67
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
68
+ transformers.utils.logging.set_verbosity_info()
69
+
70
+ log_level = training_args.get_process_log_level()
71
+ logger.setLevel(log_level)
72
+ # datasets.utils.logging.set_verbosity(log_level)
73
+ transformers.utils.logging.set_verbosity(log_level)
74
+ transformers.utils.logging.enable_default_handler()
75
+ transformers.utils.logging.enable_explicit_format()
76
+
77
+ # Log on each process the small summary:
78
+ logger.warning(
79
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
80
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
81
+ )
82
+ logger.info(f"Training/evaluation parameters {training_args}")
83
+
84
+ # Set seed before initializing model.
85
+ set_seed(training_args.seed)
86
+
87
+ # Load dataset
88
+ data_files = {}
89
+ if data_args.train_file is not None:
90
+ data_files["train"] = data_args.train_file
91
+ extension = data_args.train_file.split(".")[-1]
92
+ if data_args.validation_file is not None:
93
+ data_files["validation"] = data_args.validation_file
94
+ extension = data_args.validation_file.split(".")[-1]
95
+ if data_args.test_file is not None:
96
+ data_files["test"] = data_args.test_file
97
+ extension = data_args.test_file.split(".")[-1]
98
+
99
+ raw_datasets = load_dataset(
100
+ extension,
101
+ data_files=data_files,
102
+ cache_dir=model_args.cache_dir,
103
+ use_auth_token=True if model_args.use_auth_token else None,
104
+ )
105
+
106
+ # Load pretrained model and tokenizer
107
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
108
+ config.pre_seq_len = model_args.pre_seq_len
109
+ config.prefix_projection = model_args.prefix_projection
110
+
111
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
112
+
113
+ if model_args.ptuning_checkpoint is not None:
114
+ # Evaluation
115
+ # Loading extra state dict of prefix encoder
116
+ model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
117
+ prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
118
+ new_prefix_state_dict = {}
119
+ for k, v in prefix_state_dict.items():
120
+ if k.startswith("transformer.prefix_encoder."):
121
+ new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
122
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
123
+ else:
124
+ model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
125
+
126
+ if model_args.quantization_bit is not None:
127
+ print(f"Quantized to {model_args.quantization_bit} bit")
128
+ model = model.quantize(model_args.quantization_bit)
129
+ if model_args.pre_seq_len is not None:
130
+ # P-tuning v2
131
+ model = model.half()
132
+ model.transformer.prefix_encoder.float()
133
+ else:
134
+ # Finetune
135
+ model = model.float()
136
+
137
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
138
+
139
+ # Preprocessing the datasets.
140
+ # We need to tokenize inputs and targets.
141
+ if training_args.do_train:
142
+ column_names = raw_datasets["train"].column_names
143
+ elif training_args.do_eval:
144
+ column_names = raw_datasets["validation"].column_names
145
+ elif training_args.do_predict:
146
+ column_names = raw_datasets["test"].column_names
147
+ else:
148
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
149
+ return
150
+
151
+ # Get the column names for input/target.
152
+ prompt_column = data_args.prompt_column
153
+ response_column = data_args.response_column
154
+ history_column = data_args.history_column
155
+
156
+ # Temporarily set max_target_length for training.
157
+ max_target_length = data_args.max_target_length
158
+
159
+ def preprocess_function_eval(examples):
160
+ inputs, targets = [], []
161
+ for i in range(len(examples[prompt_column])):
162
+ if examples[prompt_column][i] and examples[response_column][i]:
163
+ query = examples[prompt_column][i]
164
+ if history_column is None or len(examples[history_column][i]) == 0:
165
+ prompt = query
166
+ else:
167
+ prompt = ""
168
+ history = examples[history_column][i]
169
+ for turn_idx, (old_query, response) in enumerate(history):
170
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
171
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
172
+ inputs.append(prompt)
173
+ targets.append(examples[response_column][i])
174
+
175
+ inputs = [prefix + inp for inp in inputs]
176
+ model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
177
+ labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
178
+
179
+ if data_args.ignore_pad_token_for_loss:
180
+ labels["input_ids"] = [
181
+ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
182
+ ]
183
+ model_inputs["labels"] = labels["input_ids"]
184
+
185
+ return model_inputs
186
+
187
+ def preprocess_function_train(examples):
188
+ max_seq_length = data_args.max_source_length + data_args.max_target_length
189
+
190
+ model_inputs = {
191
+ "input_ids": [],
192
+ "labels": [],
193
+ }
194
+ for i in range(len(examples[prompt_column])):
195
+ if examples[prompt_column][i] and examples[response_column][i]:
196
+ query, answer = examples[prompt_column][i], examples[response_column][i]
197
+
198
+ if history_column is None:
199
+ prompt = query
200
+ else:
201
+ prompt = ""
202
+ history = examples[history_column][i]
203
+ for turn_idx, (old_query, response) in enumerate(history):
204
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
205
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
206
+
207
+ prompt = prefix + prompt
208
+ a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
209
+ b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
210
+
211
+ if len(a_ids) > data_args.max_source_length - 1:
212
+ a_ids = a_ids[: data_args.max_source_length - 1]
213
+
214
+ if len(b_ids) > data_args.max_target_length - 2:
215
+ b_ids = b_ids[: data_args.max_target_length - 2]
216
+
217
+ input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
218
+
219
+ context_length = input_ids.index(tokenizer.bos_token_id)
220
+ mask_position = context_length - 1
221
+ labels = [-100] * context_length + input_ids[mask_position+1:]
222
+
223
+ pad_len = max_seq_length - len(input_ids)
224
+ input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
225
+ labels = labels + [tokenizer.pad_token_id] * pad_len
226
+ if data_args.ignore_pad_token_for_loss:
227
+ labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
228
+
229
+ model_inputs["input_ids"].append(input_ids)
230
+ model_inputs["labels"].append(labels)
231
+
232
+ return model_inputs
233
+
234
+ def print_dataset_example(example):
235
+ print("input_ids",example["input_ids"])
236
+ print("inputs", tokenizer.decode(example["input_ids"]))
237
+ print("label_ids", example["labels"])
238
+ print("labels", tokenizer.decode(example["labels"]))
239
+
240
+ if training_args.do_train:
241
+ if "train" not in raw_datasets:
242
+ raise ValueError("--do_train requires a train dataset")
243
+ train_dataset = raw_datasets["train"]
244
+ if data_args.max_train_samples is not None:
245
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
246
+ train_dataset = train_dataset.select(range(max_train_samples))
247
+ with training_args.main_process_first(desc="train dataset map pre-processing"):
248
+ train_dataset = train_dataset.map(
249
+ preprocess_function_train,
250
+ batched=True,
251
+ num_proc=data_args.preprocessing_num_workers,
252
+ remove_columns=column_names,
253
+ load_from_cache_file=not data_args.overwrite_cache,
254
+ desc="Running tokenizer on train dataset",
255
+ )
256
+ print_dataset_example(train_dataset[0])
257
+
258
+ if training_args.do_eval:
259
+ max_target_length = data_args.val_max_target_length
260
+ if "validation" not in raw_datasets:
261
+ raise ValueError("--do_eval requires a validation dataset")
262
+ eval_dataset = raw_datasets["validation"]
263
+ if data_args.max_eval_samples is not None:
264
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
265
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
266
+ with training_args.main_process_first(desc="validation dataset map pre-processing"):
267
+ eval_dataset = eval_dataset.map(
268
+ preprocess_function_eval,
269
+ batched=True,
270
+ num_proc=data_args.preprocessing_num_workers,
271
+ remove_columns=column_names,
272
+ load_from_cache_file=not data_args.overwrite_cache,
273
+ desc="Running tokenizer on validation dataset",
274
+ )
275
+ print_dataset_example(eval_dataset[0])
276
+
277
+ if training_args.do_predict:
278
+ max_target_length = data_args.val_max_target_length
279
+ if "test" not in raw_datasets:
280
+ raise ValueError("--do_predict requires a test dataset")
281
+ predict_dataset = raw_datasets["test"]
282
+ if data_args.max_predict_samples is not None:
283
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
284
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
285
+ with training_args.main_process_first(desc="prediction dataset map pre-processing"):
286
+ predict_dataset = predict_dataset.map(
287
+ preprocess_function_eval,
288
+ batched=True,
289
+ num_proc=data_args.preprocessing_num_workers,
290
+ remove_columns=column_names,
291
+ load_from_cache_file=not data_args.overwrite_cache,
292
+ desc="Running tokenizer on prediction dataset",
293
+ )
294
+ print_dataset_example(predict_dataset[0])
295
+
296
+ # Data collator
297
+ label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
298
+ data_collator = DataCollatorForSeq2Seq(
299
+ tokenizer,
300
+ model=model,
301
+ label_pad_token_id=label_pad_token_id,
302
+ pad_to_multiple_of=None,
303
+ padding=False
304
+ )
305
+
306
+ # Metric
307
+ def compute_metrics(eval_preds):
308
+ preds, labels = eval_preds
309
+ if isinstance(preds, tuple):
310
+ preds = preds[0]
311
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
312
+ if data_args.ignore_pad_token_for_loss:
313
+ # Replace -100 in the labels as we can't decode them.
314
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
315
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
316
+
317
+ score_dict = {
318
+ "rouge-1": [],
319
+ "rouge-2": [],
320
+ "rouge-l": [],
321
+ "bleu-4": []
322
+ }
323
+ for pred, label in zip(decoded_preds, decoded_labels):
324
+ hypothesis = list(jieba.cut(pred))
325
+ reference = list(jieba.cut(label))
326
+ rouge = Rouge()
327
+ scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
328
+ result = scores[0]
329
+
330
+ for k, v in result.items():
331
+ score_dict[k].append(round(v["f"] * 100, 4))
332
+ bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
333
+ score_dict["bleu-4"].append(round(bleu_score * 100, 4))
334
+
335
+ for k, v in score_dict.items():
336
+ score_dict[k] = float(np.mean(v))
337
+ return score_dict
338
+
339
+ # Override the decoding parameters of Seq2SeqTrainer
340
+ training_args.generation_max_length = (
341
+ training_args.generation_max_length
342
+ if training_args.generation_max_length is not None
343
+ else data_args.val_max_target_length
344
+ )
345
+ training_args.generation_num_beams = (
346
+ data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
347
+ )
348
+ # Initialize our Trainer
349
+ trainer = Seq2SeqTrainer(
350
+ model=model,
351
+ args=training_args,
352
+ train_dataset=train_dataset if training_args.do_train else None,
353
+ eval_dataset=eval_dataset if training_args.do_eval else None,
354
+ tokenizer=tokenizer,
355
+ data_collator=data_collator,
356
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
357
+ save_prefixencoder=model_args.pre_seq_len is not None
358
+ )
359
+
360
+ # Training
361
+ if training_args.do_train:
362
+ checkpoint = None
363
+ if training_args.resume_from_checkpoint is not None:
364
+ checkpoint = training_args.resume_from_checkpoint
365
+ # elif last_checkpoint is not None:
366
+ # checkpoint = last_checkpoint
367
+ model.gradient_checkpointing_enable()
368
+ model.enable_input_require_grads()
369
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
370
+ # trainer.save_model() # Saves the tokenizer too for easy upload
371
+
372
+ metrics = train_result.metrics
373
+ max_train_samples = (
374
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
375
+ )
376
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
377
+
378
+ trainer.log_metrics("train", metrics)
379
+ trainer.save_metrics("train", metrics)
380
+ trainer.save_state()
381
+
382
+ # Evaluation
383
+ results = {}
384
+ max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
385
+ if training_args.do_eval:
386
+ logger.info("*** Evaluate ***")
387
+ metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
388
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
389
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
390
+
391
+ trainer.log_metrics("eval", metrics)
392
+ trainer.save_metrics("eval", metrics)
393
+
394
+ if training_args.do_predict:
395
+ logger.info("*** Predict ***")
396
+ predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
397
+ metrics = predict_results.metrics
398
+ max_predict_samples = (
399
+ data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
400
+ )
401
+ metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
402
+
403
+ trainer.log_metrics("predict", metrics)
404
+ trainer.save_metrics("predict", metrics)
405
+
406
+ if trainer.is_world_process_zero():
407
+ if training_args.predict_with_generate:
408
+ predictions = tokenizer.batch_decode(
409
+ predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
410
+ )
411
+ predictions = [pred.strip() for pred in predictions]
412
+ labels = tokenizer.batch_decode(
413
+ predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
414
+ )
415
+ labels = [label.strip() for label in labels]
416
+ output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
417
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
418
+ for p, l in zip(predictions, labels):
419
+ res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
420
+ writer.write(f"{res}\n")
421
+ return results
422
+
423
+
424
+ def _mp_fn(index):
425
+ # For xla_spawn (TPUs)
426
+ main()
427
+
428
+
429
+ if __name__ == "__main__":
430
+ main()
ptuning/train.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRE_SEQ_LEN=128
2
+ LR=2e-2
3
+
4
+ CUDA_VISIBLE_DEVICES=0 python3 main.py \
5
+ --do_train \
6
+ --train_file AdvertiseGen/train.json \
7
+ --validation_file AdvertiseGen/dev.json \
8
+ --prompt_column content \
9
+ --response_column summary \
10
+ --overwrite_cache \
11
+ --model_name_or_path THUDM/chatglm-6b \
12
+ --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
13
+ --overwrite_output_dir \
14
+ --max_source_length 64 \
15
+ --max_target_length 64 \
16
+ --per_device_train_batch_size 1 \
17
+ --per_device_eval_batch_size 1 \
18
+ --gradient_accumulation_steps 16 \
19
+ --predict_with_generate \
20
+ --max_steps 3000 \
21
+ --logging_steps 10 \
22
+ --save_steps 1000 \
23
+ --learning_rate $LR \
24
+ --pre_seq_len $PRE_SEQ_LEN \
25
+ --quantization_bit 4
26
+
ptuning/train_chat.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRE_SEQ_LEN=128
2
+ LR=1e-2
3
+
4
+ CUDA_VISIBLE_DEVICES=0 python3 main.py \
5
+ --do_train \
6
+ --train_file $CHAT_TRAIN_DATA \
7
+ --validation_file $CHAT_VAL_DATA \
8
+ --prompt_column prompt \
9
+ --response_column response \
10
+ --history_column history \
11
+ --overwrite_cache \
12
+ --model_name_or_path THUDM/chatglm-6b \
13
+ --output_dir $CHECKPOINT_NAME \
14
+ --overwrite_output_dir \
15
+ --max_source_length 256 \
16
+ --max_target_length 256 \
17
+ --per_device_train_batch_size 1 \
18
+ --per_device_eval_batch_size 1 \
19
+ --gradient_accumulation_steps 16 \
20
+ --predict_with_generate \
21
+ --max_steps 3000 \
22
+ --logging_steps 10 \
23
+ --save_steps 1000 \
24
+ --learning_rate $LR \
25
+ --pre_seq_len $PRE_SEQ_LEN \
26
+ --quantization_bit 4
27
+
ptuning/trainer.py ADDED
The diff for this file is too large to render. See raw diff
 
ptuning/trainer_seq2seq.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.utils.data import Dataset
20
+
21
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
22
+ from trainer import Trainer
23
+ from transformers.trainer_utils import PredictionOutput
24
+ from transformers.utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class Seq2SeqTrainer(Trainer):
31
+ def evaluate(
32
+ self,
33
+ eval_dataset: Optional[Dataset] = None,
34
+ ignore_keys: Optional[List[str]] = None,
35
+ metric_key_prefix: str = "eval",
36
+ **gen_kwargs
37
+ ) -> Dict[str, float]:
38
+ """
39
+ Run evaluation and returns metrics.
40
+
41
+ The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
42
+ (pass it to the init `compute_metrics` argument).
43
+
44
+ You can also subclass and override this method to inject custom behavior.
45
+
46
+ Args:
47
+ eval_dataset (`Dataset`, *optional*):
48
+ Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
49
+ not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
50
+ method.
51
+ ignore_keys (`List[str]`, *optional*):
52
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
53
+ gathering predictions.
54
+ metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
55
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
56
+ "eval_bleu" if the prefix is `"eval"` (default)
57
+ max_length (`int`, *optional*):
58
+ The maximum target length to use when predicting with the generate method.
59
+ num_beams (`int`, *optional*):
60
+ Number of beams for beam search that will be used when predicting with the generate method. 1 means no
61
+ beam search.
62
+ gen_kwargs:
63
+ Additional `generate` specific kwargs.
64
+
65
+ Returns:
66
+ A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
67
+ dictionary also contains the epoch number which comes from the training state.
68
+ """
69
+
70
+ gen_kwargs = gen_kwargs.copy()
71
+ if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
72
+ gen_kwargs["max_length"] = self.args.generation_max_length
73
+ gen_kwargs["num_beams"] = (
74
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
75
+ )
76
+ self._gen_kwargs = gen_kwargs
77
+
78
+ return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
79
+
80
+ def predict(
81
+ self,
82
+ test_dataset: Dataset,
83
+ ignore_keys: Optional[List[str]] = None,
84
+ metric_key_prefix: str = "test",
85
+ **gen_kwargs
86
+ ) -> PredictionOutput:
87
+ """
88
+ Run prediction and returns predictions and potential metrics.
89
+
90
+ Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
91
+ will also return metrics, like in `evaluate()`.
92
+
93
+ Args:
94
+ test_dataset (`Dataset`):
95
+ Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
96
+ `model.forward()` method are automatically removed. Has to implement the method `__len__`
97
+ ignore_keys (`List[str]`, *optional*):
98
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
99
+ gathering predictions.
100
+ metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
101
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
102
+ "eval_bleu" if the prefix is `"eval"` (default)
103
+ max_length (`int`, *optional*):
104
+ The maximum target length to use when predicting with the generate method.
105
+ num_beams (`int`, *optional*):
106
+ Number of beams for beam search that will be used when predicting with the generate method. 1 means no
107
+ beam search.
108
+ gen_kwargs:
109
+ Additional `generate` specific kwargs.
110
+
111
+ <Tip>
112
+
113
+ If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
114
+ padding in a token classification task) the predictions will be padded (on the right) to allow for
115
+ concatenation into one array. The padding index is -100.
116
+
117
+ </Tip>
118
+
119
+ Returns: *NamedTuple* A namedtuple with the following keys:
120
+
121
+ - predictions (`np.ndarray`): The predictions on `test_dataset`.
122
+ - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
123
+ - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
124
+ labels).
125
+ """
126
+
127
+ gen_kwargs = gen_kwargs.copy()
128
+ if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
129
+ gen_kwargs["max_length"] = self.args.generation_max_length
130
+ gen_kwargs["num_beams"] = (
131
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
132
+ )
133
+ self._gen_kwargs = gen_kwargs
134
+
135
+
136
+ return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
137
+
138
+ def prediction_step(
139
+ self,
140
+ model: nn.Module,
141
+ inputs: Dict[str, Union[torch.Tensor, Any]],
142
+ prediction_loss_only: bool,
143
+ ignore_keys: Optional[List[str]] = None,
144
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
145
+ """
146
+ Perform an evaluation step on `model` using `inputs`.
147
+
148
+ Subclass and override to inject custom behavior.
149
+
150
+ Args:
151
+ model (`nn.Module`):
152
+ The model to evaluate.
153
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
154
+ The inputs and targets of the model.
155
+
156
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
157
+ argument `labels`. Check your model's documentation for all accepted arguments.
158
+ prediction_loss_only (`bool`):
159
+ Whether or not to return the loss only.
160
+
161
+ Return:
162
+ Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
163
+ labels (each being optional).
164
+ """
165
+
166
+ if not self.args.predict_with_generate or prediction_loss_only:
167
+ return super().prediction_step(
168
+ model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
169
+ )
170
+
171
+ has_labels = "labels" in inputs
172
+ inputs = self._prepare_inputs(inputs)
173
+
174
+ # XXX: adapt synced_gpus for fairscale as well
175
+ gen_kwargs = self._gen_kwargs.copy()
176
+ if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
177
+ gen_kwargs["max_length"] = self.model.config.max_length
178
+ gen_kwargs["num_beams"] = (
179
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
180
+ )
181
+ default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
182
+ gen_kwargs["synced_gpus"] = (
183
+ gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
184
+ )
185
+
186
+ if "attention_mask" in inputs:
187
+ gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
188
+ if "position_ids" in inputs:
189
+ gen_kwargs["position_ids"] = inputs.get("position_ids", None)
190
+ if "global_attention_mask" in inputs:
191
+ gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
192
+
193
+ # prepare generation inputs
194
+ # some encoder-decoder models can have varying encoder's and thus
195
+ # varying model input names
196
+ if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
197
+ generation_inputs = inputs[self.model.encoder.main_input_name]
198
+ else:
199
+ generation_inputs = inputs[self.model.main_input_name]
200
+
201
+ gen_kwargs["input_ids"] = generation_inputs
202
+ generated_tokens = self.model.generate(**gen_kwargs)
203
+ generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:]
204
+
205
+ # in case the batch is shorter than max length, the output should be padded
206
+ if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
207
+ generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
208
+ elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
209
+ gen_kwargs["max_new_tokens"] + 1
210
+ ):
211
+ generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
212
+
213
+ loss = None
214
+
215
+ if self.args.prediction_loss_only:
216
+ return (loss, None, None)
217
+
218
+ if has_labels:
219
+ labels = inputs["labels"]
220
+ if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
221
+ labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
222
+ elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
223
+ gen_kwargs["max_new_tokens"] + 1
224
+ ):
225
+ labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
226
+ else:
227
+ labels = None
228
+
229
+ return (loss, generated_tokens, labels)
230
+
231
+ def _pad_tensors_to_max_len(self, tensor, max_length):
232
+ if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
233
+ # If PAD token is not defined at least EOS token has to be defined
234
+ pad_token_id = (
235
+ self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
236
+ )
237
+ else:
238
+ if self.model.config.pad_token_id is not None:
239
+ pad_token_id = self.model.config.pad_token_id
240
+ else:
241
+ raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
242
+
243
+ padded_tensor = pad_token_id * torch.ones(
244
+ (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
245
+ )
246
+ padded_tensor[:, : tensor.shape[-1]] = tensor
247
+ return padded_tensor
ptuning/web_demo.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ import gradio as gr
4
+ import mdtex2html
5
+
6
+ import torch
7
+ import transformers
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModel,
11
+ AutoTokenizer,
12
+ AutoTokenizer,
13
+ DataCollatorForSeq2Seq,
14
+ HfArgumentParser,
15
+ Seq2SeqTrainingArguments,
16
+ set_seed,
17
+ )
18
+
19
+ from arguments import ModelArguments, DataTrainingArguments
20
+
21
+
22
+ model = None
23
+ tokenizer = None
24
+
25
+ """Override Chatbot.postprocess"""
26
+
27
+
28
+ def postprocess(self, y):
29
+ if y is None:
30
+ return []
31
+ for i, (message, response) in enumerate(y):
32
+ y[i] = (
33
+ None if message is None else mdtex2html.convert((message)),
34
+ None if response is None else mdtex2html.convert(response),
35
+ )
36
+ return y
37
+
38
+
39
+ gr.Chatbot.postprocess = postprocess
40
+
41
+
42
+ def parse_text(text):
43
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
44
+ lines = text.split("\n")
45
+ lines = [line for line in lines if line != ""]
46
+ count = 0
47
+ for i, line in enumerate(lines):
48
+ if "```" in line:
49
+ count += 1
50
+ items = line.split('`')
51
+ if count % 2 == 1:
52
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
53
+ else:
54
+ lines[i] = f'<br></code></pre>'
55
+ else:
56
+ if i > 0:
57
+ if count % 2 == 1:
58
+ line = line.replace("`", "\`")
59
+ line = line.replace("<", "&lt;")
60
+ line = line.replace(">", "&gt;")
61
+ line = line.replace(" ", "&nbsp;")
62
+ line = line.replace("*", "&ast;")
63
+ line = line.replace("_", "&lowbar;")
64
+ line = line.replace("-", "&#45;")
65
+ line = line.replace(".", "&#46;")
66
+ line = line.replace("!", "&#33;")
67
+ line = line.replace("(", "&#40;")
68
+ line = line.replace(")", "&#41;")
69
+ line = line.replace("$", "&#36;")
70
+ lines[i] = "<br>"+line
71
+ text = "".join(lines)
72
+ return text
73
+
74
+
75
+ def predict(input, chatbot, max_length, top_p, temperature, history):
76
+ chatbot.append((parse_text(input), ""))
77
+ for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
78
+ temperature=temperature):
79
+ chatbot[-1] = (parse_text(input), parse_text(response))
80
+
81
+ yield chatbot, history
82
+
83
+
84
+ def reset_user_input():
85
+ return gr.update(value='')
86
+
87
+
88
+ def reset_state():
89
+ return [], []
90
+
91
+
92
+ with gr.Blocks() as demo:
93
+ gr.HTML("""<h1 align="center">ChatGLM</h1>""")
94
+
95
+ chatbot = gr.Chatbot()
96
+ with gr.Row():
97
+ with gr.Column(scale=4):
98
+ with gr.Column(scale=12):
99
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
100
+ container=False)
101
+ with gr.Column(min_width=32, scale=1):
102
+ submitBtn = gr.Button("Submit", variant="primary")
103
+ with gr.Column(scale=1):
104
+ emptyBtn = gr.Button("Clear History")
105
+ max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
106
+ top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
107
+ temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
108
+
109
+ history = gr.State([])
110
+
111
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
112
+ show_progress=True)
113
+ submitBtn.click(reset_user_input, [], [user_input])
114
+
115
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
116
+
117
+
118
+
119
+ def main():
120
+ global model, tokenizer
121
+
122
+ parser = HfArgumentParser((
123
+ ModelArguments))
124
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
125
+ # If we pass only one argument to the script and it's the path to a json file,
126
+ # let's parse it to get our arguments.
127
+ model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
128
+ else:
129
+ model_args = parser.parse_args_into_dataclasses()[0]
130
+
131
+ tokenizer = AutoTokenizer.from_pretrained(
132
+ model_args.model_name_or_path, trust_remote_code=True)
133
+ config = AutoConfig.from_pretrained(
134
+ model_args.model_name_or_path, trust_remote_code=True)
135
+
136
+ config.pre_seq_len = model_args.pre_seq_len
137
+ config.prefix_projection = model_args.prefix_projection
138
+
139
+ if model_args.ptuning_checkpoint is not None:
140
+ print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}")
141
+ model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
142
+ prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
143
+ new_prefix_state_dict = {}
144
+ for k, v in prefix_state_dict.items():
145
+ if k.startswith("transformer.prefix_encoder."):
146
+ new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
147
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
148
+ else:
149
+ model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
150
+
151
+ if model_args.quantization_bit is not None:
152
+ print(f"Quantized to {model_args.quantization_bit} bit")
153
+ model = model.quantize(model_args.quantization_bit)
154
+
155
+ if model_args.pre_seq_len is not None:
156
+ # P-tuning v2
157
+ model = model.half().cuda()
158
+ model.transformer.prefix_encoder.float().cuda()
159
+
160
+ model = model.eval()
161
+ demo.queue().launch(share=False, inbrowser=True)
162
+
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()
ptuning/web_demo.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ PRE_SEQ_LEN=128
2
+
3
+ CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
4
+ --model_name_or_path THUDM/chatglm-6b \
5
+ --ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \
6
+ --pre_seq_len $PRE_SEQ_LEN
7
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ protobuf
2
+ transformers==4.27.1
3
+ cpm_kernels
4
+ torch>=1.10
5
+ gradio
6
+ mdtex2html
7
+ sentencepiece
8
+ accelerate
resources/WECHAT.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src=wechat.jpg width="60%"/>
3
+
4
+ <p> 扫码关注公众号,加入「ChatGLM交流群」 </p>
5
+ <p> Scan the QR code to follow the official account and join the "ChatGLM Discussion Group" </p>
6
+ </div>
7
+
resources/cli-demo.png ADDED