agent404 commited on
Commit
fe781a6
·
1 Parent(s): f090575

upload files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. RepCodec/.gitignore +160 -0
  2. RepCodec/LICENSE +428 -0
  3. RepCodec/README.md +273 -0
  4. RepCodec/examples/data2vec_audio.py +541 -0
  5. RepCodec/examples/data2vec_feature_reader.py +87 -0
  6. RepCodec/examples/dump_feature.py +142 -0
  7. RepCodec/examples/feature_utils.py +70 -0
  8. RepCodec/examples/hubert_feature_reader.py +64 -0
  9. RepCodec/examples/tokens/data2vec_base_l6_dev-clean.tokens +0 -0
  10. RepCodec/examples/tokens/data2vec_large_l18_dev-clean.tokens +0 -0
  11. RepCodec/examples/tokens/hubert_base_l9_dev-clean.tokens +0 -0
  12. RepCodec/examples/tokens/hubert_large_l18_dev-clean.tokens +0 -0
  13. RepCodec/examples/tokens/whisper_large_l32_dev-clean.tokens +0 -0
  14. RepCodec/examples/tokens/whisper_medium_l24_dev-clean.tokens +0 -0
  15. RepCodec/examples/whisper_feature_reader.py +110 -0
  16. RepCodec/examples/whisper_model.py +58 -0
  17. RepCodec/repcodec/RepCodec.py +84 -0
  18. RepCodec/repcodec/configs/repcodec_dim1024.yaml +18 -0
  19. RepCodec/repcodec/configs/repcodec_dim1280.yaml +18 -0
  20. RepCodec/repcodec/configs/repcodec_dim768.yaml +18 -0
  21. RepCodec/repcodec/layers/conv_layer.py +95 -0
  22. RepCodec/repcodec/layers/vq_module.py +155 -0
  23. RepCodec/repcodec/modules/decoder.py +109 -0
  24. RepCodec/repcodec/modules/encoder.py +89 -0
  25. RepCodec/repcodec/modules/projector.py +32 -0
  26. RepCodec/repcodec/modules/quantizer.py +46 -0
  27. RepCodec/repcodec/modules/residual_unit.py +39 -0
  28. RepCodec/repcodec/tokenize.py +212 -0
  29. RepCodec/setup.py +31 -0
  30. RepCodec/train.py +228 -0
  31. RepCodec/train_configs/ex_dim768_mse.yaml +74 -0
  32. RepCodec/trainer/autoencoder.py +287 -0
  33. __pycache__/post_process_audio.cpython-310.pyc +0 -0
  34. __pycache__/vocoder.cpython-310.pyc +0 -0
  35. decoders/config.yaml +15 -0
  36. decoders/decoder_131000.pth +3 -0
  37. decoders/decoder_151000.pth +3 -0
  38. descriptaudiocodec/dac/__init__.py +16 -0
  39. descriptaudiocodec/dac/__main__.py +36 -0
  40. descriptaudiocodec/dac/__pycache__/__init__.cpython-310.pyc +0 -0
  41. descriptaudiocodec/dac/__pycache__/__init__.cpython-38.pyc +0 -0
  42. descriptaudiocodec/dac/__pycache__/__init__.cpython-39.pyc +0 -0
  43. descriptaudiocodec/dac/compare/__init__.py +0 -0
  44. descriptaudiocodec/dac/compare/encodec.py +54 -0
  45. descriptaudiocodec/dac/model/__init__.py +4 -0
  46. descriptaudiocodec/dac/model/__pycache__/__init__.cpython-310.pyc +0 -0
  47. descriptaudiocodec/dac/model/__pycache__/__init__.cpython-39.pyc +0 -0
  48. descriptaudiocodec/dac/model/__pycache__/base.cpython-310.pyc +0 -0
  49. descriptaudiocodec/dac/model/__pycache__/base.cpython-39.pyc +0 -0
  50. descriptaudiocodec/dac/model/__pycache__/dac.cpython-310.pyc +0 -0
RepCodec/.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
RepCodec/LICENSE ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) ByteDance, Inc. and its affiliates.
4
+ Copyright (c) Chutong Meng
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+ Attribution-NonCommercial 4.0 International
31
+
32
+ =======================================================================
33
+
34
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
35
+ does not provide legal services or legal advice. Distribution of
36
+ Creative Commons public licenses does not create a lawyer-client or
37
+ other relationship. Creative Commons makes its licenses and related
38
+ information available on an "as-is" basis. Creative Commons gives no
39
+ warranties regarding its licenses, any material licensed under their
40
+ terms and conditions, or any related information. Creative Commons
41
+ disclaims all liability for damages resulting from their use to the
42
+ fullest extent possible.
43
+
44
+ Using Creative Commons Public Licenses
45
+
46
+ Creative Commons public licenses provide a standard set of terms and
47
+ conditions that creators and other rights holders may use to share
48
+ original works of authorship and other material subject to copyright
49
+ and certain other rights specified in the public license below. The
50
+ following considerations are for informational purposes only, are not
51
+ exhaustive, and do not form part of our licenses.
52
+
53
+ Considerations for licensors: Our public licenses are
54
+ intended for use by those authorized to give the public
55
+ permission to use material in ways otherwise restricted by
56
+ copyright and certain other rights. Our licenses are
57
+ irrevocable. Licensors should read and understand the terms
58
+ and conditions of the license they choose before applying it.
59
+ Licensors should also secure all rights necessary before
60
+ applying our licenses so that the public can reuse the
61
+ material as expected. Licensors should clearly mark any
62
+ material not subject to the license. This includes other CC-
63
+ licensed material, or material used under an exception or
64
+ limitation to copyright. More considerations for licensors:
65
+ wiki.creativecommons.org/Considerations_for_licensors
66
+
67
+ Considerations for the public: By using one of our public
68
+ licenses, a licensor grants the public permission to use the
69
+ licensed material under specified terms and conditions. If
70
+ the licensor's permission is not necessary for any reason--for
71
+ example, because of any applicable exception or limitation to
72
+ copyright--then that use is not regulated by the license. Our
73
+ licenses grant only permissions under copyright and certain
74
+ other rights that a licensor has authority to grant. Use of
75
+ the licensed material may still be restricted for other
76
+ reasons, including because others have copyright or other
77
+ rights in the material. A licensor may make special requests,
78
+ such as asking that all changes be marked or described.
79
+ Although not required by our licenses, you are encouraged to
80
+ respect those requests where reasonable. More_considerations
81
+ for the public:
82
+ wiki.creativecommons.org/Considerations_for_licensees
83
+
84
+ =======================================================================
85
+
86
+ Creative Commons Attribution-NonCommercial 4.0 International Public
87
+ License
88
+
89
+ By exercising the Licensed Rights (defined below), You accept and agree
90
+ to be bound by the terms and conditions of this Creative Commons
91
+ Attribution-NonCommercial 4.0 International Public License ("Public
92
+ License"). To the extent this Public License may be interpreted as a
93
+ contract, You are granted the Licensed Rights in consideration of Your
94
+ acceptance of these terms and conditions, and the Licensor grants You
95
+ such rights in consideration of benefits the Licensor receives from
96
+ making the Licensed Material available under these terms and
97
+ conditions.
98
+
99
+ Section 1 -- Definitions.
100
+
101
+ a. Adapted Material means material subject to Copyright and Similar
102
+ Rights that is derived from or based upon the Licensed Material
103
+ and in which the Licensed Material is translated, altered,
104
+ arranged, transformed, or otherwise modified in a manner requiring
105
+ permission under the Copyright and Similar Rights held by the
106
+ Licensor. For purposes of this Public License, where the Licensed
107
+ Material is a musical work, performance, or sound recording,
108
+ Adapted Material is always produced where the Licensed Material is
109
+ synched in timed relation with a moving image.
110
+
111
+ b. Adapter's License means the license You apply to Your Copyright
112
+ and Similar Rights in Your contributions to Adapted Material in
113
+ accordance with the terms and conditions of this Public License.
114
+
115
+ c. Copyright and Similar Rights means copyright and/or similar rights
116
+ closely related to copyright including, without limitation,
117
+ performance, broadcast, sound recording, and Sui Generis Database
118
+ Rights, without regard to how the rights are labeled or
119
+ categorized. For purposes of this Public License, the rights
120
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
121
+ Rights.
122
+ d. Effective Technological Measures means those measures that, in the
123
+ absence of proper authority, may not be circumvented under laws
124
+ fulfilling obligations under Article 11 of the WIPO Copyright
125
+ Treaty adopted on December 20, 1996, and/or similar international
126
+ agreements.
127
+
128
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
129
+ any other exception or limitation to Copyright and Similar Rights
130
+ that applies to Your use of the Licensed Material.
131
+
132
+ f. Licensed Material means the artistic or literary work, database,
133
+ or other material to which the Licensor applied this Public
134
+ License.
135
+
136
+ g. Licensed Rights means the rights granted to You subject to the
137
+ terms and conditions of this Public License, which are limited to
138
+ all Copyright and Similar Rights that apply to Your use of the
139
+ Licensed Material and that the Licensor has authority to license.
140
+
141
+ h. Licensor means the individual(s) or entity(ies) granting rights
142
+ under this Public License.
143
+
144
+ i. NonCommercial means not primarily intended for or directed towards
145
+ commercial advantage or monetary compensation. For purposes of
146
+ this Public License, the exchange of the Licensed Material for
147
+ other material subject to Copyright and Similar Rights by digital
148
+ file-sharing or similar means is NonCommercial provided there is
149
+ no payment of monetary compensation in connection with the
150
+ exchange.
151
+
152
+ j. Share means to provide material to the public by any means or
153
+ process that requires permission under the Licensed Rights, such
154
+ as reproduction, public display, public performance, distribution,
155
+ dissemination, communication, or importation, and to make material
156
+ available to the public including in ways that members of the
157
+ public may access the material from a place and at a time
158
+ individually chosen by them.
159
+
160
+ k. Sui Generis Database Rights means rights other than copyright
161
+ resulting from Directive 96/9/EC of the European Parliament and of
162
+ the Council of 11 March 1996 on the legal protection of databases,
163
+ as amended and/or succeeded, as well as other essentially
164
+ equivalent rights anywhere in the world.
165
+
166
+ l. You means the individual or entity exercising the Licensed Rights
167
+ under this Public License. Your has a corresponding meaning.
168
+
169
+ Section 2 -- Scope.
170
+
171
+ a. License grant.
172
+
173
+ 1. Subject to the terms and conditions of this Public License,
174
+ the Licensor hereby grants You a worldwide, royalty-free,
175
+ non-sublicensable, non-exclusive, irrevocable license to
176
+ exercise the Licensed Rights in the Licensed Material to:
177
+
178
+ a. reproduce and Share the Licensed Material, in whole or
179
+ in part, for NonCommercial purposes only; and
180
+
181
+ b. produce, reproduce, and Share Adapted Material for
182
+ NonCommercial purposes only.
183
+
184
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
185
+ Exceptions and Limitations apply to Your use, this Public
186
+ License does not apply, and You do not need to comply with
187
+ its terms and conditions.
188
+
189
+ 3. Term. The term of this Public License is specified in Section
190
+ 6(a).
191
+
192
+ 4. Media and formats; technical modifications allowed. The
193
+ Licensor authorizes You to exercise the Licensed Rights in
194
+ all media and formats whether now known or hereafter created,
195
+ and to make technical modifications necessary to do so. The
196
+ Licensor waives and/or agrees not to assert any right or
197
+ authority to forbid You from making technical modifications
198
+ necessary to exercise the Licensed Rights, including
199
+ technical modifications necessary to circumvent Effective
200
+ Technological Measures. For purposes of this Public License,
201
+ simply making modifications authorized by this Section 2(a)
202
+ (4) never produces Adapted Material.
203
+
204
+ 5. Downstream recipients.
205
+
206
+ a. Offer from the Licensor -- Licensed Material. Every
207
+ recipient of the Licensed Material automatically
208
+ receives an offer from the Licensor to exercise the
209
+ Licensed Rights under the terms and conditions of this
210
+ Public License.
211
+
212
+ b. No downstream restrictions. You may not offer or impose
213
+ any additional or different terms or conditions on, or
214
+ apply any Effective Technological Measures to, the
215
+ Licensed Material if doing so restricts exercise of the
216
+ Licensed Rights by any recipient of the Licensed
217
+ Material.
218
+
219
+ 6. No endorsement. Nothing in this Public License constitutes or
220
+ may be construed as permission to assert or imply that You
221
+ are, or that Your use of the Licensed Material is, connected
222
+ with, or sponsored, endorsed, or granted official status by,
223
+ the Licensor or others designated to receive attribution as
224
+ provided in Section 3(a)(1)(A)(i).
225
+
226
+ b. Other rights.
227
+
228
+ 1. Moral rights, such as the right of integrity, are not
229
+ licensed under this Public License, nor are publicity,
230
+ privacy, and/or other similar personality rights; however, to
231
+ the extent possible, the Licensor waives and/or agrees not to
232
+ assert any such rights held by the Licensor to the limited
233
+ extent necessary to allow You to exercise the Licensed
234
+ Rights, but not otherwise.
235
+
236
+ 2. Patent and trademark rights are not licensed under this
237
+ Public License.
238
+
239
+ 3. To the extent possible, the Licensor waives any right to
240
+ collect royalties from You for the exercise of the Licensed
241
+ Rights, whether directly or through a collecting society
242
+ under any voluntary or waivable statutory or compulsory
243
+ licensing scheme. In all other cases the Licensor expressly
244
+ reserves any right to collect such royalties, including when
245
+ the Licensed Material is used other than for NonCommercial
246
+ purposes.
247
+
248
+ Section 3 -- License Conditions.
249
+
250
+ Your exercise of the Licensed Rights is expressly made subject to the
251
+ following conditions.
252
+
253
+ a. Attribution.
254
+
255
+ 1. If You Share the Licensed Material (including in modified
256
+ form), You must:
257
+
258
+ a. retain the following if it is supplied by the Licensor
259
+ with the Licensed Material:
260
+
261
+ i. identification of the creator(s) of the Licensed
262
+ Material and any others designated to receive
263
+ attribution, in any reasonable manner requested by
264
+ the Licensor (including by pseudonym if
265
+ designated);
266
+
267
+ ii. a copyright notice;
268
+
269
+ iii. a notice that refers to this Public License;
270
+
271
+ iv. a notice that refers to the disclaimer of
272
+ warranties;
273
+
274
+ v. a URI or hyperlink to the Licensed Material to the
275
+ extent reasonably practicable;
276
+
277
+ b. indicate if You modified the Licensed Material and
278
+ retain an indication of any previous modifications; and
279
+
280
+ c. indicate the Licensed Material is licensed under this
281
+ Public License, and include the text of, or the URI or
282
+ hyperlink to, this Public License.
283
+
284
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
285
+ reasonable manner based on the medium, means, and context in
286
+ which You Share the Licensed Material. For example, it may be
287
+ reasonable to satisfy the conditions by providing a URI or
288
+ hyperlink to a resource that includes the required
289
+ information.
290
+
291
+ 3. If requested by the Licensor, You must remove any of the
292
+ information required by Section 3(a)(1)(A) to the extent
293
+ reasonably practicable.
294
+
295
+ 4. If You Share Adapted Material You produce, the Adapter's
296
+ License You apply must not prevent recipients of the Adapted
297
+ Material from complying with this Public License.
298
+
299
+ Section 4 -- Sui Generis Database Rights.
300
+
301
+ Where the Licensed Rights include Sui Generis Database Rights that
302
+ apply to Your use of the Licensed Material:
303
+
304
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
305
+ to extract, reuse, reproduce, and Share all or a substantial
306
+ portion of the contents of the database for NonCommercial purposes
307
+ only;
308
+
309
+ b. if You include all or a substantial portion of the database
310
+ contents in a database in which You have Sui Generis Database
311
+ Rights, then the database in which You have Sui Generis Database
312
+ Rights (but not its individual contents) is Adapted Material; and
313
+
314
+ c. You must comply with the conditions in Section 3(a) if You Share
315
+ all or a substantial portion of the contents of the database.
316
+
317
+ For the avoidance of doubt, this Section 4 supplements and does not
318
+ replace Your obligations under this Public License where the Licensed
319
+ Rights include other Copyright and Similar Rights.
320
+
321
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
322
+
323
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
324
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
325
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
326
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
327
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
328
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
329
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
330
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
331
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
332
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
333
+
334
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
335
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
336
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
337
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
338
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
339
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
340
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
341
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
342
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
343
+
344
+ c. The disclaimer of warranties and limitation of liability provided
345
+ above shall be interpreted in a manner that, to the extent
346
+ possible, most closely approximates an absolute disclaimer and
347
+ waiver of all liability.
348
+
349
+ Section 6 -- Term and Termination.
350
+
351
+ a. This Public License applies for the term of the Copyright and
352
+ Similar Rights licensed here. However, if You fail to comply with
353
+ this Public License, then Your rights under this Public License
354
+ terminate automatically.
355
+
356
+ b. Where Your right to use the Licensed Material has terminated under
357
+ Section 6(a), it reinstates:
358
+
359
+ 1. automatically as of the date the violation is cured, provided
360
+ it is cured within 30 days of Your discovery of the
361
+ violation; or
362
+
363
+ 2. upon express reinstatement by the Licensor.
364
+
365
+ For the avoidance of doubt, this Section 6(b) does not affect any
366
+ right the Licensor may have to seek remedies for Your violations
367
+ of this Public License.
368
+
369
+ c. For the avoidance of doubt, the Licensor may also offer the
370
+ Licensed Material under separate terms or conditions or stop
371
+ distributing the Licensed Material at any time; however, doing so
372
+ will not terminate this Public License.
373
+
374
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
375
+ License.
376
+
377
+ Section 7 -- Other Terms and Conditions.
378
+
379
+ a. The Licensor shall not be bound by any additional or different
380
+ terms or conditions communicated by You unless expressly agreed.
381
+
382
+ b. Any arrangements, understandings, or agreements regarding the
383
+ Licensed Material not stated herein are separate from and
384
+ independent of the terms and conditions of this Public License.
385
+
386
+ Section 8 -- Interpretation.
387
+
388
+ a. For the avoidance of doubt, this Public License does not, and
389
+ shall not be interpreted to, reduce, limit, restrict, or impose
390
+ conditions on any use of the Licensed Material that could lawfully
391
+ be made without permission under this Public License.
392
+
393
+ b. To the extent possible, if any provision of this Public License is
394
+ deemed unenforceable, it shall be automatically reformed to the
395
+ minimum extent necessary to make it enforceable. If the provision
396
+ cannot be reformed, it shall be severed from this Public License
397
+ without affecting the enforceability of the remaining terms and
398
+ conditions.
399
+
400
+ c. No term or condition of this Public License will be waived and no
401
+ failure to comply consented to unless expressly agreed to by the
402
+ Licensor.
403
+
404
+ d. Nothing in this Public License constitutes or may be interpreted
405
+ as a limitation upon, or waiver of, any privileges and immunities
406
+ that apply to the Licensor or You, including from the legal
407
+ processes of any jurisdiction or authority.
408
+
409
+ =======================================================================
410
+
411
+ Creative Commons is not a party to its public
412
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
413
+ its public licenses to material it publishes and in those instances
414
+ will be considered the “Licensor.” The text of the Creative Commons
415
+ public licenses is dedicated to the public domain under the CC0 Public
416
+ Domain Dedication. Except for the limited purpose of indicating that
417
+ material is shared under a Creative Commons public license or as
418
+ otherwise permitted by the Creative Commons policies published at
419
+ creativecommons.org/policies, Creative Commons does not authorize the
420
+ use of the trademark "Creative Commons" or any other trademark or logo
421
+ of Creative Commons without its prior written consent including,
422
+ without limitation, in connection with any unauthorized modifications
423
+ to any of its public licenses or any other arrangements,
424
+ understandings, or agreements concerning use of licensed material. For
425
+ the avoidance of doubt, this paragraph does not form part of the
426
+ public licenses.
427
+
428
+ Creative Commons may be contacted at creativecommons.org.
RepCodec/README.md ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RepCodec: A Speech Representation Codec for Speech Tokenization
2
+
3
+ > [**RepCodec: A Speech Representation Codec for Speech Tokenization**](https://arxiv.org/abs/2309.00169)
4
+
5
+ ## Introduction
6
+
7
+ **RepCodec** is a speech tokenization method for converting a speech waveform into a sequence of discrete semantic
8
+ tokens.
9
+ The main idea is to train a representation codec which learns a vector quantization codebook through reconstructing the
10
+ input speech representations from speech encoders like HuBERT or data2vec.
11
+ Extensive experiments show that RepCodec significantly outperforms the widely used k-means clustering approach in both
12
+ speech understanding and generation.
13
+ Also, RepCodec generalizes well across various speech encoders and languages.
14
+
15
+ <img src="images/RepCodec.png" alt="se" width="1000" />
16
+
17
+ ## RepCodec Models
18
+
19
+ | Feature Type | Speech Data | RepCodec Model |
20
+ |-----------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------|----------------------------------------------------------------------------------------------------------|
21
+ | [HuBERT base](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models) layer 9 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [hubert_base_l9](https://drive.google.com/file/d/1XD0HKl607FFjri2-VJT7lHQeSpxsCCFO/view?usp=sharing) |
22
+ | [HuBERT large](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models) layer 18 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [hubert_large_l18](https://drive.google.com/file/d/1mTbm5GeJ7gp_5L3QLP-JGXdf8RnRw5n6/view?usp=sharing) |
23
+ | [data2vec base](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2) layer 6 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [data2vec_base_l6](https://drive.google.com/file/d/1d8sf3Ko_fYM9zlaiwxK_4xusLRKV5EMd/view?usp=sharing) |
24
+ | [data2vec large](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2) layer 18 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [data2vec_large_l18](https://drive.google.com/file/d/1nuRIHaejT-uVi4cluftbT8o_JZqar5SU/view?usp=sharing) |
25
+ | [Whisper medium](https://github.com/openai/whisper/tree/main#available-models-and-languages) layer 24 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [whisper_medium_l24](https://drive.google.com/file/d/1V6YJSA2V4iywXrecJAN0oqsa3aHowexZ/view?usp=sharing) |
26
+ | [Whisper large-v2](https://github.com/openai/whisper/tree/main#available-models-and-languages) layer 32 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [whisper_large_l32](https://drive.google.com/file/d/1k_X7ZMPg8iOeDrIJe70v6CHfFygzufXC/view?usp=sharing) |
27
+
28
+ ## Speech Tokenization Using Pre-Trained Models
29
+
30
+ ### Installation
31
+
32
+ Please first install RepCodec by
33
+
34
+ ```
35
+ git clone https://github.com/mct10/RepCodec.git
36
+ cd RepCodec
37
+ pip install .
38
+ ```
39
+
40
+ We used Python 3.9.18 and PyTorch 1.12.1 to test the usage, but the code should be compatible with other recent Python
41
+ and PyTorch versions.
42
+
43
+ ### Representation Preparation
44
+
45
+ We adapt the `dump_hubert_feature.py` script
46
+ from [fairseq](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert/simple_kmeans#hubert-feature)
47
+ to support dumping representations from **data2vec**, **HuBERT**, or **Whisper** encoders.
48
+
49
+ If you use our script (`examples/dump_feature.py`), please also install the following packages:
50
+
51
+ ```
52
+ pip install npy_append_array soundfile
53
+ ```
54
+
55
+ Additionally, if you want to dump representations from
56
+
57
+ - **data2vec** or **HuBERT**: please
58
+ follow [fairseq's instruction](https://github.com/facebookresearch/fairseq#requirements-and-installation) to install
59
+ the latest fairseq.
60
+
61
+ - **Whisper**: please follow [Whispers'instruction](https://github.com/openai/whisper/tree/main#setup) to install the
62
+ latest
63
+ Whisper.
64
+
65
+ Then, you can follow the given examples to dump representations:
66
+
67
+ ```
68
+ # Example 1: dump from HuBERT base layer 9
69
+ # (for data2vec, simply change "model_type" to data2vec and "ckpt_path" to the path of data2vec model)
70
+
71
+ layer=9
72
+
73
+ python3 examples/dump_feature.py \
74
+ --model_type hubert \
75
+ --tsv_path /path/to/tsv/file \
76
+ --ckpt_path /path/to/HuBERT/model \
77
+ --layer ${layer} \
78
+ --feat_dir /dir/to/save/representations
79
+
80
+
81
+ # Example 2: dump from Whisper medium layer 24
82
+
83
+ layer=24
84
+
85
+ python3 examples/dump_feature.py \
86
+ --model_type whisper \
87
+ --tsv_path /path/to/tsv/file \
88
+ --whisper_root /directory/to/save/whisper/model \
89
+ --whisper_name medium \
90
+ --layer ${layer} \
91
+ --feat_dir /dir/to/save/representations
92
+ ```
93
+
94
+ Explanations about the args:
95
+
96
+ - **model_type:** choose from `data2vec`, `hubert`, and `whisper`.
97
+
98
+ - **tsv_path:** path of the tsv file.
99
+ Should have the format of
100
+
101
+ ```
102
+ /dir/to/dataset
103
+ path_of_utterance_1 number_of_frames
104
+ path_of_utterance_2 number_of_frames
105
+ ```
106
+
107
+ You can follow [this script](https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/wav2vec_manifest.py)
108
+ to generate the tsv file.
109
+
110
+ For example, by running
111
+
112
+ ```
113
+ python wav2vec_manifest.py \
114
+ /dir/to/LibriSpeech/dev-clean \
115
+ --dest /dir/to/manifest \
116
+ --ext flac \
117
+ --valid-percent 0
118
+ ```
119
+
120
+ you can obtain the `dev-clean.tsv` in `/dir/to/manifest` for LibriSpeech. (By default, the output file name
121
+ is `train.tsv`. Remember to rename the file.)
122
+
123
+ It should be similar to:
124
+
125
+ ```
126
+ /dir/to/LibriSpeech/dev-clean
127
+ 2277/149896/2277-149896-0026.flac 78720
128
+ 2277/149896/2277-149896-0005.flac 89600
129
+ 2277/149896/2277-149896-0033.flac 45520
130
+ ```
131
+
132
+ - **ckpt_path**:
133
+ must provide for data2vec and HuBERT.
134
+ You need to download the model
135
+ from [data2vec website](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2)
136
+ or [HuBERT website](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models)
137
+ yourself.
138
+ `--ckpt_path` is the path of the data2vec/HuBERT model.
139
+ - **whisper_root** and **whisper_name**:
140
+ must provide **BOTH** `--whisper_root` and `--whisper_name` for Whisper.
141
+ If there is no corresponding model in `--whisper_root`, the script will download for you.
142
+
143
+ - **layer**:
144
+ which Transformer encoder layer of the model should the representations be extracted from.
145
+ It is **1-based**.
146
+ For example, if layer=9, then the outputs from the 9<sup>th</sup> Transformer encoder layer are dumped.
147
+ Range: [1, number of Transformer encoder layers]
148
+
149
+ - **feat_dir**: The output representations will be saved to `${feat_dir}/0_1.npy`
150
+ and `${feat_dir}/0_1.len`.
151
+
152
+ For other useful functionalities (e.g., sharding), please check the argument list in `examples/dump_feature.py`.
153
+
154
+ ### Command Line Usage
155
+
156
+ We expect to have `${feat_dir}/0_1.npy` and `${feat_dir}/0_1.len` in the provided
157
+ directory `/dir/to/representaitons`.
158
+
159
+ Also, the tsv file should be the **same** as the one used in [Representation Preparation](#representation-preparation).
160
+
161
+ ```
162
+ repcodec /dir/to/representaitons \
163
+ --model /path/to/repcodec/model \
164
+ --tsv_path /path/to/tsv/file \
165
+ [--model_config_path /path/to/train/config] \
166
+ [--use_gpu] \
167
+ [--out_dir /path/to/output]
168
+ ```
169
+
170
+ If you trained the model yourself following [Training New RepCodec Models](#training-new-repcodec-models),
171
+ please provide the training config file using `--model_config_path`.
172
+ If you use the model we provide [here](#repcodec-models), then you do not have to provide that.
173
+
174
+ This command will tokenize the representations and the output discrete tokens will be saved to `${out_dir}/tokens`.
175
+ The tokens are in the same order as the provided tsv file.
176
+
177
+ An example of the output file:
178
+
179
+ ```
180
+ /dir/to/LibriSpeech/dev-clean
181
+ 2277/149896/2277-149896-0026.flac 696 696 198 198 198 498 ...
182
+ 2277/149896/2277-149896-0005.flac 696 696 198 198 198 907 ...
183
+ 2277/149896/2277-149896-0033.flac 696 696 198 198 198 696 ...
184
+ ```
185
+
186
+ Under `examples/tokens`, we provide some token files as references. They are obtained from LibriSpeech dev-clean subset
187
+ using the 6 types of representations and corresponding [RepCodec Models](#repcodec-models).
188
+ Your results should be very similar to ours.
189
+
190
+ ### Python Usage
191
+
192
+ ```python
193
+ import torch
194
+ import yaml
195
+
196
+ from repcodec.RepCodec import RepCodec
197
+
198
+ # for feature types of HubERT base & data2vec base, please use repcodec_dim768.yaml;
199
+ # for feature types of HuBERT large & data2vec large & Whisper medium, please use repcodec_dim1024.yaml;
200
+ # for feature types of Whisper large-v2, please use repcodec_dim1280.yaml
201
+ config = "repcodec/configs/repcodec_dim768.yaml"
202
+ with open(config) as fp:
203
+ conf = yaml.load(fp, Loader=yaml.FullLoader)
204
+
205
+ model = RepCodec(**conf)
206
+ model.load_state_dict(torch.load("./hubert_base_l9.pkl", map_location="cpu")["model"]["repcodec"])
207
+ model.quantizer.initial()
208
+ model.eval()
209
+
210
+ # input shape: (batch size, hidden dim, sequence length)
211
+ random_features = torch.randn(size=(1, 768, 100))
212
+ with torch.no_grad():
213
+ x = model.encoder(random_features)
214
+ z = model.projector(x)
215
+ _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
216
+ tokens = idx.cpu().data.numpy().tolist()[0]
217
+ ```
218
+
219
+ ## Training New RepCodec Models
220
+
221
+ We use a config file to set up all the training configurations, e.g., data, model architecture,
222
+ optimizer, scheduler.
223
+ We provide an example [here](./train_configs/ex_dim768_mse.yaml).
224
+
225
+ Please first install required packages following [Installation](#installation)
226
+ and prepare the representations following [Representation Preparation](#representation-preparation).
227
+
228
+ The input data directory is expected to have the following structure
229
+ ```
230
+ /dir/to/representations/
231
+ train_set_name/
232
+ 0_1.npy
233
+ 0_1.len
234
+ valid_set_name/
235
+ 0_1.npy
236
+ 0_1.len
237
+ test_set_name/
238
+ 0_1.npy
239
+ 0_1.len
240
+ ```
241
+
242
+ The names of subsets should be the same as the fields in the config file.
243
+
244
+ Then, you can run training by
245
+ ```
246
+ python train.py \
247
+ -c /path/to/config/file \
248
+ --tag $tag \
249
+ --exp_root exp
250
+ ```
251
+
252
+ `tag` is the name of the output folder.
253
+ All outputs will be saved to `exp_root/tag/`.
254
+
255
+ ## Acknowledge
256
+
257
+ Our implementation is based on [facebookresearch/AudioDec](https://github.com/facebookresearch/AudioDec).
258
+ We thank them for open-sourcing their code!
259
+
260
+ ## Citation
261
+
262
+ If you find our work useful, please cite the following article.
263
+
264
+ ```
265
+ @misc{huang2023repcodec,
266
+ title={RepCodec: A Speech Representation Codec for Speech Tokenization},
267
+ author={Zhichao Huang and Chutong Meng and Tom Ko},
268
+ year={2023},
269
+ eprint={2309.00169},
270
+ archivePrefix={arXiv},
271
+ primaryClass={eess.AS}
272
+ }
273
+ ```
RepCodec/examples/data2vec_audio.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ # ref: https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
9
+
10
+ import logging
11
+ import math
12
+ from dataclasses import dataclass, field
13
+ from typing import Optional
14
+
15
+ from omegaconf import II
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.distributed as dist
21
+
22
+ from fairseq.modules import EMAModule, EMAModuleConfig
23
+ from fairseq.data.data_utils import compute_mask_indices
24
+ from fairseq.models import BaseFairseqModel, register_model
25
+ from fairseq.models.wav2vec import (
26
+ ConvFeatureExtractionModel,
27
+ Wav2Vec2Config,
28
+ TransformerEncoder,
29
+ )
30
+ from fairseq.modules import (
31
+ GradMultiply,
32
+ LayerNorm,
33
+ )
34
+ from fairseq.utils import index_put
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ @dataclass
41
+ class Data2VecAudioConfig(Wav2Vec2Config):
42
+
43
+ loss_beta: float = field(
44
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
45
+ )
46
+ loss_scale: Optional[float] = field(
47
+ default=None,
48
+ metadata={
49
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
50
+ },
51
+ )
52
+ average_top_k_layers: int = field(
53
+ default=8, metadata={"help": "how many layers to average"}
54
+ )
55
+
56
+ layer_norm_target_layer: bool = False
57
+ instance_norm_target_layer: bool = False
58
+ instance_norm_targets: bool = False
59
+ layer_norm_targets: bool = False
60
+ batch_norm_target_layer: bool = False
61
+ group_norm_target_layer: bool = False
62
+
63
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
64
+ ema_end_decay: float = field(
65
+ default=0.9999, metadata={"help": "final ema decay rate"}
66
+ )
67
+
68
+ # when to finish annealing ema decay rate
69
+ ema_anneal_end_step: int = II("optimization.max_update")
70
+
71
+ ema_transformer_only: bool = field(
72
+ default=True,
73
+ metadata={"help": "whether to momentum update only the transformer"},
74
+ )
75
+ ema_layers_only: bool = field(
76
+ default=True,
77
+ metadata={"help": "whether to momentum update only the transformer layers"},
78
+ )
79
+
80
+ max_update: int = II("optimization.max_update")
81
+
82
+ min_target_var: float = field(
83
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
84
+ )
85
+ min_pred_var: float = field(
86
+ default=0.01,
87
+ metadata={"help": "stop training if prediction var falls below this"},
88
+ )
89
+
90
+
91
+ def get_annealed_rate(start, end, curr_step, total_steps):
92
+ r = end - start
93
+ pct_remaining = 1 - curr_step / total_steps
94
+ return end - r * pct_remaining
95
+
96
+
97
+ @register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
98
+ class Data2VecAudioModel(BaseFairseqModel):
99
+ def __init__(self, cfg: Data2VecAudioConfig):
100
+ super().__init__()
101
+ self.cfg = cfg
102
+
103
+ feature_enc_layers = eval(cfg.conv_feature_layers)
104
+ self.extractor_embed = feature_enc_layers[-1][0]
105
+
106
+ self.ema = None
107
+ self.embed = cfg.encoder_embed_dim
108
+
109
+ self.average_top_k_layers = cfg.average_top_k_layers
110
+ self.loss_beta = cfg.loss_beta
111
+ self.loss_scale = cfg.loss_scale
112
+
113
+ self.feature_extractor = ConvFeatureExtractionModel(
114
+ conv_layers=feature_enc_layers,
115
+ dropout=0.0,
116
+ mode=cfg.extractor_mode,
117
+ conv_bias=cfg.conv_bias,
118
+ )
119
+
120
+ self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
121
+
122
+ self.mask_prob = cfg.mask_prob
123
+ self.mask_selection = cfg.mask_selection
124
+ self.mask_other = cfg.mask_other
125
+ self.mask_length = cfg.mask_length
126
+ self.no_mask_overlap = cfg.no_mask_overlap
127
+ self.mask_min_space = cfg.mask_min_space
128
+
129
+ self.mask_channel_prob = cfg.mask_channel_prob
130
+ self.mask_channel_before = cfg.mask_channel_before
131
+ self.mask_channel_selection = cfg.mask_channel_selection
132
+ self.mask_channel_other = cfg.mask_channel_other
133
+ self.mask_channel_length = cfg.mask_channel_length
134
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
135
+ self.mask_channel_min_space = cfg.mask_channel_min_space
136
+
137
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
138
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
139
+
140
+ self.feature_grad_mult = cfg.feature_grad_mult
141
+
142
+ self.mask_emb = nn.Parameter(
143
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
144
+ )
145
+
146
+ self.encoder = TransformerEncoder(cfg)
147
+ self.layer_norm = LayerNorm(self.extractor_embed)
148
+
149
+ self.final_proj = nn.Linear(self.embed, self.embed)
150
+
151
+ self.num_updates = 0
152
+
153
+ def make_ema_teacher(self):
154
+ ema_config = EMAModuleConfig(
155
+ ema_decay=self.cfg.ema_decay,
156
+ ema_fp32=True,
157
+ )
158
+ skip_keys = set()
159
+ if self.cfg.ema_layers_only:
160
+ self.cfg.ema_transformer_only = True
161
+ for k, _ in self.encoder.pos_conv.named_parameters():
162
+ skip_keys.add(f"pos_conv.{k}")
163
+
164
+ self.ema = EMAModule(
165
+ self.encoder if self.cfg.ema_transformer_only else self,
166
+ ema_config,
167
+ skip_keys=skip_keys,
168
+ )
169
+
170
+ def set_num_updates(self, num_updates):
171
+ super().set_num_updates(num_updates)
172
+
173
+ if self.ema is None and self.final_proj is not None:
174
+ logger.info(f"making ema teacher")
175
+ self.make_ema_teacher()
176
+ elif self.training and self.ema is not None:
177
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
178
+ if num_updates >= self.cfg.ema_anneal_end_step:
179
+ decay = self.cfg.ema_end_decay
180
+ else:
181
+ decay = get_annealed_rate(
182
+ self.cfg.ema_decay,
183
+ self.cfg.ema_end_decay,
184
+ num_updates,
185
+ self.cfg.ema_anneal_end_step,
186
+ )
187
+ self.ema.set_decay(decay)
188
+ if self.ema.get_decay() < 1:
189
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
190
+
191
+ self.num_updates = num_updates
192
+
193
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
194
+ state = super().state_dict(destination, prefix, keep_vars)
195
+
196
+ if self.ema is not None:
197
+ state[prefix + "_ema"] = self.ema.fp32_params
198
+
199
+ return state
200
+
201
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
202
+ if self.ema is not None:
203
+ k = prefix + "_ema"
204
+ assert k in state_dict
205
+ self.ema.restore(state_dict[k], True)
206
+ del state_dict[k]
207
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
208
+
209
+ @classmethod
210
+ def build_model(cls, cfg: Data2VecAudioConfig, task=None):
211
+ """Build a new model instance."""
212
+
213
+ return cls(cfg)
214
+
215
+ def apply_mask(
216
+ self,
217
+ x,
218
+ padding_mask,
219
+ mask_indices=None,
220
+ mask_channel_indices=None,
221
+ ):
222
+ B, T, C = x.shape
223
+
224
+ if self.mask_channel_prob > 0 and self.mask_channel_before:
225
+ mask_channel_indices = compute_mask_indices(
226
+ (B, C),
227
+ None,
228
+ self.mask_channel_prob,
229
+ self.mask_channel_length,
230
+ self.mask_channel_selection,
231
+ self.mask_channel_other,
232
+ no_overlap=self.no_mask_channel_overlap,
233
+ min_space=self.mask_channel_min_space,
234
+ )
235
+ mask_channel_indices = (
236
+ torch.from_numpy(mask_channel_indices)
237
+ .to(x.device)
238
+ .unsqueeze(1)
239
+ .expand(-1, T, -1)
240
+ )
241
+ x[mask_channel_indices] = 0
242
+
243
+ if self.mask_prob > 0:
244
+ if mask_indices is None:
245
+ mask_indices = compute_mask_indices(
246
+ (B, T),
247
+ padding_mask,
248
+ self.mask_prob,
249
+ self.mask_length,
250
+ self.mask_selection,
251
+ self.mask_other,
252
+ min_masks=1,
253
+ no_overlap=self.no_mask_overlap,
254
+ min_space=self.mask_min_space,
255
+ require_same_masks=self.cfg.require_same_masks,
256
+ mask_dropout=self.cfg.mask_dropout,
257
+ )
258
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
259
+ x = index_put(x, mask_indices, self.mask_emb)
260
+ else:
261
+ mask_indices = None
262
+
263
+ if self.mask_channel_prob > 0 and not self.mask_channel_before:
264
+ if mask_channel_indices is None:
265
+ mask_channel_indices = compute_mask_indices(
266
+ (B, C),
267
+ None,
268
+ self.mask_channel_prob,
269
+ self.mask_channel_length,
270
+ self.mask_channel_selection,
271
+ self.mask_channel_other,
272
+ no_overlap=self.no_mask_channel_overlap,
273
+ min_space=self.mask_channel_min_space,
274
+ )
275
+ mask_channel_indices = (
276
+ torch.from_numpy(mask_channel_indices)
277
+ .to(x.device)
278
+ .unsqueeze(1)
279
+ .expand(-1, T, -1)
280
+ )
281
+ x = index_put(x, mask_channel_indices, 0)
282
+
283
+ return x, mask_indices
284
+
285
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
286
+ """
287
+ Computes the output length of the convolutional layers
288
+ """
289
+
290
+ def _conv_out_length(input_length, kernel_size, stride):
291
+ return torch.floor((input_length - kernel_size) / stride + 1)
292
+
293
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
294
+
295
+ for i in range(len(conv_cfg_list)):
296
+ input_lengths = _conv_out_length(
297
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
298
+ )
299
+
300
+ return input_lengths.to(torch.long)
301
+
302
+ def forward(
303
+ self,
304
+ source,
305
+ padding_mask=None,
306
+ mask=True,
307
+ features_only=False,
308
+ layer=None,
309
+ mask_indices=None,
310
+ mask_channel_indices=None,
311
+ padding_count=None,
312
+ ):
313
+ features = source
314
+
315
+ if self.feature_grad_mult > 0:
316
+ features = self.feature_extractor(features)
317
+ if self.feature_grad_mult != 1.0:
318
+ features = GradMultiply.apply(features, self.feature_grad_mult)
319
+ else:
320
+ with torch.no_grad():
321
+ features = self.feature_extractor(features)
322
+
323
+ features = features.transpose(1, 2)
324
+
325
+ features = self.layer_norm(features)
326
+
327
+ orig_padding_mask = padding_mask
328
+
329
+ if padding_mask is not None and padding_mask.any():
330
+ input_lengths = (1 - padding_mask.long()).sum(-1)
331
+ # apply conv formula to get real output_lengths
332
+ output_lengths = self._get_feat_extract_output_lengths(input_lengths)
333
+
334
+ padding_mask = torch.zeros(
335
+ features.shape[:2], dtype=features.dtype, device=features.device
336
+ )
337
+
338
+ # these two operations makes sure that all values
339
+ # before the output lengths indices are attended to
340
+ padding_mask[
341
+ (
342
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
343
+ output_lengths - 1,
344
+ )
345
+ ] = 1
346
+ padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
347
+ else:
348
+ padding_mask = None
349
+
350
+ if self.post_extract_proj is not None:
351
+ features = self.post_extract_proj(features)
352
+
353
+ pre_encoder_features = None
354
+ if self.cfg.ema_transformer_only:
355
+ pre_encoder_features = features.clone()
356
+
357
+ features = self.dropout_input(features)
358
+
359
+ if mask:
360
+ x, mask_indices = self.apply_mask(
361
+ features,
362
+ padding_mask,
363
+ mask_indices=mask_indices,
364
+ mask_channel_indices=mask_channel_indices,
365
+ )
366
+ else:
367
+ x = features
368
+ mask_indices = None
369
+
370
+ x, layer_results = self.encoder(
371
+ x,
372
+ padding_mask=padding_mask,
373
+ layer=layer,
374
+ )
375
+
376
+ if features_only:
377
+ return {
378
+ "x": x,
379
+ "padding_mask": padding_mask,
380
+ "layer_results": layer_results,
381
+ }
382
+
383
+ result = {
384
+ "losses": {},
385
+ }
386
+
387
+ with torch.no_grad():
388
+ self.ema.model.eval()
389
+
390
+ if self.cfg.ema_transformer_only:
391
+ y, layer_results = self.ema.model.extract_features(
392
+ pre_encoder_features,
393
+ padding_mask=padding_mask,
394
+ min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
395
+ )
396
+ y = {
397
+ "x": y,
398
+ "padding_mask": padding_mask,
399
+ "layer_results": layer_results,
400
+ }
401
+ else:
402
+ y = self.ema.model.extract_features(
403
+ source=source,
404
+ padding_mask=orig_padding_mask,
405
+ mask=False,
406
+ )
407
+
408
+ target_layer_results = [l[2] for l in y["layer_results"]]
409
+
410
+ permuted = False
411
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
412
+ target_layer_results = [
413
+ tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
414
+ ]
415
+ permuted = True
416
+
417
+ if self.cfg.batch_norm_target_layer:
418
+ target_layer_results = [
419
+ F.batch_norm(
420
+ tl.float(), running_mean=None, running_var=None, training=True
421
+ )
422
+ for tl in target_layer_results
423
+ ]
424
+
425
+ if self.cfg.instance_norm_target_layer:
426
+ target_layer_results = [
427
+ F.instance_norm(tl.float()) for tl in target_layer_results
428
+ ]
429
+
430
+ if permuted:
431
+ target_layer_results = [
432
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
433
+ ]
434
+
435
+ if self.cfg.group_norm_target_layer:
436
+ target_layer_results = [
437
+ F.layer_norm(tl.float(), tl.shape[-2:])
438
+ for tl in target_layer_results
439
+ ]
440
+
441
+ if self.cfg.layer_norm_target_layer:
442
+ target_layer_results = [
443
+ F.layer_norm(tl.float(), tl.shape[-1:])
444
+ for tl in target_layer_results
445
+ ]
446
+
447
+ y = sum(target_layer_results) / len(target_layer_results)
448
+
449
+ if self.cfg.layer_norm_targets:
450
+ y = F.layer_norm(y.float(), y.shape[-1:])
451
+
452
+ if self.cfg.instance_norm_targets:
453
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
454
+
455
+ if not permuted:
456
+ y = y.transpose(0, 1)
457
+
458
+ y = y[mask_indices]
459
+
460
+ x = x[mask_indices]
461
+ x = self.final_proj(x)
462
+
463
+ sz = x.size(-1)
464
+
465
+ if self.loss_beta == 0:
466
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
467
+ else:
468
+ loss = F.smooth_l1_loss(
469
+ x.float(), y.float(), reduction="none", beta=self.loss_beta
470
+ ).sum(dim=-1)
471
+
472
+ if self.loss_scale is not None:
473
+ scale = self.loss_scale
474
+ else:
475
+ scale = 1 / math.sqrt(sz)
476
+
477
+ result["losses"]["regression"] = loss.sum() * scale
478
+
479
+ if "sample_size" not in result:
480
+ result["sample_size"] = loss.numel()
481
+
482
+ with torch.no_grad():
483
+ result["target_var"] = self.compute_var(y)
484
+ result["pred_var"] = self.compute_var(x.float())
485
+
486
+ if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
487
+ logger.error(
488
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
489
+ )
490
+ raise Exception(
491
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
492
+ )
493
+ if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
494
+ logger.error(
495
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
496
+ )
497
+ raise Exception(
498
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
499
+ )
500
+
501
+ if self.ema is not None:
502
+ result["ema_decay"] = self.ema.get_decay() * 1000
503
+
504
+ return result
505
+
506
+ @staticmethod
507
+ def compute_var(y):
508
+ y = y.view(-1, y.size(-1))
509
+ if dist.is_initialized():
510
+ zc = torch.tensor(y.size(0)).cuda()
511
+ zs = y.sum(dim=0)
512
+ zss = (y ** 2).sum(dim=0)
513
+
514
+ dist.all_reduce(zc)
515
+ dist.all_reduce(zs)
516
+ dist.all_reduce(zss)
517
+
518
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
519
+ return torch.sqrt(var + 1e-6).mean()
520
+ else:
521
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
522
+
523
+ def extract_features(
524
+ self, source, padding_mask, mask=False, layer=None
525
+ ):
526
+ res = self.forward(
527
+ source,
528
+ padding_mask,
529
+ mask=mask,
530
+ features_only=True,
531
+ layer=layer,
532
+ )
533
+ return res
534
+
535
+ def remove_pretraining_modules(self, last_layer=None):
536
+ self.final_proj = None
537
+ self.ema = None
538
+ if last_layer is not None:
539
+ self.encoder.layers = nn.ModuleList(
540
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
541
+ )
RepCodec/examples/data2vec_feature_reader.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ import logging
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from fairseq import tasks
13
+ from fairseq.checkpoint_utils import load_checkpoint_to_cpu
14
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
15
+ from omegaconf import OmegaConf
16
+
17
+ from data2vec_audio import Data2VecAudioModel
18
+
19
+ logger = logging.getLogger("dump_feature")
20
+
21
+
22
+ class Data2vecFeatureReader(object):
23
+ def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
24
+ state = load_checkpoint_to_cpu(ckpt_path)
25
+ cfg = state["cfg"]
26
+ # load task
27
+ task = tasks.setup_task(cfg.task, from_checkpoint=True)
28
+ task.load_state_dict(state["task_state"])
29
+ # load model config
30
+ if "layer_type" not in cfg.model:
31
+ # fix a missing key
32
+ model_config = {k: v for k, v in cfg.model.items()}
33
+ model_config["layer_type"] = "transformer"
34
+ model_config = OmegaConf.create(model_config)
35
+ else:
36
+ model_config = cfg.model
37
+
38
+ # fix param name in the state
39
+ state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
40
+ state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
41
+ del state["model"]["_ema"]
42
+
43
+ # load model
44
+ model = Data2VecAudioModel.build_model(model_config)
45
+ model.load_state_dict(
46
+ state["model"], strict=True, model_cfg=model_config
47
+ )
48
+
49
+ self.device = device
50
+ logger.info(f"device = {self.device}")
51
+
52
+ self.model = model.eval().to(self.device)
53
+ self.task = task
54
+ self.layer = layer - 1 # make it 1-based
55
+ self.max_chunk = max_chunk
56
+ logger.info(f"TASK CONFIG:\n{self.task.cfg}")
57
+ logger.info(f" max_chunk = {self.max_chunk}")
58
+
59
+ def read_audio(self, path, ref_len=None):
60
+ wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
61
+ if wav.ndim == 2:
62
+ wav = wav.mean(-1)
63
+ assert wav.ndim == 1, wav.ndim
64
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
65
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
66
+ return wav
67
+
68
+ def get_feats(self, path, ref_len=None):
69
+ x = self.read_audio(path, ref_len=ref_len)
70
+ with torch.no_grad():
71
+ x = torch.from_numpy(x).float().to(self.device)
72
+ if self.task.cfg.normalize:
73
+ x = F.layer_norm(x, x.shape)
74
+ x = x.view(1, -1)
75
+
76
+ feat = []
77
+ for start in range(0, x.size(1), self.max_chunk):
78
+ x_chunk = x[:, start: start + self.max_chunk]
79
+ res = self.model.extract_features(
80
+ source=x_chunk,
81
+ padding_mask=None,
82
+ mask=False,
83
+ layer=self.layer,
84
+ )
85
+ feat_chunk = res["x"]
86
+ feat.append(feat_chunk)
87
+ return torch.cat(feat, 1).squeeze(0)
RepCodec/examples/dump_feature.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ import logging
9
+ import os
10
+ import sys
11
+
12
+ from feature_utils import get_path_iterator, dump_feature
13
+
14
+ logging.basicConfig(
15
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
16
+ datefmt="%Y-%m-%d %H:%M:%S",
17
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
18
+ stream=sys.stdout,
19
+ )
20
+ logger = logging.getLogger("dump_feature")
21
+
22
+
23
+ def main(
24
+ model_type: str,
25
+ tsv_path: str,
26
+ ckpt_path: str,
27
+ whisper_root: str,
28
+ whisper_name: str,
29
+ layer: int,
30
+ nshard: int,
31
+ rank: int,
32
+ feat_dir: str,
33
+ max_chunk: int,
34
+ use_cpu: bool = False
35
+ ):
36
+ device = "cpu" if use_cpu else "cuda"
37
+
38
+ # some checks
39
+ if model_type in ["hubert", "data2vec"]:
40
+ assert ckpt_path and os.path.exists(ckpt_path)
41
+ elif model_type in ["whisper"]:
42
+ assert whisper_name and whisper_root
43
+ else:
44
+ raise ValueError(f"Unsupported model type {model_type}")
45
+
46
+ reader = None
47
+ if model_type == "hubert":
48
+ from hubert_feature_reader import HubertFeatureReader
49
+ reader = HubertFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
50
+ elif model_type == "data2vec":
51
+ from data2vec_feature_reader import Data2vecFeatureReader
52
+ reader = Data2vecFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
53
+ elif model_type == "whisper":
54
+ from whisper_feature_reader import WhisperFeatureReader
55
+ reader = WhisperFeatureReader(whisper_root, whisper_name, layer, device=device)
56
+
57
+ assert reader is not None
58
+
59
+ generator, num = get_path_iterator(tsv_path, nshard, rank)
60
+ dump_feature(reader, generator, num, nshard, rank, feat_dir)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ import argparse
65
+
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument(
68
+ "--model_type",
69
+ required=True,
70
+ type=str,
71
+ choices=["data2vec", "hubert", "whisper"],
72
+ help="the type of the speech encoder."
73
+ )
74
+ parser.add_argument(
75
+ "--tsv_path",
76
+ required=True,
77
+ type=str,
78
+ help="the path to the tsv file."
79
+ )
80
+ parser.add_argument(
81
+ "--ckpt_path",
82
+ required=False,
83
+ type=str,
84
+ default=None,
85
+ help="path to the speech model. must provide for HuBERT and data2vec"
86
+ )
87
+ parser.add_argument(
88
+ "--whisper_root",
89
+ required=False,
90
+ type=str,
91
+ default=None,
92
+ help="root dir to download/store whisper model. must provide for whisper model."
93
+ )
94
+ parser.add_argument(
95
+ "--whisper_name",
96
+ required=False,
97
+ type=str,
98
+ default=None,
99
+ help="name of whisper model. e.g., large-v2. must provide for whisper model."
100
+ )
101
+ parser.add_argument(
102
+ "--layer",
103
+ required=True,
104
+ type=int,
105
+ help="which layer of the model. this is 1-based."
106
+ )
107
+ parser.add_argument(
108
+ "--feat_dir",
109
+ required=True,
110
+ type=str,
111
+ help="the output dir to save the representations."
112
+ )
113
+ parser.add_argument(
114
+ "--nshard",
115
+ required=False,
116
+ type=int,
117
+ default=1,
118
+ help="total number of shards."
119
+ )
120
+ parser.add_argument(
121
+ "--rank",
122
+ required=False,
123
+ type=int,
124
+ default=0,
125
+ help="shard id of this process."
126
+ )
127
+ parser.add_argument(
128
+ "--max_chunk",
129
+ type=int,
130
+ default=1600000,
131
+ help="max number of frames of each batch."
132
+ )
133
+ parser.add_argument(
134
+ "--use_cpu",
135
+ default=False,
136
+ action="store_true",
137
+ help="whether use cpu instead of gpu."
138
+ )
139
+ args = parser.parse_args()
140
+ logger.info(args)
141
+
142
+ main(**vars(args))
RepCodec/examples/feature_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ # ref: https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/feature_utils.py
9
+
10
+ import logging
11
+ import os
12
+ import sys
13
+
14
+ import tqdm
15
+ from npy_append_array import NpyAppendArray
16
+
17
+
18
+ logging.basicConfig(
19
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
20
+ datefmt="%Y-%m-%d %H:%M:%S",
21
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
22
+ stream=sys.stdout,
23
+ )
24
+ logger = logging.getLogger("feature_utils")
25
+
26
+
27
+ def get_shard_range(tot, nshard, rank):
28
+ assert rank < nshard and rank >= 0, f"invaid rank/nshard {rank}/{nshard}"
29
+ start = round(tot / nshard * rank)
30
+ end = round(tot / nshard * (rank + 1))
31
+ assert start < end, f"start={start}, end={end}"
32
+ logger.info(
33
+ f"rank {rank} of {nshard}, process {end-start} "
34
+ f"({start}-{end}) out of {tot}"
35
+ )
36
+ return start, end
37
+
38
+
39
+ def get_path_iterator(tsv, nshard, rank):
40
+ with open(tsv, "r") as f:
41
+ root = f.readline().rstrip()
42
+ lines = [line.rstrip() for line in f]
43
+ start, end = get_shard_range(len(lines), nshard, rank)
44
+ lines = lines[start:end]
45
+ def iterate():
46
+ for line in lines:
47
+ subpath, nsample = line.split("\t")
48
+ yield f"{root}/{subpath}", int(nsample)
49
+ return iterate, len(lines)
50
+
51
+
52
+ def dump_feature(reader, generator, num, nshard, rank, feat_dir):
53
+ iterator = generator()
54
+
55
+ feat_path = f"{feat_dir}/{rank}_{nshard}.npy"
56
+ leng_path = f"{feat_dir}/{rank}_{nshard}.len"
57
+
58
+ os.makedirs(feat_dir, exist_ok=True)
59
+ if os.path.exists(feat_path):
60
+ os.remove(feat_path)
61
+
62
+ feat_f = NpyAppendArray(feat_path)
63
+ with open(leng_path, "w") as leng_f:
64
+ for path, nsample in tqdm.tqdm(iterator, total=num):
65
+ feat = reader.get_feats(path, nsample)
66
+ feat_f.append(feat.cpu().numpy())
67
+ leng_f.write(f"{len(feat)}\n")
68
+ logger.info("finished successfully")
69
+
70
+
RepCodec/examples/hubert_feature_reader.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ import logging
9
+
10
+ import fairseq
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
15
+
16
+ logger = logging.getLogger("dump_feature")
17
+
18
+
19
+ class HubertFeatureReader(object):
20
+ def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
21
+ (
22
+ model,
23
+ cfg,
24
+ task,
25
+ ) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
26
+
27
+ self.device = device
28
+ logger.info(f"device = {self.device}")
29
+
30
+ self.model = model[0].eval().to(self.device)
31
+ self.task = task
32
+ self.layer = layer
33
+ self.max_chunk = max_chunk
34
+ logger.info(f"TASK CONFIG:\n{self.task.cfg}")
35
+ logger.info(f" max_chunk = {self.max_chunk}")
36
+
37
+ def read_audio(self, path, ref_len=None):
38
+ wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
39
+ if wav.ndim == 2:
40
+ wav = wav.mean(-1)
41
+ assert wav.ndim == 1, wav.ndim
42
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
43
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
44
+ return wav
45
+
46
+ def get_feats(self, path, ref_len=None):
47
+ x = self.read_audio(path, ref_len=ref_len)
48
+ with torch.no_grad():
49
+ x = torch.from_numpy(x).float().to(self.device)
50
+ if self.task.cfg.normalize:
51
+ x = F.layer_norm(x, x.shape)
52
+ x = x.view(1, -1)
53
+
54
+ feat = []
55
+ for start in range(0, x.size(1), self.max_chunk):
56
+ x_chunk = x[:, start: start + self.max_chunk]
57
+ feat_chunk, _ = self.model.extract_features(
58
+ source=x_chunk,
59
+ padding_mask=None,
60
+ mask=False,
61
+ output_layer=self.layer,
62
+ )
63
+ feat.append(feat_chunk)
64
+ return torch.cat(feat, 1).squeeze(0)
RepCodec/examples/tokens/data2vec_base_l6_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
RepCodec/examples/tokens/data2vec_large_l18_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
RepCodec/examples/tokens/hubert_base_l9_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
RepCodec/examples/tokens/hubert_large_l18_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
RepCodec/examples/tokens/whisper_large_l32_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
RepCodec/examples/tokens/whisper_medium_l24_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
RepCodec/examples/whisper_feature_reader.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq) and
7
+ # Whisper (https://github.com/openai/whisper/)
8
+
9
+ import io
10
+ import logging
11
+ import os
12
+ from typing import Optional, Union
13
+
14
+ import soundfile as sf
15
+ import torch
16
+ from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models
17
+ from whisper.audio import log_mel_spectrogram
18
+ from whisper.model import ModelDimensions
19
+
20
+ from whisper_model import Whisper_
21
+
22
+ logger = logging.getLogger("dump_feature")
23
+
24
+
25
+ def load_model(
26
+ name: str,
27
+ device: Optional[Union[str, torch.device]] = None,
28
+ download_root: str = None,
29
+ in_memory: bool = False,
30
+ ) -> Whisper_:
31
+ """
32
+ Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97
33
+ But we will load a `Whisper_` model for feature extraction.
34
+
35
+ Parameters
36
+ ----------
37
+ name : str
38
+ one of the official model names listed by `whisper.available_models()`, or
39
+ path to a model checkpoint containing the model dimensions and the model state_dict.
40
+ device : Union[str, torch.device]
41
+ the PyTorch device to put the model into
42
+ download_root: str
43
+ path to download the model files; by default, it uses "~/.cache/whisper"
44
+ in_memory: bool
45
+ whether to preload the model weights into host memory
46
+
47
+ Returns
48
+ -------
49
+ model : Whisper
50
+ The Whisper ASR model instance
51
+ """
52
+
53
+ if device is None:
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ if download_root is None:
56
+ default = os.path.join(os.path.expanduser("~"), ".cache")
57
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
58
+
59
+ if name in _MODELS:
60
+ checkpoint_file = _download(_MODELS[name], download_root, in_memory)
61
+ alignment_heads = _ALIGNMENT_HEADS[name]
62
+ elif os.path.isfile(name):
63
+ checkpoint_file = open(name, "rb").read() if in_memory else name
64
+ alignment_heads = None
65
+ else:
66
+ raise RuntimeError(
67
+ f"Model {name} not found; available models = {available_models()}"
68
+ )
69
+
70
+ with (
71
+ io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
72
+ ) as fp:
73
+ checkpoint = torch.load(fp, map_location=device)
74
+ del checkpoint_file
75
+
76
+ dims = ModelDimensions(**checkpoint["dims"])
77
+ model = Whisper_(dims)
78
+ model.load_state_dict(checkpoint["model_state_dict"])
79
+
80
+ if alignment_heads is not None:
81
+ model.set_alignment_heads(alignment_heads)
82
+
83
+ return model.to(device)
84
+
85
+
86
+ class WhisperFeatureReader(object):
87
+ def __init__(self, root, ckpt, layer, device):
88
+ self.device = device
89
+ logger.info(f"device = {self.device}")
90
+
91
+ self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval()
92
+ self.model.decoder = None # to save some memory by deleting the decoder
93
+ self.layer = layer # one-based
94
+
95
+ def read_audio(self, path, ref_len=None):
96
+ wav, sample_rate = sf.read(path)
97
+ assert sample_rate == 16000, sample_rate
98
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
99
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
100
+ return wav
101
+
102
+ def get_feats(self, path, ref_len=None):
103
+ wav = self.read_audio(path, ref_len)
104
+ audio_length = len(wav)
105
+ with torch.no_grad():
106
+ mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device))
107
+ hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer)
108
+ feature_length = audio_length // 320
109
+ hidden = hidden[0, :feature_length]
110
+ return hidden.contiguous()
RepCodec/examples/whisper_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq) and
7
+ # Whisper (https://github.com/openai/whisper/)
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ from whisper.model import AudioEncoder, sinusoids, Whisper, ModelDimensions
15
+
16
+
17
+ class AudioEncoder_(AudioEncoder):
18
+ def __init__(self, *args, **kwargs):
19
+ super(AudioEncoder_, self).__init__(*args, **kwargs)
20
+
21
+ def extract_feature(self, x: Tensor, target_layer: Optional[int] = None):
22
+ """
23
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
24
+ the mel spectrogram of the audio
25
+ """
26
+ x = F.gelu(self.conv1(x))
27
+ x = F.gelu(self.conv2(x))
28
+ x = x.permute(0, 2, 1)
29
+
30
+ length_x = x.shape[1]
31
+ if length_x > self.positional_embedding.shape[0]:
32
+ self.register_buffer("positional_embedding", sinusoids(length_x, self.positional_embedding.shape[1]))
33
+ self.positional_embedding = self.positional_embedding.to(x.device)
34
+ x = (x + self.positional_embedding[:length_x, :]).to(x.dtype)
35
+
36
+ if target_layer is None:
37
+ target_layer = len(self.blocks)
38
+
39
+ for block in self.blocks[:target_layer]:
40
+ x = block(x)
41
+
42
+ return x
43
+
44
+
45
+ class Whisper_(Whisper):
46
+ def __init__(self, dims: ModelDimensions):
47
+ super(Whisper_, self).__init__(dims)
48
+ # replace audio encoder with our audio encoder
49
+ self.encoder = AudioEncoder_(
50
+ self.dims.n_mels,
51
+ self.dims.n_audio_ctx,
52
+ self.dims.n_audio_state,
53
+ self.dims.n_audio_head,
54
+ self.dims.n_audio_layer,
55
+ )
56
+
57
+ def extract_features(self, mel: torch.Tensor, target_layer: Optional[int] = None):
58
+ return self.encoder.extract_feature(mel, target_layer)
RepCodec/repcodec/RepCodec.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.modules.decoder import Decoder
11
+ from repcodec.modules.encoder import Encoder
12
+ from repcodec.modules.projector import Projector
13
+ from repcodec.modules.quantizer import Quantizer
14
+
15
+
16
+ class RepCodec(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_channels=768,
20
+ output_channels=768,
21
+ encode_channels=768,
22
+ decode_channels=768,
23
+ code_dim=768,
24
+ codebook_num=1,
25
+ codebook_size=1024,
26
+ bias=True,
27
+ enc_ratios=(1, 1),
28
+ dec_ratios=(1, 1),
29
+ enc_strides=(1, 1),
30
+ dec_strides=(1, 1),
31
+ enc_kernel_size=3,
32
+ dec_kernel_size=3,
33
+ enc_block_dilations=(1, 1),
34
+ enc_block_kernel_size=3,
35
+ dec_block_dilations=(1, 1),
36
+ dec_block_kernel_size=3
37
+ ):
38
+ super().__init__()
39
+
40
+ self.input_channels = input_channels
41
+
42
+ self.encoder = Encoder(
43
+ input_channels=input_channels,
44
+ encode_channels=encode_channels,
45
+ channel_ratios=enc_ratios,
46
+ strides=enc_strides,
47
+ kernel_size=enc_kernel_size,
48
+ bias=bias,
49
+ block_dilations=enc_block_dilations,
50
+ unit_kernel_size=enc_block_kernel_size
51
+ )
52
+
53
+ self.decoder = Decoder(
54
+ code_dim=code_dim,
55
+ output_channels=output_channels,
56
+ decode_channels=decode_channels,
57
+ channel_ratios=dec_ratios,
58
+ strides=dec_strides,
59
+ kernel_size=dec_kernel_size,
60
+ bias=bias,
61
+ block_dilations=dec_block_dilations,
62
+ unit_kernel_size=dec_block_kernel_size
63
+ )
64
+
65
+ self.projector = Projector(
66
+ input_channels=self.encoder.out_channels,
67
+ code_dim=code_dim,
68
+ kernel_size=3,
69
+ stride=1,
70
+ bias=False
71
+ )
72
+
73
+ self.quantizer = Quantizer(
74
+ code_dim=code_dim,
75
+ codebook_num=codebook_num,
76
+ codebook_size=codebook_size
77
+ )
78
+
79
+ def forward(self, x):
80
+ x = self.encoder(x)
81
+ z = self.projector(x)
82
+ zq, vqloss, perplexity = self.quantizer(z)
83
+ y = self.decoder(zq)
84
+ return y, zq, z, vqloss, perplexity
RepCodec/repcodec/configs/repcodec_dim1024.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_channels: 1024
2
+ output_channels: 1024
3
+ encode_channels: 1024
4
+ decode_channels: 1024
5
+ code_dim: 1024
6
+ codebook_num: 1
7
+ codebook_size: 1024
8
+ bias: true
9
+ enc_ratios: [ 1, 1 ]
10
+ dec_ratios: [ 1, 1 ]
11
+ enc_strides: [ 1, 1 ] # no downsampling
12
+ dec_strides: [ 1, 1 ]
13
+ enc_kernel_size: 3
14
+ dec_kernel_size: 3
15
+ enc_block_dilations: [ 1, 1 ]
16
+ enc_block_kernel_size: 3
17
+ dec_block_dilations: [ 1, 1 ]
18
+ dec_block_kernel_size: 3
RepCodec/repcodec/configs/repcodec_dim1280.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_channels: 1280
2
+ output_channels: 1280
3
+ encode_channels: 1280
4
+ decode_channels: 1280
5
+ code_dim: 1280
6
+ codebook_num: 1
7
+ codebook_size: 1024
8
+ bias: true
9
+ enc_ratios: [ 1, 1 ]
10
+ dec_ratios: [ 1, 1 ]
11
+ enc_strides: [ 1, 1 ] # no downsampling
12
+ dec_strides: [ 1, 1 ]
13
+ enc_kernel_size: 3
14
+ dec_kernel_size: 3
15
+ enc_block_dilations: [ 1, 1 ]
16
+ enc_block_kernel_size: 3
17
+ dec_block_dilations: [ 1, 1 ]
18
+ dec_block_kernel_size: 3
RepCodec/repcodec/configs/repcodec_dim768.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_channels: 768
2
+ output_channels: 768
3
+ encode_channels: 768
4
+ decode_channels: 768
5
+ code_dim: 768
6
+ codebook_num: 1
7
+ codebook_size: 1024
8
+ bias: true
9
+ enc_ratios: [ 1, 1 ]
10
+ dec_ratios: [ 1, 1 ]
11
+ enc_strides: [ 1, 1 ] # no downsampling
12
+ dec_strides: [ 1, 1 ]
13
+ enc_kernel_size: 3
14
+ dec_kernel_size: 3
15
+ enc_block_dilations: [ 1, 1 ]
16
+ enc_block_kernel_size: 3
17
+ dec_block_dilations: [ 1, 1 ]
18
+ dec_block_kernel_size: 3
RepCodec/repcodec/layers/conv_layer.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+
11
+ class Conv1d1x1(nn.Conv1d):
12
+ """1x1 Conv1d."""
13
+
14
+ def __init__(self, in_channels, out_channels, bias=True):
15
+ super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
16
+
17
+
18
+ class Conv1d(nn.Module):
19
+ def __init__(
20
+ self,
21
+ in_channels: int,
22
+ out_channels: int,
23
+ kernel_size: int,
24
+ stride: int = 1,
25
+ padding: int = -1,
26
+ dilation: int = 1,
27
+ groups: int = 1,
28
+ bias: bool = True
29
+ ):
30
+ super().__init__()
31
+ self.in_channels = in_channels
32
+ self.out_channels = out_channels
33
+ self.kernel_size = kernel_size
34
+ if padding < 0:
35
+ padding = (kernel_size - 1) // 2 * dilation
36
+ self.dilation = dilation
37
+ self.conv = nn.Conv1d(
38
+ in_channels=in_channels,
39
+ out_channels=out_channels,
40
+ kernel_size=kernel_size,
41
+ stride=stride,
42
+ padding=padding,
43
+ dilation=dilation,
44
+ groups=groups,
45
+ bias=bias,
46
+ )
47
+
48
+ def forward(self, x):
49
+ """
50
+ Args:
51
+ x (Tensor): Float tensor variable with the shape (B, C, T).
52
+ Returns:
53
+ Tensor: Float tensor variable with the shape (B, C, T).
54
+ """
55
+ x = self.conv(x)
56
+ return x
57
+
58
+
59
+ class ConvTranspose1d(nn.Module):
60
+ def __init__(
61
+ self,
62
+ in_channels: int,
63
+ out_channels: int,
64
+ kernel_size: int,
65
+ stride: int,
66
+ padding=-1,
67
+ output_padding=-1,
68
+ groups=1,
69
+ bias=True,
70
+ ):
71
+ super().__init__()
72
+ if padding < 0:
73
+ padding = (stride + 1) // 2
74
+ if output_padding < 0:
75
+ output_padding = 1 if stride % 2 else 0
76
+ self.deconv = nn.ConvTranspose1d(
77
+ in_channels=in_channels,
78
+ out_channels=out_channels,
79
+ kernel_size=kernel_size,
80
+ stride=stride,
81
+ padding=padding,
82
+ output_padding=output_padding,
83
+ groups=groups,
84
+ bias=bias,
85
+ )
86
+
87
+ def forward(self, x):
88
+ """
89
+ Args:
90
+ x (Tensor): Float tensor variable with the shape (B, C, T).
91
+ Returns:
92
+ Tensor: Float tensor variable with the shape (B, C', T').
93
+ """
94
+ x = self.deconv(x)
95
+ return x
RepCodec/repcodec/layers/vq_module.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """Vector quantization w/ exponential moving averages (EMA)"""
15
+
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ codebook_size: int,
20
+ decay=0.8,
21
+ commitment=1.,
22
+ eps=1e-5,
23
+ n_embed=None,
24
+ ):
25
+ super().__init__()
26
+ n_embed = self.default(n_embed, codebook_size)
27
+
28
+ self.dim = dim
29
+ self.n_embed = n_embed
30
+ self.decay = decay
31
+ self.eps = eps
32
+ self.commitment = commitment
33
+
34
+ embed = torch.randn(dim, n_embed)
35
+ self.register_buffer('embed', embed)
36
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
37
+ self.register_buffer('embed_avg', embed.clone())
38
+
39
+ @property
40
+ def codebook(self):
41
+ return self.embed.transpose(0, 1)
42
+
43
+ def exists(self, val):
44
+ return val is not None
45
+
46
+ def default(self, val, d):
47
+ return val if self.exists(val) else d
48
+
49
+ def ema_inplace(self, moving_avg, new, decay):
50
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
51
+
52
+ def laplace_smoothing(self, x, n_categories, eps=1e-5):
53
+ return (x + eps) / (x.sum() + n_categories * eps)
54
+
55
+ def forward(self, input):
56
+ dtype = input.dtype
57
+ flatten = input.reshape(-1, self.dim)
58
+ dist = (
59
+ flatten.pow(2).sum(1, keepdim=True)
60
+ - 2 * flatten @ self.embed
61
+ + self.embed.pow(2).sum(0, keepdim=True)
62
+ )
63
+ _, embed_ind = (-dist).max(1)
64
+ embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
65
+ embed_ind = embed_ind.view(*input.shape[:-1])
66
+ quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
67
+
68
+ if self.training:
69
+ self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
70
+ embed_sum = flatten.transpose(0, 1) @ embed_onehot
71
+ self.ema_inplace(self.embed_avg, embed_sum, self.decay)
72
+ cluster_size = self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
73
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
74
+ self.embed.data.copy_(embed_normalized)
75
+
76
+ loss = F.mse_loss(quantize.detach(), input) * self.commitment
77
+ quantize = input + (quantize - input).detach()
78
+
79
+ avg_probs = torch.mean(embed_onehot, dim=0)
80
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
81
+
82
+ return quantize, loss, perplexity
83
+
84
+ def forward_index(self, input):
85
+ dtype = input.dtype
86
+ flatten = input.reshape(-1, self.dim)
87
+ dist = (
88
+ flatten.pow(2).sum(1, keepdim=True)
89
+ - 2 * flatten @ self.embed
90
+ + self.embed.pow(2).sum(0, keepdim=True)
91
+ )
92
+ _, embed_ind = (-dist).max(1)
93
+ embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
94
+ embed_ind = embed_ind.view(*input.shape[:-1])
95
+ quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
96
+ quantize = input + (quantize - input).detach()
97
+
98
+ return quantize, embed_ind
99
+
100
+
101
+ class ResidualVQ(nn.Module):
102
+ """ Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
103
+
104
+ def __init__(
105
+ self,
106
+ *,
107
+ num_quantizers,
108
+ **kwargs
109
+ ):
110
+ super().__init__()
111
+ self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])
112
+
113
+ def forward(self, x):
114
+ quantized_out = 0.
115
+ residual = x
116
+ all_losses = []
117
+ all_perplexities = []
118
+ for layer in self.layers:
119
+ quantized, loss, perplexity = layer(residual)
120
+ # Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
121
+ # We found considering only the 1st layer VQ's graident results in better performance
122
+ # residual = residual - quantized.detach() # considering all layers' graidents
123
+ residual = residual - quantized # considering only the first layer's graident
124
+ quantized_out = quantized_out + quantized
125
+ all_losses.append(loss)
126
+ all_perplexities.append(perplexity)
127
+ all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities))
128
+ return quantized_out, all_losses, all_perplexities
129
+
130
+ def forward_index(self, x, flatten_idx=False):
131
+ quantized_out = 0.
132
+ residual = x
133
+ all_indices = []
134
+ for i, layer in enumerate(self.layers):
135
+ quantized, indices = layer.forward_index(residual)
136
+ # residual = residual - quantized.detach()
137
+ residual = residual - quantized
138
+ quantized_out = quantized_out + quantized
139
+ if flatten_idx:
140
+ indices += (self.codebook_size * i)
141
+ all_indices.append(indices)
142
+ all_indices = torch.stack(all_indices)
143
+ return quantized_out, all_indices.squeeze(1)
144
+
145
+ def initial(self):
146
+ self.codebook = []
147
+ for layer in self.layers:
148
+ self.codebook.append(layer.codebook)
149
+ self.codebook_size = self.codebook[0].size(0)
150
+ self.codebook = torch.stack(self.codebook)
151
+ self.codebook = self.codebook.reshape(-1, self.codebook.size(-1))
152
+
153
+ def lookup(self, indices):
154
+ quantized_out = F.embedding(indices, self.codebook) # Num x T x C
155
+ return torch.sum(quantized_out, dim=0, keepdim=True)
RepCodec/repcodec/modules/decoder.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from RepCodec.repcodec.layers.conv_layer import Conv1d, ConvTranspose1d
12
+ from RepCodec.repcodec.modules.residual_unit import ResidualUnit
13
+
14
+
15
+ class DecoderBlock(nn.Module):
16
+ """ Decoder block (no up-sampling) """
17
+
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ out_channels: int,
22
+ stride: int,
23
+ dilations=(1, 1),
24
+ unit_kernel_size=3,
25
+ bias=True
26
+ ):
27
+ super().__init__()
28
+
29
+ if stride == 1:
30
+ self.conv = Conv1d(
31
+ in_channels=in_channels,
32
+ out_channels=out_channels,
33
+ kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
34
+ stride=stride,
35
+ bias=bias,
36
+ )
37
+ else:
38
+ self.conv = ConvTranspose1d(
39
+ in_channels=in_channels,
40
+ out_channels=out_channels,
41
+ kernel_size=(2 * stride),
42
+ stride=stride,
43
+ bias=bias,
44
+ )
45
+
46
+ self.res_units = torch.nn.ModuleList()
47
+ for idx, dilation in enumerate(dilations):
48
+ self.res_units += [
49
+ ResidualUnit(out_channels, out_channels,
50
+ kernel_size=unit_kernel_size,
51
+ dilation=dilation)
52
+ ]
53
+ self.num_res = len(self.res_units)
54
+
55
+ def forward(self, x):
56
+ x = self.conv(x)
57
+ for idx in range(self.num_res):
58
+ x = self.res_units[idx](x)
59
+ return x
60
+
61
+
62
+ class Decoder(nn.Module):
63
+ def __init__(
64
+ self,
65
+ code_dim: int,
66
+ output_channels: int,
67
+ decode_channels: int,
68
+ channel_ratios=(1, 1),
69
+ strides=(1, 1),
70
+ kernel_size=3,
71
+ bias=True,
72
+ block_dilations=(1, 1),
73
+ unit_kernel_size=3,
74
+ ):
75
+ super().__init__()
76
+ assert len(channel_ratios) == len(strides)
77
+
78
+ self.conv1 = Conv1d(
79
+ in_channels=code_dim,
80
+ out_channels=int(decode_channels * channel_ratios[0]),
81
+ kernel_size=kernel_size,
82
+ stride=1,
83
+ bias=False
84
+ )
85
+
86
+ self.conv_blocks = torch.nn.ModuleList()
87
+ for idx, stride in enumerate(strides):
88
+ in_channels = int(decode_channels * channel_ratios[idx])
89
+ if idx < (len(channel_ratios) - 1):
90
+ out_channels = int(decode_channels * channel_ratios[idx + 1])
91
+ else:
92
+ out_channels = decode_channels
93
+ self.conv_blocks += [
94
+ DecoderBlock(
95
+ in_channels, out_channels, stride,
96
+ dilations=block_dilations, unit_kernel_size=unit_kernel_size,
97
+ bias=bias
98
+ )
99
+ ]
100
+ self.num_blocks = len(self.conv_blocks)
101
+
102
+ self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
103
+
104
+ def forward(self, z):
105
+ x = self.conv1(z)
106
+ for i in range(self.num_blocks):
107
+ x = self.conv_blocks[i](x)
108
+ x = self.conv2(x)
109
+ return x
RepCodec/repcodec/modules/encoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from RepCodec.repcodec.layers.conv_layer import Conv1d
12
+ from RepCodec.repcodec.modules.residual_unit import ResidualUnit
13
+
14
+
15
+ class EncoderBlock(nn.Module):
16
+ def __init__(
17
+ self,
18
+ in_channels: int,
19
+ out_channels: int,
20
+ stride: int,
21
+ dilations=(1, 1),
22
+ unit_kernel_size=3,
23
+ bias=True
24
+ ):
25
+ super().__init__()
26
+ self.res_units = torch.nn.ModuleList()
27
+ for dilation in dilations:
28
+ self.res_units += [
29
+ ResidualUnit(in_channels, in_channels,
30
+ kernel_size=unit_kernel_size,
31
+ dilation=dilation)
32
+ ]
33
+ self.num_res = len(self.res_units)
34
+
35
+ self.conv = Conv1d(
36
+ in_channels=in_channels,
37
+ out_channels=out_channels,
38
+ kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
39
+ stride=stride,
40
+ bias=bias,
41
+ )
42
+
43
+ def forward(self, x):
44
+ for idx in range(self.num_res):
45
+ x = self.res_units[idx](x)
46
+ x = self.conv(x)
47
+ return x
48
+
49
+
50
+ class Encoder(nn.Module):
51
+ def __init__(
52
+ self,
53
+ input_channels: int,
54
+ encode_channels: int,
55
+ channel_ratios=(1, 1),
56
+ strides=(1, 1),
57
+ kernel_size=3,
58
+ bias=True,
59
+ block_dilations=(1, 1),
60
+ unit_kernel_size=3
61
+ ):
62
+ super().__init__()
63
+ assert len(channel_ratios) == len(strides)
64
+
65
+ self.conv = Conv1d(
66
+ in_channels=input_channels,
67
+ out_channels=encode_channels,
68
+ kernel_size=kernel_size,
69
+ stride=1,
70
+ bias=False
71
+ )
72
+ self.conv_blocks = torch.nn.ModuleList()
73
+ in_channels = encode_channels
74
+ for idx, stride in enumerate(strides):
75
+ out_channels = int(encode_channels * channel_ratios[idx]) # could be float
76
+ self.conv_blocks += [
77
+ EncoderBlock(in_channels, out_channels, stride,
78
+ dilations=block_dilations, unit_kernel_size=unit_kernel_size,
79
+ bias=bias)
80
+ ]
81
+ in_channels = out_channels
82
+ self.num_blocks = len(self.conv_blocks)
83
+ self.out_channels = out_channels
84
+
85
+ def forward(self, x):
86
+ x = self.conv(x)
87
+ for i in range(self.num_blocks):
88
+ x = self.conv_blocks[i](x)
89
+ return x
RepCodec/repcodec/modules/projector.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.layers.conv_layer import Conv1d
11
+
12
+
13
+ class Projector(nn.Module):
14
+ def __init__(
15
+ self,
16
+ input_channels: int,
17
+ code_dim: int,
18
+ kernel_size=3,
19
+ stride=1,
20
+ bias=False
21
+ ):
22
+ super().__init__()
23
+ self.project = Conv1d(
24
+ input_channels,
25
+ code_dim,
26
+ kernel_size=kernel_size,
27
+ stride=stride,
28
+ bias=bias
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.project(x)
RepCodec/repcodec/modules/quantizer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.layers.vq_module import ResidualVQ
11
+
12
+
13
+ class Quantizer(nn.Module):
14
+ def __init__(
15
+ self,
16
+ code_dim: int,
17
+ codebook_num: int,
18
+ codebook_size: int,
19
+ ):
20
+ super().__init__()
21
+ self.codebook = ResidualVQ(
22
+ dim=code_dim,
23
+ num_quantizers=codebook_num,
24
+ codebook_size=codebook_size
25
+ )
26
+
27
+ def initial(self):
28
+ self.codebook.initial()
29
+
30
+ def forward(self, z):
31
+ zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
32
+ zq = zq.transpose(2, 1)
33
+ return zq, vqloss, perplexity
34
+
35
+ def inference(self, z):
36
+ zq, indices = self.codebook.forward_index(z.transpose(2, 1))
37
+ zq = zq.transpose(2, 1)
38
+ return zq, indices
39
+
40
+ def encode(self, z):
41
+ zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
42
+ return zq, indices
43
+
44
+ def decode(self, indices):
45
+ z = self.codebook.lookup(indices)
46
+ return z
RepCodec/repcodec/modules/residual_unit.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from RepCodec.repcodec.layers.conv_layer import Conv1d, Conv1d1x1
11
+
12
+
13
+ class ResidualUnit(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_channels: int,
17
+ out_channels: int,
18
+ kernel_size=3,
19
+ dilation=1,
20
+ bias=False,
21
+ nonlinear_activation="ELU",
22
+ nonlinear_activation_params={},
23
+ ):
24
+ super().__init__()
25
+ self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
26
+ self.conv1 = Conv1d(
27
+ in_channels=in_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=kernel_size,
30
+ stride=1,
31
+ dilation=dilation,
32
+ bias=bias,
33
+ )
34
+ self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
35
+
36
+ def forward(self, x):
37
+ y = self.conv1(self.activation(x))
38
+ y = self.conv2(self.activation(y))
39
+ return x + y
RepCodec/repcodec/tokenize.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Tuple, List, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ import yaml
15
+ from tqdm import tqdm
16
+
17
+ from repcodec.RepCodec import RepCodec
18
+
19
+ ALL_MODELS = {
20
+ "data2vec_base_l6": 768,
21
+ "data2vec_large_l18": 1024,
22
+ "hubert_base_l9": 768,
23
+ "hubert_large_l18": 1024,
24
+ "whisper_medium_l24": 1024,
25
+ "whisper_large_l32": 1280
26
+ }
27
+
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
31
+ parser.add_argument(
32
+ "in_dir",
33
+ type=str,
34
+ help="directory of representations to be tokenized."
35
+ )
36
+ parser.add_argument(
37
+ "--model",
38
+ required=True,
39
+ type=str,
40
+ help="path of the RepCodec model."
41
+ )
42
+ parser.add_argument(
43
+ "--tsv_path",
44
+ required=True,
45
+ type=str,
46
+ help="path of the tsv file."
47
+ )
48
+ parser.add_argument(
49
+ "--model_config_path",
50
+ default=None,
51
+ type=str,
52
+ help="please provide this training config if you are using the model you trained yourself."
53
+ )
54
+ parser.add_argument(
55
+ "--n_shard",
56
+ required=False,
57
+ type=int,
58
+ default=1,
59
+ help="number of shards of representations."
60
+ )
61
+ parser.add_argument(
62
+ "--use_gpu",
63
+ default=False,
64
+ action="store_true",
65
+ help="whether use gpu for inference."
66
+ )
67
+ parser.add_argument(
68
+ "--batch_size",
69
+ default=1,
70
+ type=int,
71
+ help="number of utterances for each mini batch."
72
+ )
73
+ parser.add_argument(
74
+ "--out_dir",
75
+ type=str,
76
+ default=".",
77
+ help="the directory to save the output."
78
+ )
79
+ return parser.parse_args()
80
+
81
+
82
+ def load_model(model_path: str, config_path: Optional[str] = None):
83
+ if config_path is None:
84
+ name = os.path.basename(model_path).strip(".pkl")
85
+ assert name in ALL_MODELS.keys(), f"Cannot find configs for {model_path}. " \
86
+ f"Please provide the config file you used for training."
87
+ config = os.path.join(os.path.dirname(__file__), "configs", f"repcodec_dim{ALL_MODELS[name]}.yaml")
88
+ with open(config) as fp:
89
+ conf = yaml.load(fp, Loader=yaml.FullLoader)
90
+ else:
91
+ with open(config_path) as fp:
92
+ conf = yaml.load(fp, Loader=yaml.FullLoader)["model_params"]
93
+
94
+ model = RepCodec(**conf)
95
+ model.load_state_dict(torch.load(model_path, map_location="cpu")["model"]["repcodec"])
96
+ model.quantizer.initial()
97
+ model.eval()
98
+ return model
99
+
100
+
101
+ def load_shard(in_dir: Path, rank: int, n_shard: int) -> Tuple[np.ndarray, List[int]]:
102
+ feat_path = in_dir / f"{rank}_{n_shard}.npy"
103
+ len_path = in_dir / f"{rank}_{n_shard}.len"
104
+
105
+ with open(len_path) as fp:
106
+ lengths = [int(line.strip()) for line in fp]
107
+
108
+ return np.load(feat_path.as_posix(), mmap_mode="r"), lengths
109
+
110
+
111
+ def pad_data(data: List[np.ndarray]) -> List[np.ndarray]:
112
+ max_len = max([d.shape[0] for d in data])
113
+ data = [
114
+ np.pad(d, [(0, max_len - d.shape[0]), (0, 0)], "constant", constant_values=0.0)
115
+ for d in data
116
+ ]
117
+ return data
118
+
119
+
120
+ def make_batch_data(data: np.ndarray, shard_lengths: List[int], batch_size: int):
121
+ batch_data = []
122
+ batch_lens = []
123
+ offsets = np.cumsum([0] + shard_lengths)
124
+ assert len(data) == offsets[-1], f"{len(data)} {offsets[-1]}"
125
+
126
+ # from longest to shortest
127
+ for i in range(len(shard_lengths)):
128
+ if batch_size > len(batch_data):
129
+ batch_data.append(data[offsets[i]: offsets[i + 1]])
130
+ batch_lens.append(shard_lengths[i])
131
+ else:
132
+ yield {
133
+ "data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float), # (bsz, seq len, hidden dim)
134
+ "lengths": batch_lens
135
+ }
136
+ batch_data = [data[offsets[i]: offsets[i + 1]]]
137
+ batch_lens = [shard_lengths[i]]
138
+ if len(batch_data) > 0:
139
+ yield {
140
+ "data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float),
141
+ "lengths": batch_lens
142
+ }
143
+
144
+
145
+ def tokenize_batch(model: RepCodec, batch: dict, device: str) -> List[List[int]]:
146
+ with torch.no_grad():
147
+ data = batch["data"].transpose(1, 2).to(device) # (bsz, hidden dim, seq len)
148
+ x = model.encoder(data)
149
+ z = model.projector(x)
150
+ _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
151
+
152
+ # when bsz=1: (1, seq len)
153
+ if idx.dim() == 2:
154
+ return idx.cpu().data.numpy().tolist()
155
+ # when bsz>1: (1, bsz, seq len)
156
+ tokens = idx.cpu().data.numpy().tolist()[0]
157
+ res = []
158
+ batch_lens = batch["lengths"]
159
+ for i in range(len(tokens)):
160
+ n_tokens = batch_lens[i]
161
+ res.append(tokens[i][:n_tokens])
162
+ return res
163
+
164
+
165
+ def load_tsv(path: str):
166
+ with open(path) as fp:
167
+ root = fp.readline().strip()
168
+ names = []
169
+ for line in fp:
170
+ names.append(line.strip().split("\t")[0])
171
+ return root, names
172
+
173
+
174
+ def cli():
175
+ args = parse_args()
176
+ device = "cuda" if args.use_gpu else "cpu"
177
+
178
+ model = load_model(model_path=args.model, config_path=args.model_config_path)
179
+ model.to(device)
180
+
181
+ in_dir = Path(args.in_dir)
182
+ n_shard = args.n_shard
183
+ batch_size = args.batch_size
184
+
185
+ root_dir, file_names = load_tsv(args.tsv_path)
186
+
187
+ output_dir = args.out_dir
188
+ os.makedirs(output_dir, exist_ok=True)
189
+
190
+ processed_cnt = 0
191
+ pbar = tqdm(total=len(file_names))
192
+ with open(os.path.join(output_dir, "tokens"), mode="w+") as fp:
193
+ fp.write(f"{root_dir}\n")
194
+
195
+ for rank in range(n_shard):
196
+ shard_data, shard_lengths = load_shard(in_dir, rank, n_shard)
197
+ for batch in make_batch_data(shard_data, shard_lengths, batch_size=batch_size):
198
+ batch_tokens = tokenize_batch(model, batch, device)
199
+
200
+ for tokens in batch_tokens:
201
+ fp.write(f"{file_names[processed_cnt]}\t{' '.join(map(str, tokens))}\n")
202
+ processed_cnt += 1
203
+
204
+ pbar.update(len(batch_tokens))
205
+ assert processed_cnt == len(file_names), f"# lines of tsv do not match # of representations!"
206
+
207
+ pbar.close()
208
+ print("Tokenize successfully!")
209
+
210
+
211
+ if __name__ == '__main__':
212
+ cli()
RepCodec/setup.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from setuptools import setup
8
+
9
+ try:
10
+ with open("README.md") as fp:
11
+ long_description = fp.read()
12
+ except Exception:
13
+ long_description = ""
14
+
15
+ setup(
16
+ name="RepCodec",
17
+ version="v1.0.0",
18
+ description="A Speech Representation Codec for Speech Tokenization",
19
+ long_description=long_description,
20
+ url="https://github.com/mct10/RepCodec",
21
+ packages=["repcodec", "repcodec.modules", "repcodec.layers"],
22
+ package_data={
23
+ "repcodec": ["configs/*.yaml"]
24
+ },
25
+ install_requires=["numpy", "tqdm", "torch", "tensorboardX", "PyYAML"],
26
+ entry_points={
27
+ 'console_scripts': [
28
+ "repcodec=repcodec.tokenize:cli"
29
+ ]
30
+ }
31
+ )
RepCodec/train.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import argparse
9
+ import logging
10
+
11
+ import os
12
+
13
+ logging.basicConfig(
14
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
17
+ )
18
+ logger = logging.getLogger("repcodec_train") # init logger before other modules
19
+
20
+ import random
21
+
22
+ import numpy as np
23
+ import torch
24
+ import yaml
25
+ from torch.utils.data import DataLoader
26
+
27
+ from dataloader import ReprDataset, ReprCollater
28
+ from losses.repr_reconstruct_loss import ReprReconstructLoss
29
+ from repcodec.RepCodec import RepCodec
30
+ from trainer.autoencoder import Trainer
31
+
32
+
33
+ class TrainMain:
34
+ def __init__(self, args):
35
+ # Fix seed and make backends deterministic
36
+ random.seed(args.seed)
37
+ np.random.seed(args.seed)
38
+ torch.manual_seed(args.seed)
39
+ if not torch.cuda.is_available():
40
+ self.device = torch.device('cpu')
41
+ logger.info(f"device: cpu")
42
+ else:
43
+ self.device = torch.device('cuda:0') # only supports single gpu for now
44
+ logger.info(f"device: gpu")
45
+ torch.cuda.manual_seed_all(args.seed)
46
+ if args.disable_cudnn == "False":
47
+ torch.backends.cudnn.benchmark = True
48
+
49
+ # initialize config
50
+ with open(args.config, 'r') as f:
51
+ self.config = yaml.load(f, Loader=yaml.FullLoader)
52
+ self.config.update(vars(args))
53
+
54
+ # initialize model folder
55
+ expdir = os.path.join(args.exp_root, args.tag)
56
+ os.makedirs(expdir, exist_ok=True)
57
+ self.config["outdir"] = expdir
58
+
59
+ # save config
60
+ with open(os.path.join(expdir, "config.yml"), "w") as f:
61
+ yaml.dump(self.config, f, Dumper=yaml.Dumper)
62
+ for key, value in self.config.items():
63
+ logger.info(f"{key} = {value}")
64
+
65
+ # initialize attribute
66
+ self.resume: str = args.resume
67
+ self.data_loader = None
68
+ self.model = None
69
+ self.optimizer = None
70
+ self.scheduler = None
71
+ self.criterion = None
72
+ self.trainer = None
73
+
74
+ # initialize batch_length
75
+ self.batch_length: int = self.config['batch_length']
76
+ self.data_path: str = self.config['data']['path']
77
+
78
+ def initialize_data_loader(self):
79
+ train_set = self._build_dataset("train")
80
+ valid_set = self._build_dataset("valid")
81
+ collater = ReprCollater()
82
+
83
+ logger.info(f"The number of training files = {len(train_set)}.")
84
+ logger.info(f"The number of validation files = {len(valid_set)}.")
85
+ dataset = {"train": train_set, "dev": valid_set}
86
+ self._set_data_loader(dataset, collater)
87
+
88
+ def define_model_optimizer_scheduler(self):
89
+ # model arch
90
+ self.model = {
91
+ "repcodec": RepCodec(**self.config["model_params"]).to(self.device)
92
+ }
93
+ logger.info(f"Model Arch:\n{self.model['repcodec']}")
94
+
95
+ # opt
96
+ optimizer_class = getattr(
97
+ torch.optim,
98
+ self.config["model_optimizer_type"]
99
+ )
100
+ self.optimizer = {
101
+ "repcodec": optimizer_class(
102
+ self.model["repcodec"].parameters(),
103
+ **self.config["model_optimizer_params"]
104
+ )
105
+ }
106
+
107
+ # scheduler
108
+ scheduler_class = getattr(
109
+ torch.optim.lr_scheduler,
110
+ self.config.get("model_scheduler_type", "StepLR"),
111
+ )
112
+ self.scheduler = {
113
+ "repcodec": scheduler_class(
114
+ optimizer=self.optimizer["repcodec"],
115
+ **self.config["model_scheduler_params"]
116
+ )
117
+ }
118
+
119
+ def define_criterion(self):
120
+ self.criterion = {
121
+ "repr_reconstruct_loss": ReprReconstructLoss(
122
+ **self.config.get("repr_reconstruct_loss_params", {}),
123
+ ).to(self.device)
124
+ }
125
+
126
+ def define_trainer(self):
127
+ self.trainer = Trainer(
128
+ steps=0,
129
+ epochs=0,
130
+ data_loader=self.data_loader,
131
+ model=self.model,
132
+ criterion=self.criterion,
133
+ optimizer=self.optimizer,
134
+ scheduler=self.scheduler,
135
+ config=self.config,
136
+ device=self.device
137
+ )
138
+
139
+ def initialize_model(self):
140
+ initial = self.config.get("initial", "")
141
+ if os.path.exists(self.resume): # resume from trained model
142
+ self.trainer.load_checkpoint(self.resume)
143
+ logger.info(f"Successfully resumed from {self.resume}.")
144
+ elif os.path.exists(initial): # initial new model with the pre-trained model
145
+ self.trainer.load_checkpoint(initial, load_only_params=True)
146
+ logger.info(f"Successfully initialize parameters from {initial}.")
147
+ else:
148
+ logger.info("Train from scrach")
149
+
150
+ def run(self):
151
+ assert self.trainer is not None
152
+ self.trainer: Trainer
153
+ try:
154
+ logger.info(f"The current training step: {self.trainer.steps}")
155
+ self.trainer.train_max_steps = self.config["train_max_steps"]
156
+ if not self.trainer._check_train_finish():
157
+ self.trainer.run()
158
+ finally:
159
+ self.trainer.save_checkpoint(
160
+ os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl")
161
+ )
162
+ logger.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.")
163
+
164
+ def _build_dataset(
165
+ self, subset: str
166
+ ) -> ReprDataset:
167
+ data_dir = os.path.join(
168
+ self.data_path, self.config['data']['subset'][subset]
169
+ )
170
+ params = {
171
+ "data_dir": data_dir,
172
+ "batch_len": self.batch_length
173
+ }
174
+ return ReprDataset(**params)
175
+
176
+ def _set_data_loader(self, dataset, collater):
177
+ self.data_loader = {
178
+ "train": DataLoader(
179
+ dataset=dataset["train"],
180
+ shuffle=True,
181
+ collate_fn=collater,
182
+ batch_size=self.config["batch_size"],
183
+ num_workers=self.config["num_workers"],
184
+ pin_memory=self.config["pin_memory"],
185
+ ),
186
+ "dev": DataLoader(
187
+ dataset=dataset["dev"],
188
+ shuffle=False,
189
+ collate_fn=collater,
190
+ batch_size=self.config["batch_size"],
191
+ num_workers=0,
192
+ pin_memory=False, # save some memory. set to True if you have enough memory.
193
+ ),
194
+ }
195
+
196
+
197
+ def train():
198
+ parser = argparse.ArgumentParser()
199
+ parser.add_argument(
200
+ "-c", "--config", type=str, required=True,
201
+ help="the path of config yaml file."
202
+ )
203
+ parser.add_argument(
204
+ "--tag", type=str, required=True,
205
+ help="the outputs will be saved to exp_root/tag/"
206
+ )
207
+ parser.add_argument(
208
+ "--exp_root", type=str, default="exp"
209
+ )
210
+ parser.add_argument(
211
+ "--resume", default="", type=str, nargs="?",
212
+ help='checkpoint file path to resume training. (default="")',
213
+ )
214
+ parser.add_argument("--seed", default=1337, type=int)
215
+ parser.add_argument("--disable_cudnn", choices=("True", "False"), default="False", help="Disable CUDNN")
216
+ args = parser.parse_args()
217
+
218
+ train_main = TrainMain(args)
219
+ train_main.initialize_data_loader()
220
+ train_main.define_model_optimizer_scheduler()
221
+ train_main.define_criterion()
222
+ train_main.define_trainer()
223
+ train_main.initialize_model()
224
+ train_main.run()
225
+
226
+
227
+ if __name__ == '__main__':
228
+ train()
RepCodec/train_configs/ex_dim768_mse.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################
2
+ # DATA SETTING #
3
+ ###########################################################
4
+ data:
5
+ path: "/dir/to/representations/"
6
+ subset:
7
+ train: "train_set_name"
8
+ valid: "valid_set_name"
9
+ test: "test_set_name"
10
+
11
+ ###########################################################
12
+ # MODEL SETTING #
13
+ ###########################################################
14
+ model_params:
15
+ input_channels: 768
16
+ output_channels: 768
17
+ encode_channels: 768
18
+ decode_channels: 768
19
+ code_dim: 768
20
+ codebook_num: 1
21
+ codebook_size: 1024
22
+ bias: true
23
+ enc_ratios: [1, 1]
24
+ dec_ratios: [1, 1]
25
+ enc_strides: [1, 1] # no downsampling
26
+ dec_strides: [1, 1]
27
+ enc_kernel_size: 3
28
+ dec_kernel_size: 3
29
+ enc_block_dilations: [1, 1]
30
+ enc_block_kernel_size: 3
31
+ dec_block_dilations: [1, 1]
32
+ dec_block_kernel_size: 3
33
+
34
+ ###########################################################
35
+ # METRIC LOSS SETTING #
36
+ ###########################################################
37
+ repr_reconstruct_loss_params:
38
+ loss_type: l2
39
+
40
+ ###########################################################
41
+ # LOSS WEIGHT SETTING #
42
+ ###########################################################
43
+ lambda_vq_loss: 1.0 # Loss weight of vector quantize loss.
44
+ lambda_repr_reconstruct_loss: 45.0
45
+
46
+ ###########################################################
47
+ # DATA LOADER SETTING #
48
+ ###########################################################
49
+ batch_size: 32 # Batch size.
50
+ batch_length: 96 # Length of each audio in batch (training w/o adv).
51
+ pin_memory: true # Whether to pin memory in Pytorch DataLoader.
52
+ num_workers: 4 # Number of workers in Pytorch DataLoader.
53
+
54
+ ###########################################################
55
+ # OPTIMIZER & SCHEDULER SETTING #
56
+ ###########################################################
57
+ model_optimizer_type: Adam
58
+ model_optimizer_params:
59
+ lr: 1.0e-4
60
+ betas: [0.5, 0.9]
61
+ weight_decay: 0.0
62
+ model_scheduler_type: StepLR
63
+ model_scheduler_params:
64
+ step_size: 200000 # Model's scheduler step size.
65
+ gamma: 1.0
66
+ grad_norm: -1
67
+
68
+ ###########################################################
69
+ # INTERVAL SETTING #
70
+ ###########################################################
71
+ train_max_steps: 200000 # Number of training steps. (w/o adv)
72
+ save_interval_steps: 20000 # Interval steps to save checkpoint.
73
+ eval_interval_steps: 2000 # Interval steps to evaluate the network.
74
+ log_interval_steps: 100 # Interval steps to record the training log.
RepCodec/trainer/autoencoder.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import logging
9
+ import os
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ from tensorboardX import SummaryWriter
14
+ from tqdm import tqdm
15
+
16
+ logger = logging.getLogger("repcodec_train")
17
+
18
+
19
+ class Trainer:
20
+ def __init__(
21
+ self,
22
+ steps: int,
23
+ epochs: int,
24
+ data_loader: dict,
25
+ model: dict,
26
+ criterion: dict,
27
+ optimizer: dict,
28
+ scheduler: dict,
29
+ config: dict,
30
+ device=torch.device("cpu"),
31
+ ):
32
+ self.steps = steps
33
+ self.epochs = epochs
34
+ self.data_loader = data_loader
35
+ self.model = model
36
+ self.criterion = criterion
37
+ self.optimizer = optimizer
38
+ self.scheduler = scheduler
39
+ self.config = config
40
+ self.device = device
41
+ self.writer = SummaryWriter(config["outdir"])
42
+ self.total_train_loss = defaultdict(float)
43
+ self.total_eval_loss = defaultdict(float)
44
+ self.train_max_steps = config.get("train_max_steps", 0)
45
+
46
+ def _train_step(self, batch):
47
+ """Single step of training."""
48
+ mode = "train"
49
+ x = batch
50
+ x = x.to(self.device)
51
+
52
+ codec_loss = 0.0
53
+ y_, zq, z, vqloss, perplexity = self.model["repcodec"](x)
54
+ self._perplexity(perplexity, mode=mode)
55
+ codec_loss += self._vq_loss(vqloss, mode=mode)
56
+ codec_loss += self._metric_loss(y_, x, mode=mode)
57
+
58
+ self._record_loss("codec_loss", codec_loss, mode=mode)
59
+ self._update_repcodec(codec_loss)
60
+
61
+ self.steps += 1
62
+ self.tqdm.update(1)
63
+ self._check_train_finish()
64
+
65
+ @torch.no_grad()
66
+ def _eval_step(self, batch):
67
+ """Single step of evaluation."""
68
+ mode = "eval"
69
+ x = batch
70
+ x = x.to(self.device)
71
+
72
+ codec_loss = 0.0
73
+ y_, zq, z, vqloss, perplexity = self.model["repcodec"](x)
74
+ self._perplexity(perplexity, mode=mode)
75
+ codec_loss += self._vq_loss(vqloss, mode=mode)
76
+ codec_loss += self._metric_loss(y_, x, mode=mode)
77
+
78
+ self._record_loss("codec_loss", codec_loss, mode=mode)
79
+
80
+ def run(self):
81
+ """Run training."""
82
+ self.finish_train = False
83
+ self.tqdm = tqdm(
84
+ initial=self.steps, total=self.train_max_steps, desc="[train]"
85
+ )
86
+ while True:
87
+ self._train_epoch()
88
+
89
+ # check whether training is finished
90
+ if self.finish_train:
91
+ break
92
+
93
+ self.tqdm.close()
94
+ logger.info("Finished training.")
95
+
96
+ def save_checkpoint(self, checkpoint_path: str):
97
+ state_dict = {
98
+ "model": {
99
+ "repcodec": self.model["repcodec"].state_dict()
100
+ },
101
+ "optimizer": {
102
+ "repcodec": self.optimizer["repcodec"].state_dict(),
103
+ },
104
+ "scheduler": {
105
+ "repcodec": self.scheduler["repcodec"].state_dict(),
106
+ },
107
+ "steps": self.steps,
108
+ "epochs": self.epochs,
109
+ }
110
+
111
+ if not os.path.exists(os.path.dirname(checkpoint_path)):
112
+ os.makedirs(os.path.dirname(checkpoint_path))
113
+ torch.save(state_dict, checkpoint_path)
114
+
115
+ def load_checkpoint(
116
+ self,
117
+ checkpoint_path: str,
118
+ strict: bool = True,
119
+ load_only_params: bool = False
120
+ ):
121
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
122
+ self.model["repcodec"].load_state_dict(
123
+ state_dict["model"]["repcodec"], strict=strict
124
+ )
125
+
126
+ if not load_only_params:
127
+ self.steps = state_dict["steps"]
128
+ self.epochs = state_dict["epochs"]
129
+ self.optimizer["repcodec"].load_state_dict(
130
+ state_dict["optimizer"]["repcodec"]
131
+ )
132
+ self.scheduler["repcodec"].load_state_dict(
133
+ state_dict["scheduler"]["repcodec"]
134
+ )
135
+
136
+ def _train_epoch(self):
137
+ """One epoch of training."""
138
+ for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1):
139
+ # train one step
140
+ self._train_step(batch)
141
+
142
+ # check interval
143
+ self._check_log_interval()
144
+ self._check_eval_interval()
145
+ self._check_save_interval()
146
+
147
+ # check whether training is finished
148
+ if self.finish_train:
149
+ return
150
+
151
+ # update
152
+ self.epochs += 1
153
+ self.train_steps_per_epoch = train_steps_per_epoch
154
+ if train_steps_per_epoch > 200:
155
+ logger.info(
156
+ f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
157
+ f"({self.train_steps_per_epoch} steps per epoch)."
158
+ )
159
+
160
+ def _eval_epoch(self):
161
+ """One epoch of evaluation."""
162
+ logger.info(f"(Steps: {self.steps}) Start evaluation.")
163
+ # change mode
164
+ for key in self.model.keys():
165
+ self.model[key].eval()
166
+
167
+ # calculate loss for each batch
168
+ for eval_steps_per_epoch, batch in enumerate(
169
+ tqdm(self.data_loader["dev"], desc="[eval]"), 1
170
+ ):
171
+ # eval one step
172
+ self._eval_step(batch)
173
+
174
+ logger.info(
175
+ f"(Steps: {self.steps}) Finished evaluation "
176
+ f"({eval_steps_per_epoch} steps per epoch)."
177
+ )
178
+
179
+ # average loss
180
+ for key in self.total_eval_loss.keys():
181
+ self.total_eval_loss[key] /= eval_steps_per_epoch
182
+ logger.info(
183
+ f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
184
+ )
185
+
186
+ # record
187
+ self._write_to_tensorboard(self.total_eval_loss)
188
+
189
+ # reset
190
+ self.total_eval_loss = defaultdict(float)
191
+
192
+ # restore mode
193
+ for key in self.model.keys():
194
+ self.model[key].train()
195
+
196
+ def _metric_loss(self, predict_y, natural_y, mode='train'):
197
+ """Metric losses."""
198
+ metric_loss = 0.0
199
+
200
+ repr_reconstruct_loss = self.criterion["repr_reconstruct_loss"](predict_y, natural_y)
201
+ repr_reconstruct_loss *= self.config["lambda_repr_reconstruct_loss"]
202
+ self._record_loss("reconstruct_loss", repr_reconstruct_loss, mode=mode)
203
+ metric_loss += repr_reconstruct_loss
204
+
205
+ return metric_loss
206
+
207
+ def _update_repcodec(self, repr_loss):
208
+ """Update generator."""
209
+ self.optimizer["repcodec"].zero_grad()
210
+ repr_loss.backward()
211
+ if self.config["grad_norm"] > 0:
212
+ torch.nn.utils.clip_grad_norm_(
213
+ self.model["repcodec"].parameters(),
214
+ self.config["grad_norm"],
215
+ )
216
+ self.optimizer["repcodec"].step()
217
+ self.scheduler["repcodec"].step()
218
+
219
+ def _record_loss(self, name: str, loss, mode='train'):
220
+ """Record loss."""
221
+ if torch.is_tensor(loss):
222
+ loss = loss.item()
223
+
224
+ if mode == 'train':
225
+ self.total_train_loss[f"train/{name}"] += loss
226
+ elif mode == 'eval':
227
+ self.total_eval_loss[f"eval/{name}"] += loss
228
+ else:
229
+ raise NotImplementedError(f"Mode ({mode}) is not supported!")
230
+
231
+ def _write_to_tensorboard(self, loss):
232
+ """Write to tensorboard."""
233
+ for key, value in loss.items():
234
+ self.writer.add_scalar(key, value, self.steps)
235
+
236
+ def _check_save_interval(self):
237
+ if self.steps and (self.steps % self.config["save_interval_steps"] == 0):
238
+ self.save_checkpoint(
239
+ os.path.join(self.config["outdir"], f"checkpoint-{self.steps}steps.pkl")
240
+ )
241
+ logger.info(f"Successfully saved checkpoint @ {self.steps} steps.")
242
+
243
+ def _check_eval_interval(self):
244
+ if self.steps % self.config["eval_interval_steps"] == 0:
245
+ self._eval_epoch()
246
+
247
+ def _check_log_interval(self):
248
+ if self.steps % self.config["log_interval_steps"] == 0:
249
+ for key in self.total_train_loss.keys():
250
+ self.total_train_loss[key] /= self.config["log_interval_steps"]
251
+ logger.info(
252
+ f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}."
253
+ )
254
+ self._write_to_tensorboard(self.total_train_loss)
255
+
256
+ # reset
257
+ self.total_train_loss = defaultdict(float)
258
+
259
+ def _check_train_finish(self):
260
+ if self.steps >= self.train_max_steps:
261
+ self.finish_train = True
262
+ else:
263
+ self.finish_train = False
264
+ return self.finish_train
265
+
266
+ def _perplexity(self, perplexity, label=None, mode='train'):
267
+ if label:
268
+ name = f"{mode}/ppl_{label}"
269
+ else:
270
+ name = f"{mode}/ppl"
271
+ if torch.numel(perplexity) > 1:
272
+ perplexity = perplexity.tolist()
273
+ for idx, ppl in enumerate(perplexity):
274
+ self._record_loss(f"{name}_{idx}", ppl, mode=mode)
275
+ else:
276
+ self._record_loss(name, perplexity, mode=mode)
277
+
278
+ def _vq_loss(self, vqloss, label=None, mode='train'):
279
+ if label:
280
+ name = f"{mode}/vqloss_{label}"
281
+ else:
282
+ name = f"{mode}/vqloss"
283
+ vqloss = torch.sum(vqloss)
284
+ vqloss *= self.config["lambda_vq_loss"]
285
+ self._record_loss(name, vqloss, mode=mode)
286
+
287
+ return vqloss
__pycache__/post_process_audio.cpython-310.pyc ADDED
Binary file (3.1 kB). View file
 
__pycache__/vocoder.cpython-310.pyc ADDED
Binary file (5.75 kB). View file
 
decoders/config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ backbone:
2
+ class_path: vocos.models.VocosBackbone
3
+ init_args:
4
+ input_channels: 1024
5
+ dim: 512
6
+ intermediate_dim: 1536
7
+ num_layers: 8
8
+
9
+ head:
10
+ class_path: vocos.heads.ISTFTHead
11
+ init_args:
12
+ dim: 512
13
+ n_fft: 3528
14
+ hop_length: 882
15
+ padding: same
decoders/decoder_131000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b99f0be84eeef3a32f29cd55beb89727fd0b2fd0df3dbad3023508f4c7185c37
3
+ size 72611958
decoders/decoder_151000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8af97a29d3483f9d4a3755992837501bd7d6caa1a69382ed16e64039e0ea0998
3
+ size 72610550
descriptaudiocodec/dac/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ # preserved here for legacy reasons
4
+ __model_version__ = "latest"
5
+
6
+ import audiotools
7
+
8
+ audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
+ audiotools.ml.BaseModel.EXTERN += ["einops"]
10
+
11
+
12
+ from . import nn
13
+ from . import model
14
+ from . import utils
15
+ from .model import DAC
16
+ from .model import DACFile
descriptaudiocodec/dac/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import argbind
4
+
5
+ from dac.utils import download
6
+ from dac.utils.decode import decode
7
+ from dac.utils.encode import encode
8
+
9
+ STAGES = ["encode", "decode", "download"]
10
+
11
+
12
+ def run(stage: str):
13
+ """Run stages.
14
+
15
+ Parameters
16
+ ----------
17
+ stage : str
18
+ Stage to run
19
+ """
20
+ if stage not in STAGES:
21
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
+ stage_fn = globals()[stage]
23
+
24
+ if stage == "download":
25
+ stage_fn()
26
+ return
27
+
28
+ stage_fn()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ group = sys.argv.pop(1)
33
+ args = argbind.parse_args(group=group)
34
+
35
+ with argbind.scope(args):
36
+ run(group)
descriptaudiocodec/dac/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (516 Bytes). View file
 
descriptaudiocodec/dac/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (488 Bytes). View file
 
descriptaudiocodec/dac/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (489 Bytes). View file
 
descriptaudiocodec/dac/compare/__init__.py ADDED
File without changes
descriptaudiocodec/dac/compare/encodec.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from audiotools import AudioSignal
3
+ from audiotools.ml import BaseModel
4
+ from encodec import EncodecModel
5
+
6
+
7
+ class Encodec(BaseModel):
8
+ def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
9
+ super().__init__()
10
+
11
+ if sample_rate == 24000:
12
+ self.model = EncodecModel.encodec_model_24khz()
13
+ else:
14
+ self.model = EncodecModel.encodec_model_48khz()
15
+ self.model.set_target_bandwidth(bandwidth)
16
+ self.sample_rate = 44100
17
+
18
+ def forward(
19
+ self,
20
+ audio_data: torch.Tensor,
21
+ sample_rate: int = 44100,
22
+ n_quantizers: int = None,
23
+ ):
24
+ signal = AudioSignal(audio_data, sample_rate)
25
+ signal.resample(self.model.sample_rate)
26
+ recons = self.model(signal.audio_data)
27
+ recons = AudioSignal(recons, self.model.sample_rate)
28
+ recons.resample(sample_rate)
29
+ return {"audio": recons.audio_data}
30
+
31
+
32
+ if __name__ == "__main__":
33
+ import numpy as np
34
+ from functools import partial
35
+
36
+ model = Encodec()
37
+
38
+ for n, m in model.named_modules():
39
+ o = m.extra_repr()
40
+ p = sum([np.prod(p.size()) for p in m.parameters()])
41
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
42
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
43
+ print(model)
44
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
45
+
46
+ length = 88200 * 2
47
+ x = torch.randn(1, 1, length).to(model.device)
48
+ x.requires_grad_(True)
49
+ x.retain_grad()
50
+
51
+ # Make a forward pass
52
+ out = model(x)["audio"]
53
+
54
+ print(x.shape, out.shape)
descriptaudiocodec/dac/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import CodecMixin
2
+ from .base import DACFile
3
+ from .dac import DAC
4
+ from .discriminator import Discriminator
descriptaudiocodec/dac/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (357 Bytes). View file
 
descriptaudiocodec/dac/model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (330 Bytes). View file
 
descriptaudiocodec/dac/model/__pycache__/base.cpython-310.pyc ADDED
Binary file (7.26 kB). View file
 
descriptaudiocodec/dac/model/__pycache__/base.cpython-39.pyc ADDED
Binary file (7.18 kB). View file
 
descriptaudiocodec/dac/model/__pycache__/dac.cpython-310.pyc ADDED
Binary file (10.8 kB). View file